diff --git a/pytorchvideo.py b/pytorchvideo.py index 5808e77..7aa32bd 100644 --- a/pytorchvideo.py +++ b/pytorchvideo.py @@ -65,8 +65,8 @@ class PytorchVideo(NNOperator): self.classmap[v] = str(k).replace('"', '') else: self.classmap = classmap - # todo: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.device = 'cpu' + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + # self.device = 'cpu' self.model = torch.hub.load('facebookresearch/pytorchvideo', model=model_name, pretrained=True) self.model.eval()