From e0b7c347f82d125f3ea10dc2e44a3a30ab898722 Mon Sep 17 00:00:00 2001 From: gexy5 Date: Thu, 16 Jun 2022 11:31:35 +0800 Subject: [PATCH] add movinet Signed-off-by: gexy5 --- movinet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/movinet.py b/movinet.py index 48d66b1..137747b 100644 --- a/movinet.py +++ b/movinet.py @@ -41,14 +41,14 @@ class Movinet(NNOperator): def __init__(self, model_name: str = 'movineta0', framework: str = 'pytorch', - casual: str = False, + causal: str = False, skip_preprocess: bool = False, classmap: dict = None, topk: int = 5, ): super().__init__(framework=framework) self.model_name = model_name - self.casual = casual + self.causal = causal self.skip_preprocess = skip_preprocess self.topk = topk self.dataset_name = 'kinetics_600' @@ -65,7 +65,7 @@ class Movinet(NNOperator): else: self.classmap = classmap self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.model = create_model(model_name=model_name, pretrained=True, casual=self.casual, device=self.device) + self.model = create_model(model_name=model_name, pretrained=True, causal=self.causal, device=self.device) self.input_mean=[0.485, 0.456, 0.406] self.input_std=[0.229, 0.224, 0.225] self.transform_cfgs = get_configs(