Browse Source
Add torch no_grad in inference
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
2 changed files with
18 additions and
1 deletions
-
test_onnx.py
-
timm_image.py
|
|
@ -45,7 +45,7 @@ for name in models: |
|
|
|
onnx_path = f'saved/onnx/{saved_name}.onnx' |
|
|
|
|
|
|
|
try: |
|
|
|
op = TimmImage(model_name=name, device='cpu') |
|
|
|
op = ops.image_embedding.timm(model_name=name, device='cpu').get_op() |
|
|
|
data = torch.rand((1,) + op.config['input_size']) |
|
|
|
except Exception as e: |
|
|
|
print(f'***Please re-download model {name}.***') |
|
|
|
|
|
@ -41,6 +41,13 @@ warnings.filterwarnings('ignore') |
|
|
|
log = logging.getLogger('timm_op') |
|
|
|
|
|
|
|
|
|
|
|
def torch_no_grad(f): |
|
|
|
def wrap(*args, **kwargs): |
|
|
|
with torch.no_grad(): |
|
|
|
f(*args, **kwargs) |
|
|
|
return wrap |
|
|
|
|
|
|
|
|
|
|
|
# @accelerate |
|
|
|
class Model: |
|
|
|
def __init__(self, model_name, device, num_classes): |
|
|
@ -88,6 +95,7 @@ class TimmImage(NNOperator): |
|
|
|
log.warning('The operator is initialized without specified model.') |
|
|
|
pass |
|
|
|
|
|
|
|
@torch_no_grad |
|
|
|
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): |
|
|
|
if not isinstance(data, list): |
|
|
|
imgs = [data] |
|
|
@ -270,3 +278,12 @@ class TimmImage(NNOperator): |
|
|
|
return ['onnx'] |
|
|
|
else: |
|
|
|
return [] |
|
|
|
|
|
|
|
def input_schema(self): |
|
|
|
return [(Image, (-1, -1, 3))] |
|
|
|
|
|
|
|
def output_schema(self): |
|
|
|
image = Image(numpy.random.randn(480, 480, 3), "RGB") |
|
|
|
ret = self(image) |
|
|
|
data_type = type(ret.reshape(-1)[0]) |
|
|
|
return [(data_type, ret.shape)] |
|
|
|