logo
Browse Source

add 16 frames model

Signed-off-by: gexy5 <xinyu.ge@zilliz.com>
main
gexy5 2 years ago
parent
commit
6b0470d6aa
  1. 2
      tsm.py

2
tsm.py

@ -57,6 +57,7 @@ class Tsm(NNOperator):
else: else:
self.classmap = classmap self.classmap = classmap
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device)
if model_name == 'tsm_k400_r50_seg8': if model_name == 'tsm_k400_r50_seg8':
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth')
self.transform_cfgs = get_configs( self.transform_cfgs = get_configs(
@ -75,7 +76,6 @@ class Tsm(NNOperator):
mean=self.model.input_mean, mean=self.model.input_mean,
std=self.model.input_std, std=self.model.input_std,
) )
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device)
self.model.eval() self.model.eval()
def __call__(self, video: List[VideoFrame]): def __call__(self, video: List[VideoFrame]):

Loading…
Cancel
Save