From ae13361b67361a649b8c97a62d035638be13ff51 Mon Sep 17 00:00:00 2001 From: gexy5 Date: Mon, 13 Jun 2022 17:25:33 +0800 Subject: [PATCH] modify Signed-off-by: gexy5 --- tsm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tsm.py b/tsm.py index 4bcdc87..8e00784 100644 --- a/tsm.py +++ b/tsm.py @@ -60,6 +60,7 @@ class Tsm(NNOperator): if model_name == 'tsm_k400_r50_seg8': self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) + self.model.eval() self.transform_cfgs = get_configs( side_size=224, crop_size=224, @@ -95,8 +96,6 @@ class Tsm(NNOperator): ) inputs = data.to(self.device)[None, ...] - self.model.eval() - feats = self.model.forward_features(inputs) if self.model.reshape: if self.model.is_shift and self.model.temporal_pool: