diff --git a/pytorchvideo.py b/pytorchvideo.py index 86c32bd..5808e77 100644 --- a/pytorchvideo.py +++ b/pytorchvideo.py @@ -65,7 +65,8 @@ class PytorchVideo(NNOperator): self.classmap[v] = str(k).replace('"', '') else: self.classmap = classmap - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + # todo: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = 'cpu' self.model = torch.hub.load('facebookresearch/pytorchvideo', model=model_name, pretrained=True) self.model.eval()