diff --git a/timm_image.py b/timm_image.py index aafd30b..3070762 100644 --- a/timm_image.py +++ b/timm_image.py @@ -22,6 +22,7 @@ import towhee from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register +from towhee.types import Image import torch from torch import nn @@ -144,3 +145,12 @@ class TimmImage(NNOperator): else: log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".') return model_list + + def input_schema(self): + return [(Image, (-1, -1, 3))] + + def output_schema(self): + image = Image(numpy.random.randn(480, 480, 3), "RGB") + ret = self(image) + data_type = type(ret.reshape(-1)[0]) + return [(data_type, ret.shape)]