logo
Browse Source

Add test for onnx

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
d815531df0
  1. 167
      test_onnx.py
  2. 35
      timm_image.py

167
test_onnx.py

@ -1,110 +1,97 @@
from towhee import ops
from timm_image import TimmImage
import torch
import numpy
import onnx
import onnxruntime
f = open('onnx.csv', 'a+')
f.write('model_name, run_op, save_onnx, check_onnx\n')
import os
from pathlib import Path
import logging
import platform
import psutil
# models = TimmImage.supported_model_names()[:1]
models = [
'vgg11',
'resnet18',
'resnetv2_50',
'seresnet33ts',
'skresnet18',
'resnext26ts',
'seresnext26d_32x4d',
'skresnext50_32x4d',
'convit_base',
'inception_v4',
'efficientnet_b0',
'tf_efficientnet_b0',
'swin_base_patch4_window7_224',
'vit_base_patch8_224',
'beit_base_patch16_224',
'convnext_base',
'crossvit_9_240',
'convmixer_768_32',
'coat_lite_mini',
'inception_v3',
'cait_m36_384',
'cspdarknet53',
'deit_base_distilled_patch16_224',
'densenet121',
'dla34',
'dm_nfnet_f0',
'nf_regnet_b1',
'nf_resnet50',
'dpn68',
'ese_vovnet19b_dw',
'fbnetc_100',
'fbnetv3_b',
'halonet26t',
'eca_halonext26ts',
'sehalonet33ts',
'hardcorenas_a',
'hrnet_w18',
'jx_nest_base',
'lcnet_050',
'levit_128',
'mixer_b16_224',
'mixnet_s',
'mnasnet_100',
'mobilenetv2_050',
'mobilenetv3_large_100',
'nasnetalarge',
'pit_b_224',
'pnasnet5large',
'regnetx_002',
'repvgg_a2',
'res2net50_14w_8s',
'res2next50',
'resmlp_12_224',
'resnest14d',
'rexnet_100',
'selecsls42b',
'semnasnet_075',
'tinynet_a',
'tnt_s_patch16_224',
'tresnet_l',
'twins_pcpvt_base',
'visformer_small',
'xception',
'xcit_large_24_p8_224',
'ghostnet_100',
'gmlp_s16_224',
'lambda_resnet26rpt_256',
'spnasnet_100',
]
models = ['resnet50']
decoder = ops.image_decode()
data = decoder('./towhee.jpeg')
atol = 1e-3
log_path = 'timm_onnx.log'
f = open('onnx.csv', 'w+')
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n')
logger = logging.getLogger('timm_onnx')
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_path)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.ERROR)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB')
logger.debug(f'cpu: {psutil.cpu_count()}')
status = None
for name in models:
f.write(f'{name},')
logger.info(f'***{name}***')
saved_name = name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}.onnx'
if status:
f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5
op = TimmImage(model_name=name, device='cpu')
data = torch.rand((1,) + op.config['input_size'])
try:
op = TimmImage(model_name=name)
out1 = op(data)
f.write('success,')
out1 = op.model(data).detach().numpy()
logger.info('OP LOADED.')
status[1] = 'success'
except Exception as e:
f.write('fail\n')
print(f'Fail to load op for {name}: {e}')
logger.error(f'FAIL TO LOAD OP: {e}')
continue
try:
op.save_model(format='onnx')
f.write('success,')
logger.info('ONNX SAVED.')
status[2] = 'success'
except Exception as e:
logger.error(f'FAIL TO SAVE ONNX: {e}')
continue
try:
try:
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
except Exception:
saved_onnx = onnx.load(onnx_path, load_external_data=False)
onnx.checker.check_model(saved_onnx)
logger.info('ONNX CHECKED.')
status[3] = 'success'
except Exception as e:
f.write('fail\n')
print(f'Fail to save onnx for {name}: {e}')
logger.error(f'FAIL TO CHECK ONNX: {e}')
continue
try:
saved_name = name.replace('/', '-')
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx')
onnx.checker.check_model(onnx_model)
f.write('success')
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())
out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()})
logger.info('ONNX WORKED.')
status[4] = 'success'
if numpy.allclose(out1, out2, atol=atol):
logger.info('Check accuracy: OK')
status[5] = 'success'
else:
logger.info(f'Check accuracy: atol is larger than {atol}.')
except Exception as e:
f.write('fail\n')
print(f'Fail to check onnx for {name}: {e}')
logger.error(f'FAIL TO RUN ONNX: {e}')
continue
f.write('\n')
if status:
f.write(','.join(status) + '\n')
print('Finished.')

35
timm_image.py

@ -37,7 +37,7 @@ from timm.models.factory import create_model
import warnings
warnings.filterwarnings('ignore')
log = logging.getLogger()
log = logging.getLogger('timm_op')
@register(output_schema=['vec'])
@ -53,13 +53,22 @@ class TimmImage(NNOperator):
Whether skip image transforms.
"""
def __init__(self, model_name: str, num_classes: int = 1000, skip_preprocess: bool = False) -> None:
def __init__(self,
model_name: str = 'resnet50',
num_classes: int = 1000,
skip_preprocess: bool = False,
pretrained: bool = True,
device: str = None
) -> None:
super().__init__()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.model_name = model_name
self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes)
self.model.to(self.device)
self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes)
self.model.eval()
self.model.to(self.device)
self.config = resolve_data_config({}, model=self.model)
self.tfms = create_transform(**self.config)
self.skip_tfms = skip_preprocess
@ -120,9 +129,15 @@ class TimmImage(NNOperator):
torch.onnx.export(self.model,
dummy_input,
path,
input_names=["input"],
output_names=["output"],
opset_version=12)
input_names=['input_0'],
output_names=['output_0'],
opset_version=13,
dynamic_axes={
'input_0': {0: 'batch_size'},
'output_0': {0: 'batch_size'}
},
do_constant_folding=True
)
except Exception as e:
log.error(f'Fail to save as onnx: {e}.')
raise RuntimeError(f'Fail to save as onnx: {e}.')
@ -140,6 +155,10 @@ class TimmImage(NNOperator):
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
elif format == 'onnx':
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
# todo: elif format == 'torchscript':
# todo: elif format == 'tensorrt'
else:

Loading…
Cancel
Save