Browse Source
add 16 frames model
Signed-off-by: gexy5 <xinyu.ge@zilliz.com>
main
gexy5
2 years ago
1 changed files with
1 additions and
1 deletions
-
tsm.py
|
|
@ -57,6 +57,7 @@ class Tsm(NNOperator): |
|
|
|
else: |
|
|
|
self.classmap = classmap |
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) |
|
|
|
if model_name == 'tsm_k400_r50_seg8': |
|
|
|
self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') |
|
|
|
self.transform_cfgs = get_configs( |
|
|
@ -75,7 +76,6 @@ class Tsm(NNOperator): |
|
|
|
mean=self.model.input_mean, |
|
|
|
std=self.model.input_std, |
|
|
|
) |
|
|
|
self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) |
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
def __call__(self, video: List[VideoFrame]): |
|
|
|