diff --git a/test_onnx.py b/test_onnx.py index ab7f718..55650a7 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -45,7 +45,7 @@ for name in models: onnx_path = f'saved/onnx/{saved_name}.onnx' try: - op = TimmImage(model_name=name, device='cpu') + op = ops.image_embedding.timm(model_name=name, device='cpu').get_op() data = torch.rand((1,) + op.config['input_size']) except Exception as e: print(f'***Please re-download model {name}.***') diff --git a/timm_image.py b/timm_image.py index 3d68bb3..ef01ebc 100644 --- a/timm_image.py +++ b/timm_image.py @@ -41,6 +41,13 @@ warnings.filterwarnings('ignore') log = logging.getLogger('timm_op') +def torch_no_grad(f): + def wrap(*args, **kwargs): + with torch.no_grad(): + f(*args, **kwargs) + return wrap + + # @accelerate class Model: def __init__(self, model_name, device, num_classes): @@ -88,6 +95,7 @@ class TimmImage(NNOperator): log.warning('The operator is initialized without specified model.') pass + @torch_no_grad def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): if not isinstance(data, list): imgs = [data] @@ -270,3 +278,12 @@ class TimmImage(NNOperator): return ['onnx'] else: return [] + + def input_schema(self): + return [(Image, (-1, -1, 3))] + + def output_schema(self): + image = Image(numpy.random.randn(480, 480, 3), "RGB") + ret = self(image) + data_type = type(ret.reshape(-1)[0]) + return [(data_type, ret.shape)]