diff --git a/README.md b/README.md index 9ce3d02..ab96b8f 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ model_name='tsm_k400_r50_seg8', skip_preprocess=False, classmap=None, topk=5)*** ​ Supported model names: - tsm_k400_r50_seg8 +- tsm_k400_r50_seg16 ​ ***skip_preprocess***: *bool* diff --git a/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth b/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth new file mode 100644 index 0000000..8cca177 --- /dev/null +++ b/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb1f718819b94efc2dd71125e76b7ce64c529bc15eafbe81fc091ab1faef16c7 +size 97587673 diff --git a/tsm.py b/tsm.py index d607fa4..808fe3d 100644 --- a/tsm.py +++ b/tsm.py @@ -24,12 +24,13 @@ class Tsm(NNOperator): model_name (`str`): Supported model names: - tsm_k400_r50_seg8 + - tsm_k400_r50_seg16 skip_preprocess (`str`): Flag to skip video transforms. predict (`bool`): Flag to control whether predict labels. If False, then return video embedding. - classmap (`str=None`): - Path of the json file to match class names. + classmap (`dict=None`): + The dictionary maps classes to integers. topk (`int=5`): The number of classification labels to be returned (ordered by possibility from high to low). """ @@ -37,7 +38,7 @@ class Tsm(NNOperator): model_name: str = 'tsm_k400_r50_seg8', framework: str = 'pytorch', skip_preprocess: bool = False, - classmap: str = None, + classmap: dict = None, topk: int = 5, ): super().__init__(framework=framework) @@ -58,15 +59,24 @@ class Tsm(NNOperator): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' if model_name == 'tsm_k400_r50_seg8': self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') + self.transform_cfgs = get_configs( + side_size=256, + crop_size=224, + num_frames=8, + mean=self.model.input_mean, + std=self.model.input_std, + ) + elif model_name == 'tsm_k400_r50_seg16': + self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth') + self.transform_cfgs = get_configs( + side_size=256, + crop_size=224, + num_frames=16, + mean=self.model.input_mean, + std=self.model.input_std, + ) self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) self.model.eval() - self.transform_cfgs = get_configs( - side_size=256, - crop_size=224, - num_frames=8, - mean=self.model.input_mean, - std=self.model.input_std, - ) def __call__(self, video: List[VideoFrame]): """