logo
Browse Source

add 16 frames model

Signed-off-by: gexy5 <xinyu.ge@zilliz.com>
main
gexy5 2 years ago
parent
commit
68dc79bf50
  1. 1
      README.md
  2. BIN
      TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth
  3. 30
      tsm.py

1
README.md

@ -68,6 +68,7 @@ model_name='tsm_k400_r50_seg8', skip_preprocess=False, classmap=None, topk=5)***
​ Supported model names: ​ Supported model names:
- tsm_k400_r50_seg8 - tsm_k400_r50_seg8
- tsm_k400_r50_seg16
***skip_preprocess***: *bool* ***skip_preprocess***: *bool*

BIN
TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth (Stored with Git LFS)

Binary file not shown.

30
tsm.py

@ -24,12 +24,13 @@ class Tsm(NNOperator):
model_name (`str`): model_name (`str`):
Supported model names: Supported model names:
- tsm_k400_r50_seg8 - tsm_k400_r50_seg8
- tsm_k400_r50_seg16
skip_preprocess (`str`): skip_preprocess (`str`):
Flag to skip video transforms. Flag to skip video transforms.
predict (`bool`): predict (`bool`):
Flag to control whether predict labels. If False, then return video embedding. Flag to control whether predict labels. If False, then return video embedding.
classmap (`str=None`):
Path of the json file to match class names.
classmap (`dict=None`):
The dictionary maps classes to integers.
topk (`int=5`): topk (`int=5`):
The number of classification labels to be returned (ordered by possibility from high to low). The number of classification labels to be returned (ordered by possibility from high to low).
""" """
@ -37,7 +38,7 @@ class Tsm(NNOperator):
model_name: str = 'tsm_k400_r50_seg8', model_name: str = 'tsm_k400_r50_seg8',
framework: str = 'pytorch', framework: str = 'pytorch',
skip_preprocess: bool = False, skip_preprocess: bool = False,
classmap: str = None,
classmap: dict = None,
topk: int = 5, topk: int = 5,
): ):
super().__init__(framework=framework) super().__init__(framework=framework)
@ -58,15 +59,24 @@ class Tsm(NNOperator):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if model_name == 'tsm_k400_r50_seg8': if model_name == 'tsm_k400_r50_seg8':
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth')
self.transform_cfgs = get_configs(
side_size=256,
crop_size=224,
num_frames=8,
mean=self.model.input_mean,
std=self.model.input_std,
)
elif model_name == 'tsm_k400_r50_seg16':
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth')
self.transform_cfgs = get_configs(
side_size=256,
crop_size=224,
num_frames=16,
mean=self.model.input_mean,
std=self.model.input_std,
)
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device)
self.model.eval() self.model.eval()
self.transform_cfgs = get_configs(
side_size=256,
crop_size=224,
num_frames=8,
mean=self.model.input_mean,
std=self.model.input_std,
)
def __call__(self, video: List[VideoFrame]): def __call__(self, video: List[VideoFrame]):
""" """

Loading…
Cancel
Save