logo
Browse Source

Add onnx test

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
fd3109ff51
  1. 60
      isc.py
  2. 102
      test_onnx.py

60
isc.py

@ -33,7 +33,7 @@ import timm
import warnings
warnings.filterwarnings('ignore')
log = logging.getLogger()
log = logging.getLogger('isc_op')
@register(output_schema=['vec'])
@ -46,13 +46,19 @@ class Isc(NNOperator):
Whether skip image transforms.
"""
def __init__(self, timm_backbone: str = 'tf_efficientnetv2_m_in21ft1k',
skip_preprocess: bool = False, checkpoint_path: str = None, device: str = None) -> None:
def __init__(self,
timm_backbone: str = 'tf_efficientnetv2_m_in21ft1k',
img_size: int = 512,
checkpoint_path: str = None,
skip_preprocess: bool = False,
device: str = None) -> None:
super().__init__()
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.skip_tfms = skip_preprocess
self.timm_backbone = timm_backbone
if checkpoint_path is None:
checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth')
@ -62,7 +68,7 @@ class Isc(NNOperator):
self.model.eval()
self.tfms = transforms.Compose([
transforms.Resize((512, 512)),
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=backbone.default_cfg['mean'],
std=backbone.default_cfg['std'])
@ -82,7 +88,7 @@ class Isc(NNOperator):
inputs = torch.stack(img_list)
inputs = inputs.to(self.device)
features = self.model(inputs)
features = features.to('cpu').flatten(1)
features = features.to('cpu')
if isinstance(data, list):
vecs = list(features.detach().numpy())
@ -90,6 +96,50 @@ class Isc(NNOperator):
vecs = features.squeeze(0).detach().numpy()
return vecs
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = self.timm_backbone.replace('/', '-')
path = os.path.join(path, name)
dummy_input = torch.rand(1, 3, 224, 224)
if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
elif format == 'torchscript':
path = path + '.pt'
try:
try:
jit_model = torch.jit.script(self.model)
except Exception:
jit_model = torch.jit.trace(self.model, dummy_input, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
path = path + '.onnx'
try:
torch.onnx.export(self.model,
dummy_input,
path,
input_names=['input_0'],
output_names=['output_0'],
opset_version=14,
dynamic_axes={
'input_0': {0: 'batch_size', 2: 'height', 3: 'width'},
'output_0': {0: 'batch_size', 1: 'dim'}
},
do_constant_folding=True
)
except Exception as e:
log.error(f'Fail to save as onnx: {e}.')
raise RuntimeError(f'Fail to save as onnx: {e}.')
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')
@arg(1, to_image_color('RGB'))
def convert_img(self, img: towhee._types.Image):
img = PILImage.fromarray(img.astype('uint8'), 'RGB')

102
test_onnx.py

@ -0,0 +1,102 @@
import onnx
from isc import Isc
from towhee import ops
import torch
import numpy
import onnx
import onnxruntime
import os
from pathlib import Path
import logging
import platform
import psutil
models = ['tf_efficientnetv2_m_in21ft1k']
atol = 1e-3
log_path = 'isc_onnx.log'
f = open('onnx.csv', 'w+')
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n')
logger = logging.getLogger('isc_onnx')
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_path)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.ERROR)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB')
logger.debug(f'cpu: {psutil.cpu_count()}')
status = None
for name in models:
logger.info(f'***{name}***')
saved_name = name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}.onnx'
try:
op = Isc(timm_backbone=name, device='cpu')
except Exception as e:
logger.error(f'Fail to load model {name}. Please check weights.')
data = torch.rand(1, 3, 224, 224)
if status:
f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5
try:
out1 = op.model(data).detach().numpy()
logger.info('OP LOADED.')
status[1] = 'success'
except Exception as e:
logger.error(f'FAIL TO LOAD OP: {e}')
continue
try:
op.save_model(format='onnx')
logger.info('ONNX SAVED.')
status[2] = 'success'
except Exception as e:
logger.error(f'FAIL TO SAVE ONNX: {e}')
continue
try:
try:
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
except Exception:
saved_onnx = onnx.load(onnx_path, load_external_data=False)
onnx.checker.check_model(saved_onnx)
logger.info('ONNX CHECKED.')
status[3] = 'success'
except Exception as e:
logger.error(f'FAIL TO CHECK ONNX: {e}')
pass
try:
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())
out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()})
logger.info('ONNX WORKED.')
status[4] = 'success'
if numpy.allclose(out1, out2, atol=atol):
logger.info('Check accuracy: OK')
status[5] = 'success'
else:
logger.info(f'Check accuracy: atol is larger than {atol}.')
except Exception as e:
logger.error(f'FAIL TO RUN ONNX: {e}')
continue
if status:
f.write(','.join(status) + '\n')
print('Finished.')
Loading…
Cancel
Save