|
@ -35,32 +35,32 @@ import towhee |
|
|
torch.manual_seed(42) |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
batch_size = 8 |
|
|
batch_size = 8 |
|
|
experts = {"audio": torch.rand(batch_size, 29, 128), |
|
|
|
|
|
"face": torch.rand(batch_size, 512), |
|
|
|
|
|
"i3d.i3d.0": torch.rand(batch_size, 1024), |
|
|
|
|
|
"imagenet.resnext101_32x48d.0": torch.rand(batch_size, 2048), |
|
|
|
|
|
"imagenet.senet154.0": torch.rand(batch_size, 2048), |
|
|
|
|
|
"ocr": torch.rand(batch_size, 49, 300), |
|
|
|
|
|
"r2p1d.r2p1d-ig65m.0": torch.rand(batch_size, 512), |
|
|
|
|
|
"scene.densenet161.0": torch.rand(batch_size, 2208), |
|
|
|
|
|
"speech": torch.rand(batch_size, 32, 300) |
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
experts = {"audio": torch.rand(batch_size, 29, 128).to(device), |
|
|
|
|
|
"face": torch.rand(batch_size, 512).to(device), |
|
|
|
|
|
"i3d.i3d.0": torch.rand(batch_size, 1024).to(device), |
|
|
|
|
|
"imagenet.resnext101_32x48d.0": torch.rand(batch_size, 2048).to(device), |
|
|
|
|
|
"imagenet.senet154.0": torch.rand(batch_size, 2048).to(device), |
|
|
|
|
|
"ocr": torch.rand(batch_size, 49, 300).to(device), |
|
|
|
|
|
"r2p1d.r2p1d-ig65m.0": torch.rand(batch_size, 512).to(device), |
|
|
|
|
|
"scene.densenet161.0": torch.rand(batch_size, 2208).to(device), |
|
|
|
|
|
"speech": torch.rand(batch_size, 32, 300).to(device) |
|
|
} |
|
|
} |
|
|
ind = { |
|
|
ind = { |
|
|
"r2p1d.r2p1d-ig65m.0": torch.ones(batch_size,), |
|
|
|
|
|
"imagenet.senet154.0": torch.ones(batch_size,), |
|
|
|
|
|
"imagenet.resnext101_32x48d.0": torch.ones(batch_size,), |
|
|
|
|
|
"scene.densenet161.0": torch.ones(batch_size,), |
|
|
|
|
|
"audio": torch.ones(batch_size,), |
|
|
|
|
|
"speech": torch.ones(batch_size,), |
|
|
|
|
|
"ocr": torch.randint(low=0, high=2, size=(batch_size,)), |
|
|
|
|
|
"face": torch.randint(low=0, high=2, size=(batch_size,)), |
|
|
|
|
|
"i3d.i3d.0": torch.ones(batch_size,), |
|
|
|
|
|
|
|
|
"r2p1d.r2p1d-ig65m.0": torch.ones(batch_size,).to(device), |
|
|
|
|
|
"imagenet.senet154.0": torch.ones(batch_size,).to(device), |
|
|
|
|
|
"imagenet.resnext101_32x48d.0": torch.ones(batch_size,).to(device), |
|
|
|
|
|
"scene.densenet161.0": torch.ones(batch_size,).to(device), |
|
|
|
|
|
"audio": torch.ones(batch_size,).to(device), |
|
|
|
|
|
"speech": torch.ones(batch_size,).to(device), |
|
|
|
|
|
"ocr": torch.randint(low=0, high=2, size=(batch_size,)).to(device), |
|
|
|
|
|
"face": torch.randint(low=0, high=2, size=(batch_size,)).to(device), |
|
|
|
|
|
"i3d.i3d.0": torch.ones(batch_size,).to(device), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
text = torch.randn(batch_size, 1, 37, 768) |
|
|
|
|
|
|
|
|
text = torch.randn(batch_size, 1, 37, 768).to(device) |
|
|
|
|
|
|
|
|
towhee.dc([Entity(experts=experts, ind=ind, text=text)]) \ |
|
|
towhee.dc([Entity(experts=experts, ind=ind, text=text)]) \ |
|
|
.video_text_embedding.collaborative_experts[('experts', 'ind', 'text'), ('text_embds', 'vid_embds')]().show() |
|
|
|
|
|
|
|
|
.video_text_embedding.collaborative_experts[('experts', 'ind', 'text'), ('text_embds', 'vid_embds')](device=device).show() |
|
|
``` |
|
|
``` |
|
|
|
|
|
|
|
|
 |
|
|
 |
|
|