diff --git a/video_swin_transformer.py b/video_swin_transformer.py index db5d0f5..9a1c67f 100644 --- a/video_swin_transformer.py +++ b/video_swin_transformer.py @@ -57,7 +57,7 @@ class VideoSwinTransformer(NNOperator): self.model = video_swin_transformer.create_model(model_name=self.model_name, pretrained=True, device=self.device) - + self.model.to(self.device) self.transform_cfgs = get_configs( side_size=224, crop_size=224,