|
@ -57,9 +57,12 @@ class Tsm(NNOperator): |
|
|
else: |
|
|
else: |
|
|
self.classmap = classmap |
|
|
self.classmap = classmap |
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) |
|
|
|
|
|
if model_name == 'tsm_k400_r50_seg8': |
|
|
if model_name == 'tsm_k400_r50_seg8': |
|
|
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') |
|
|
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') |
|
|
|
|
|
elif model_name == 'tsm_k400_r50_seg16': |
|
|
|
|
|
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth') |
|
|
|
|
|
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) |
|
|
|
|
|
if model_name == 'tsm_k400_r50_seg8': |
|
|
self.transform_cfgs = get_configs( |
|
|
self.transform_cfgs = get_configs( |
|
|
side_size=256, |
|
|
side_size=256, |
|
|
crop_size=224, |
|
|
crop_size=224, |
|
@ -68,7 +71,6 @@ class Tsm(NNOperator): |
|
|
std=self.model.input_std, |
|
|
std=self.model.input_std, |
|
|
) |
|
|
) |
|
|
elif model_name == 'tsm_k400_r50_seg16': |
|
|
elif model_name == 'tsm_k400_r50_seg16': |
|
|
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth') |
|
|
|
|
|
self.transform_cfgs = get_configs( |
|
|
self.transform_cfgs = get_configs( |
|
|
side_size=256, |
|
|
side_size=256, |
|
|
crop_size=224, |
|
|
crop_size=224, |
|
|