diff --git a/tsm.py b/tsm.py index 808fe3d..268023d 100644 --- a/tsm.py +++ b/tsm.py @@ -57,6 +57,7 @@ class Tsm(NNOperator): else: self.classmap = classmap self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) if model_name == 'tsm_k400_r50_seg8': self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') self.transform_cfgs = get_configs( @@ -75,7 +76,6 @@ class Tsm(NNOperator): mean=self.model.input_mean, std=self.model.input_std, ) - self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) self.model.eval() def __call__(self, video: List[VideoFrame]):