|
|
@ -24,12 +24,13 @@ class Tsm(NNOperator): |
|
|
|
model_name (`str`): |
|
|
|
Supported model names: |
|
|
|
- tsm_k400_r50_seg8 |
|
|
|
- tsm_k400_r50_seg16 |
|
|
|
skip_preprocess (`str`): |
|
|
|
Flag to skip video transforms. |
|
|
|
predict (`bool`): |
|
|
|
Flag to control whether predict labels. If False, then return video embedding. |
|
|
|
classmap (`str=None`): |
|
|
|
Path of the json file to match class names. |
|
|
|
classmap (`dict=None`): |
|
|
|
The dictionary maps classes to integers. |
|
|
|
topk (`int=5`): |
|
|
|
The number of classification labels to be returned (ordered by possibility from high to low). |
|
|
|
""" |
|
|
@ -37,7 +38,7 @@ class Tsm(NNOperator): |
|
|
|
model_name: str = 'tsm_k400_r50_seg8', |
|
|
|
framework: str = 'pytorch', |
|
|
|
skip_preprocess: bool = False, |
|
|
|
classmap: str = None, |
|
|
|
classmap: dict = None, |
|
|
|
topk: int = 5, |
|
|
|
): |
|
|
|
super().__init__(framework=framework) |
|
|
@ -58,15 +59,24 @@ class Tsm(NNOperator): |
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
if model_name == 'tsm_k400_r50_seg8': |
|
|
|
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') |
|
|
|
self.transform_cfgs = get_configs( |
|
|
|
side_size=256, |
|
|
|
crop_size=224, |
|
|
|
num_frames=8, |
|
|
|
mean=self.model.input_mean, |
|
|
|
std=self.model.input_std, |
|
|
|
) |
|
|
|
elif model_name == 'tsm_k400_r50_seg16': |
|
|
|
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment16_e50.pth') |
|
|
|
self.transform_cfgs = get_configs( |
|
|
|
side_size=256, |
|
|
|
crop_size=224, |
|
|
|
num_frames=16, |
|
|
|
mean=self.model.input_mean, |
|
|
|
std=self.model.input_std, |
|
|
|
) |
|
|
|
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) |
|
|
|
self.model.eval() |
|
|
|
self.transform_cfgs = get_configs( |
|
|
|
side_size=256, |
|
|
|
crop_size=224, |
|
|
|
num_frames=8, |
|
|
|
mean=self.model.input_mean, |
|
|
|
std=self.model.input_std, |
|
|
|
) |
|
|
|
|
|
|
|
def __call__(self, video: List[VideoFrame]): |
|
|
|
""" |
|
|
|