diff --git a/test_onnx.py b/test_onnx.py index 55650a7..294ebce 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -57,6 +57,8 @@ for name in models: status = [name] + ['fail'] * 5 try: + out0 = op(data.cpu().detach().numpy().squeeze(0)) + assert out0.shape out1 = op.model(data).detach().numpy() logger.info('OP LOADED.') status[1] = 'success' diff --git a/timm_image.py b/timm_image.py index ef01ebc..5da4d2f 100644 --- a/timm_image.py +++ b/timm_image.py @@ -44,7 +44,7 @@ log = logging.getLogger('timm_op') def torch_no_grad(f): def wrap(*args, **kwargs): with torch.no_grad(): - f(*args, **kwargs) + return f(*args, **kwargs) return wrap