diff --git a/README.md b/README.md index ad9330f..d49999f 100644 --- a/README.md +++ b/README.md @@ -35,32 +35,32 @@ import towhee torch.manual_seed(42) batch_size = 8 -experts = {"audio": torch.rand(batch_size, 29, 128), - "face": torch.rand(batch_size, 512), - "i3d.i3d.0": torch.rand(batch_size, 1024), - "imagenet.resnext101_32x48d.0": torch.rand(batch_size, 2048), - "imagenet.senet154.0": torch.rand(batch_size, 2048), - "ocr": torch.rand(batch_size, 49, 300), - "r2p1d.r2p1d-ig65m.0": torch.rand(batch_size, 512), - "scene.densenet161.0": torch.rand(batch_size, 2208), - "speech": torch.rand(batch_size, 32, 300) +device = 'cuda' if torch.cuda.is_available() else 'cpu' +experts = {"audio": torch.rand(batch_size, 29, 128).to(device), + "face": torch.rand(batch_size, 512).to(device), + "i3d.i3d.0": torch.rand(batch_size, 1024).to(device), + "imagenet.resnext101_32x48d.0": torch.rand(batch_size, 2048).to(device), + "imagenet.senet154.0": torch.rand(batch_size, 2048).to(device), + "ocr": torch.rand(batch_size, 49, 300).to(device), + "r2p1d.r2p1d-ig65m.0": torch.rand(batch_size, 512).to(device), + "scene.densenet161.0": torch.rand(batch_size, 2208).to(device), + "speech": torch.rand(batch_size, 32, 300).to(device) } ind = { - "r2p1d.r2p1d-ig65m.0": torch.ones(batch_size,), - "imagenet.senet154.0": torch.ones(batch_size,), - "imagenet.resnext101_32x48d.0": torch.ones(batch_size,), - "scene.densenet161.0": torch.ones(batch_size,), - "audio": torch.ones(batch_size,), - "speech": torch.ones(batch_size,), - "ocr": torch.randint(low=0, high=2, size=(batch_size,)), - "face": torch.randint(low=0, high=2, size=(batch_size,)), - "i3d.i3d.0": torch.ones(batch_size,), + "r2p1d.r2p1d-ig65m.0": torch.ones(batch_size,).to(device), + "imagenet.senet154.0": torch.ones(batch_size,).to(device), + "imagenet.resnext101_32x48d.0": torch.ones(batch_size,).to(device), + "scene.densenet161.0": torch.ones(batch_size,).to(device), + "audio": torch.ones(batch_size,).to(device), + "speech": torch.ones(batch_size,).to(device), + "ocr": torch.randint(low=0, high=2, size=(batch_size,)).to(device), + "face": torch.randint(low=0, high=2, size=(batch_size,)).to(device), + "i3d.i3d.0": torch.ones(batch_size,).to(device), } - -text = torch.randn(batch_size, 1, 37, 768) +text = torch.randn(batch_size, 1, 37, 768).to(device) towhee.dc([Entity(experts=experts, ind=ind, text=text)]) \ - .video_text_embedding.collaborative_experts[('experts', 'ind', 'text'), ('text_embds', 'vid_embds')]().show() + .video_text_embedding.collaborative_experts[('experts', 'ind', 'text'), ('text_embds', 'vid_embds')](device=device).show() ``` ![](img.png)