logo
Browse Source

Add onnx test

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
fd3109ff51
  1. 60
      isc.py
  2. 102
      test_onnx.py

60
isc.py

@ -33,7 +33,7 @@ import timm
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
log = logging.getLogger()
log = logging.getLogger('isc_op')
@register(output_schema=['vec']) @register(output_schema=['vec'])
@ -46,13 +46,19 @@ class Isc(NNOperator):
Whether skip image transforms. Whether skip image transforms.
""" """
def __init__(self, timm_backbone: str = 'tf_efficientnetv2_m_in21ft1k',
skip_preprocess: bool = False, checkpoint_path: str = None, device: str = None) -> None:
def __init__(self,
timm_backbone: str = 'tf_efficientnetv2_m_in21ft1k',
img_size: int = 512,
checkpoint_path: str = None,
skip_preprocess: bool = False,
device: str = None) -> None:
super().__init__() super().__init__()
if device is None: if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device self.device = device
self.skip_tfms = skip_preprocess self.skip_tfms = skip_preprocess
self.timm_backbone = timm_backbone
if checkpoint_path is None: if checkpoint_path is None:
checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth') checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth')
@ -62,7 +68,7 @@ class Isc(NNOperator):
self.model.eval() self.model.eval()
self.tfms = transforms.Compose([ self.tfms = transforms.Compose([
transforms.Resize((512, 512)),
transforms.Resize((img_size, img_size)),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=backbone.default_cfg['mean'], transforms.Normalize(mean=backbone.default_cfg['mean'],
std=backbone.default_cfg['std']) std=backbone.default_cfg['std'])
@ -82,7 +88,7 @@ class Isc(NNOperator):
inputs = torch.stack(img_list) inputs = torch.stack(img_list)
inputs = inputs.to(self.device) inputs = inputs.to(self.device)
features = self.model(inputs) features = self.model(inputs)
features = features.to('cpu').flatten(1)
features = features.to('cpu')
if isinstance(data, list): if isinstance(data, list):
vecs = list(features.detach().numpy()) vecs = list(features.detach().numpy())
@ -90,6 +96,50 @@ class Isc(NNOperator):
vecs = features.squeeze(0).detach().numpy() vecs = features.squeeze(0).detach().numpy()
return vecs return vecs
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = self.timm_backbone.replace('/', '-')
path = os.path.join(path, name)
dummy_input = torch.rand(1, 3, 224, 224)
if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
elif format == 'torchscript':
path = path + '.pt'
try:
try:
jit_model = torch.jit.script(self.model)
except Exception:
jit_model = torch.jit.trace(self.model, dummy_input, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
path = path + '.onnx'
try:
torch.onnx.export(self.model,
dummy_input,
path,
input_names=['input_0'],
output_names=['output_0'],
opset_version=14,
dynamic_axes={
'input_0': {0: 'batch_size', 2: 'height', 3: 'width'},
'output_0': {0: 'batch_size', 1: 'dim'}
},
do_constant_folding=True
)
except Exception as e:
log.error(f'Fail to save as onnx: {e}.')
raise RuntimeError(f'Fail to save as onnx: {e}.')
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def convert_img(self, img: towhee._types.Image): def convert_img(self, img: towhee._types.Image):
img = PILImage.fromarray(img.astype('uint8'), 'RGB') img = PILImage.fromarray(img.astype('uint8'), 'RGB')

102
test_onnx.py

@ -0,0 +1,102 @@
import onnx
from isc import Isc
from towhee import ops
import torch
import numpy
import onnx
import onnxruntime
import os
from pathlib import Path
import logging
import platform
import psutil
models = ['tf_efficientnetv2_m_in21ft1k']
atol = 1e-3
log_path = 'isc_onnx.log'
f = open('onnx.csv', 'w+')
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n')
logger = logging.getLogger('isc_onnx')
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_path)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.ERROR)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB')
logger.debug(f'cpu: {psutil.cpu_count()}')
status = None
for name in models:
logger.info(f'***{name}***')
saved_name = name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}.onnx'
try:
op = Isc(timm_backbone=name, device='cpu')
except Exception as e:
logger.error(f'Fail to load model {name}. Please check weights.')
data = torch.rand(1, 3, 224, 224)
if status:
f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5
try:
out1 = op.model(data).detach().numpy()
logger.info('OP LOADED.')
status[1] = 'success'
except Exception as e:
logger.error(f'FAIL TO LOAD OP: {e}')
continue
try:
op.save_model(format='onnx')
logger.info('ONNX SAVED.')
status[2] = 'success'
except Exception as e:
logger.error(f'FAIL TO SAVE ONNX: {e}')
continue
try:
try:
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
except Exception:
saved_onnx = onnx.load(onnx_path, load_external_data=False)
onnx.checker.check_model(saved_onnx)
logger.info('ONNX CHECKED.')
status[3] = 'success'
except Exception as e:
logger.error(f'FAIL TO CHECK ONNX: {e}')
pass
try:
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())
out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()})
logger.info('ONNX WORKED.')
status[4] = 'success'
if numpy.allclose(out1, out2, atol=atol):
logger.info('Check accuracy: OK')
status[5] = 'success'
else:
logger.info(f'Check accuracy: atol is larger than {atol}.')
except Exception as e:
logger.error(f'FAIL TO RUN ONNX: {e}')
continue
if status:
f.write(','.join(status) + '\n')
print('Finished.')
Loading…
Cancel
Save