diff --git a/pytorchvideo.py b/pytorchvideo.py index 69ce409..86c32bd 100644 --- a/pytorchvideo.py +++ b/pytorchvideo.py @@ -38,8 +38,8 @@ class PytorchVideo(NNOperator): - mvit_base_32x3 skip_preprocess (`str`): Flag to skip video transforms. - classmap (`str=None`): - Path of the json file to match class names. + classmap (`dict=None`): + The dictionary maps classes to integers. topk (`int=5`): The number of classification labels to be returned (ordered by possibility from high to low). """ @@ -49,7 +49,7 @@ class PytorchVideo(NNOperator): model_name: str = 'x3d_xs', framework: str = 'pytorch', skip_preprocess: bool = False, - classmap: str = None, + classmap: dict = None, topk: int = 5, ) -> None: super().__init__(framework=framework)