logo
Browse Source

add 16 frames model

Signed-off-by: gexy5 <xinyu.ge@zilliz.com>
main
gexy5 2 years ago
parent
commit
947a323898
  1. 6
      tsm.py

6
tsm.py

@ -57,9 +57,12 @@ class Tsm(NNOperator):
else:
self.classmap = classmap
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device)
if model_name == 'tsm_k400_r50_seg8':
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth')
elif model_name == 'tsm_k400_r50_seg16':
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth')
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device)
if model_name == 'tsm_k400_r50_seg8':
self.transform_cfgs = get_configs(
side_size=256,
crop_size=224,
@ -68,7 +71,6 @@ class Tsm(NNOperator):
std=self.model.input_std,
)
elif model_name == 'tsm_k400_r50_seg16':
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth')
self.transform_cfgs = get_configs(
side_size=256,
crop_size=224,

Loading…
Cancel
Save