logo
Browse Source

repair example in README.md

main
ChengZi 2 years ago
parent
commit
f4c598e7dd
  1. 42
      README.md

42
README.md

@ -35,32 +35,32 @@ import towhee
torch.manual_seed(42)
batch_size = 8
experts = {"audio": torch.rand(batch_size, 29, 128),
"face": torch.rand(batch_size, 512),
"i3d.i3d.0": torch.rand(batch_size, 1024),
"imagenet.resnext101_32x48d.0": torch.rand(batch_size, 2048),
"imagenet.senet154.0": torch.rand(batch_size, 2048),
"ocr": torch.rand(batch_size, 49, 300),
"r2p1d.r2p1d-ig65m.0": torch.rand(batch_size, 512),
"scene.densenet161.0": torch.rand(batch_size, 2208),
"speech": torch.rand(batch_size, 32, 300)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
experts = {"audio": torch.rand(batch_size, 29, 128).to(device),
"face": torch.rand(batch_size, 512).to(device),
"i3d.i3d.0": torch.rand(batch_size, 1024).to(device),
"imagenet.resnext101_32x48d.0": torch.rand(batch_size, 2048).to(device),
"imagenet.senet154.0": torch.rand(batch_size, 2048).to(device),
"ocr": torch.rand(batch_size, 49, 300).to(device),
"r2p1d.r2p1d-ig65m.0": torch.rand(batch_size, 512).to(device),
"scene.densenet161.0": torch.rand(batch_size, 2208).to(device),
"speech": torch.rand(batch_size, 32, 300).to(device)
}
ind = {
"r2p1d.r2p1d-ig65m.0": torch.ones(batch_size,),
"imagenet.senet154.0": torch.ones(batch_size,),
"imagenet.resnext101_32x48d.0": torch.ones(batch_size,),
"scene.densenet161.0": torch.ones(batch_size,),
"audio": torch.ones(batch_size,),
"speech": torch.ones(batch_size,),
"ocr": torch.randint(low=0, high=2, size=(batch_size,)),
"face": torch.randint(low=0, high=2, size=(batch_size,)),
"i3d.i3d.0": torch.ones(batch_size,),
"r2p1d.r2p1d-ig65m.0": torch.ones(batch_size,).to(device),
"imagenet.senet154.0": torch.ones(batch_size,).to(device),
"imagenet.resnext101_32x48d.0": torch.ones(batch_size,).to(device),
"scene.densenet161.0": torch.ones(batch_size,).to(device),
"audio": torch.ones(batch_size,).to(device),
"speech": torch.ones(batch_size,).to(device),
"ocr": torch.randint(low=0, high=2, size=(batch_size,)).to(device),
"face": torch.randint(low=0, high=2, size=(batch_size,)).to(device),
"i3d.i3d.0": torch.ones(batch_size,).to(device),
}
text = torch.randn(batch_size, 1, 37, 768)
text = torch.randn(batch_size, 1, 37, 768).to(device)
towhee.dc([Entity(experts=experts, ind=ind, text=text)]) \
.video_text_embedding.collaborative_experts[('experts', 'ind', 'text'), ('text_embds', 'vid_embds')]().show()
.video_text_embedding.collaborative_experts[('experts', 'ind', 'text'), ('text_embds', 'vid_embds')](device=device).show()
```
![](img.png)

Loading…
Cancel
Save