Browse Source
modify
Signed-off-by: gexy5 <xinyu.ge@zilliz.com>
main
gexy5
2 years ago
1 changed files with
1 additions and
2 deletions
-
tsm.py
|
@ -60,6 +60,7 @@ class Tsm(NNOperator): |
|
|
if model_name == 'tsm_k400_r50_seg8': |
|
|
if model_name == 'tsm_k400_r50_seg8': |
|
|
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') |
|
|
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') |
|
|
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) |
|
|
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) |
|
|
|
|
|
self.model.eval() |
|
|
self.transform_cfgs = get_configs( |
|
|
self.transform_cfgs = get_configs( |
|
|
side_size=224, |
|
|
side_size=224, |
|
|
crop_size=224, |
|
|
crop_size=224, |
|
@ -95,8 +96,6 @@ class Tsm(NNOperator): |
|
|
) |
|
|
) |
|
|
inputs = data.to(self.device)[None, ...] |
|
|
inputs = data.to(self.device)[None, ...] |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
|
|
|
feats = self.model.forward_features(inputs) |
|
|
feats = self.model.forward_features(inputs) |
|
|
if self.model.reshape: |
|
|
if self.model.reshape: |
|
|
if self.model.is_shift and self.model.temporal_pool: |
|
|
if self.model.is_shift and self.model.temporal_pool: |
|
|