diff --git a/timesformer.py b/timesformer.py index 04914f6..8fbc73f 100644 --- a/timesformer.py +++ b/timesformer.py @@ -28,8 +28,8 @@ class Timesformer(NNOperator): - timesformer_k400_8x224 skip_preprocess (`str`): Flag to skip video transforms. - classmap (`str=None`): - Path of the json file to match class names. + classmap (`dict=None`): + The dictionary maps classes to integers. topk (`int=5`): The number of classification labels to be returned (ordered by possibility from high to low). """ @@ -37,7 +37,7 @@ class Timesformer(NNOperator): model_name: str = 'timesformer_k400_8x224', framework: str = 'pytorch', skip_preprocess: bool = False, - classmap: str = None, + classmap: dict = None, topk: int = 5, ): super().__init__(framework=framework)