logo
Browse Source

add movinet

Signed-off-by: gexy5 <xinyu.ge@zilliz.com>
main
gexy5 3 years ago
parent
commit
e0b7c347f8
  1. 6
      movinet.py

6
movinet.py

@ -41,14 +41,14 @@ class Movinet(NNOperator):
def __init__(self,
model_name: str = 'movineta0',
framework: str = 'pytorch',
casual: str = False,
causal: str = False,
skip_preprocess: bool = False,
classmap: dict = None,
topk: int = 5,
):
super().__init__(framework=framework)
self.model_name = model_name
self.casual = casual
self.causal = causal
self.skip_preprocess = skip_preprocess
self.topk = topk
self.dataset_name = 'kinetics_600'
@ -65,7 +65,7 @@ class Movinet(NNOperator):
else:
self.classmap = classmap
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = create_model(model_name=model_name, pretrained=True, casual=self.casual, device=self.device)
self.model = create_model(model_name=model_name, pretrained=True, causal=self.causal, device=self.device)
self.input_mean=[0.485, 0.456, 0.406]
self.input_std=[0.229, 0.224, 0.225]
self.transform_cfgs = get_configs(

Loading…
Cancel
Save