logo
Browse Source

Add torch no_grad in inference

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
efb75f26b3
  1. 2
      test_onnx.py
  2. 17
      timm_image.py

2
test_onnx.py

@ -45,7 +45,7 @@ for name in models:
onnx_path = f'saved/onnx/{saved_name}.onnx'
try:
op = TimmImage(model_name=name, device='cpu')
op = ops.image_embedding.timm(model_name=name, device='cpu').get_op()
data = torch.rand((1,) + op.config['input_size'])
except Exception as e:
print(f'***Please re-download model {name}.***')

17
timm_image.py

@ -41,6 +41,13 @@ warnings.filterwarnings('ignore')
log = logging.getLogger('timm_op')
def torch_no_grad(f):
def wrap(*args, **kwargs):
with torch.no_grad():
f(*args, **kwargs)
return wrap
# @accelerate
class Model:
def __init__(self, model_name, device, num_classes):
@ -88,6 +95,7 @@ class TimmImage(NNOperator):
log.warning('The operator is initialized without specified model.')
pass
@torch_no_grad
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]):
if not isinstance(data, list):
imgs = [data]
@ -270,3 +278,12 @@ class TimmImage(NNOperator):
return ['onnx']
else:
return []
def input_schema(self):
return [(Image, (-1, -1, 3))]
def output_schema(self):
image = Image(numpy.random.randn(480, 480, 3), "RGB")
ret = self(image)
data_type = type(ret.reshape(-1)[0])
return [(data_type, ret.shape)]

Loading…
Cancel
Save