logo
Browse Source

add 16 frames model

Signed-off-by: gexy5 <xinyu.ge@zilliz.com>
main
gexy5 2 years ago
parent
commit
68dc79bf50
  1. 1
      README.md
  2. BIN
      TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth
  3. 30
      tsm.py

1
README.md

@ -68,6 +68,7 @@ model_name='tsm_k400_r50_seg8', skip_preprocess=False, classmap=None, topk=5)***
​ Supported model names:
- tsm_k400_r50_seg8
- tsm_k400_r50_seg16
***skip_preprocess***: *bool*

BIN
TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth (Stored with Git LFS)

Binary file not shown.

30
tsm.py

@ -24,12 +24,13 @@ class Tsm(NNOperator):
model_name (`str`):
Supported model names:
- tsm_k400_r50_seg8
- tsm_k400_r50_seg16
skip_preprocess (`str`):
Flag to skip video transforms.
predict (`bool`):
Flag to control whether predict labels. If False, then return video embedding.
classmap (`str=None`):
Path of the json file to match class names.
classmap (`dict=None`):
The dictionary maps classes to integers.
topk (`int=5`):
The number of classification labels to be returned (ordered by possibility from high to low).
"""
@ -37,7 +38,7 @@ class Tsm(NNOperator):
model_name: str = 'tsm_k400_r50_seg8',
framework: str = 'pytorch',
skip_preprocess: bool = False,
classmap: str = None,
classmap: dict = None,
topk: int = 5,
):
super().__init__(framework=framework)
@ -58,15 +59,24 @@ class Tsm(NNOperator):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if model_name == 'tsm_k400_r50_seg8':
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth')
self.transform_cfgs = get_configs(
side_size=256,
crop_size=224,
num_frames=8,
mean=self.model.input_mean,
std=self.model.input_std,
)
elif model_name == 'tsm_k400_r50_seg16':
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth')
self.transform_cfgs = get_configs(
side_size=256,
crop_size=224,
num_frames=16,
mean=self.model.input_mean,
std=self.model.input_std,
)
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device)
self.model.eval()
self.transform_cfgs = get_configs(
side_size=256,
crop_size=224,
num_frames=8,
mean=self.model.input_mean,
std=self.model.input_std,
)
def __call__(self, video: List[VideoFrame]):
"""

Loading…
Cancel
Save