Browse Source
add movinet
Signed-off-by: gexy5 <xinyu.ge@zilliz.com>
main
gexy5
3 years ago
1 changed files with
3 additions and
3 deletions
-
movinet.py
|
|
@ -41,14 +41,14 @@ class Movinet(NNOperator): |
|
|
|
def __init__(self, |
|
|
|
model_name: str = 'movineta0', |
|
|
|
framework: str = 'pytorch', |
|
|
|
casual: str = False, |
|
|
|
causal: str = False, |
|
|
|
skip_preprocess: bool = False, |
|
|
|
classmap: dict = None, |
|
|
|
topk: int = 5, |
|
|
|
): |
|
|
|
super().__init__(framework=framework) |
|
|
|
self.model_name = model_name |
|
|
|
self.casual = casual |
|
|
|
self.causal = causal |
|
|
|
self.skip_preprocess = skip_preprocess |
|
|
|
self.topk = topk |
|
|
|
self.dataset_name = 'kinetics_600' |
|
|
@ -65,7 +65,7 @@ class Movinet(NNOperator): |
|
|
|
else: |
|
|
|
self.classmap = classmap |
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.model = create_model(model_name=model_name, pretrained=True, casual=self.casual, device=self.device) |
|
|
|
self.model = create_model(model_name=model_name, pretrained=True, causal=self.causal, device=self.device) |
|
|
|
self.input_mean=[0.485, 0.456, 0.406] |
|
|
|
self.input_std=[0.229, 0.224, 0.225] |
|
|
|
self.transform_cfgs = get_configs( |
|
|
|