logo
Browse Source

modify

Signed-off-by: gexy5 <xinyu.ge@zilliz.com>
main
gexy5 2 years ago
parent
commit
ae13361b67
  1. 3
      tsm.py

3
tsm.py

@ -60,6 +60,7 @@ class Tsm(NNOperator):
if model_name == 'tsm_k400_r50_seg8': if model_name == 'tsm_k400_r50_seg8':
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth')
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device)
self.model.eval()
self.transform_cfgs = get_configs( self.transform_cfgs = get_configs(
side_size=224, side_size=224,
crop_size=224, crop_size=224,
@ -95,8 +96,6 @@ class Tsm(NNOperator):
) )
inputs = data.to(self.device)[None, ...] inputs = data.to(self.device)[None, ...]
self.model.eval()
feats = self.model.forward_features(inputs) feats = self.model.forward_features(inputs)
if self.model.reshape: if self.model.reshape:
if self.model.is_shift and self.model.temporal_pool: if self.model.is_shift and self.model.temporal_pool:

Loading…
Cancel
Save