diff --git a/tsm.py b/tsm.py index 268023d..b791be4 100644 --- a/tsm.py +++ b/tsm.py @@ -57,9 +57,12 @@ class Tsm(NNOperator): else: self.classmap = classmap self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) if model_name == 'tsm_k400_r50_seg8': self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') + elif model_name == 'tsm_k400_r50_seg16': + self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth') + self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) + if model_name == 'tsm_k400_r50_seg8': self.transform_cfgs = get_configs( side_size=256, crop_size=224, @@ -68,7 +71,6 @@ class Tsm(NNOperator): std=self.model.input_std, ) elif model_name == 'tsm_k400_r50_seg16': - self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth') self.transform_cfgs = get_configs( side_size=256, crop_size=224,