From fd3109ff51810624e3a66df95791bce95458df85 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Tue, 13 Dec 2022 18:02:30 +0800 Subject: [PATCH] Add onnx test Signed-off-by: Jael Gu --- isc.py | 60 +++++++++++++++++++++++++++--- test_onnx.py | 102 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 5 deletions(-) create mode 100644 test_onnx.py diff --git a/isc.py b/isc.py index 9c1d6d8..28d5704 100644 --- a/isc.py +++ b/isc.py @@ -33,7 +33,7 @@ import timm import warnings warnings.filterwarnings('ignore') -log = logging.getLogger() +log = logging.getLogger('isc_op') @register(output_schema=['vec']) @@ -46,13 +46,19 @@ class Isc(NNOperator): Whether skip image transforms. """ - def __init__(self, timm_backbone: str = 'tf_efficientnetv2_m_in21ft1k', - skip_preprocess: bool = False, checkpoint_path: str = None, device: str = None) -> None: + def __init__(self, + timm_backbone: str = 'tf_efficientnetv2_m_in21ft1k', + img_size: int = 512, + checkpoint_path: str = None, + skip_preprocess: bool = False, + device: str = None) -> None: super().__init__() if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device self.skip_tfms = skip_preprocess + self.timm_backbone = timm_backbone + if checkpoint_path is None: checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth') @@ -62,7 +68,7 @@ class Isc(NNOperator): self.model.eval() self.tfms = transforms.Compose([ - transforms.Resize((512, 512)), + transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std']) @@ -82,7 +88,7 @@ class Isc(NNOperator): inputs = torch.stack(img_list) inputs = inputs.to(self.device) features = self.model(inputs) - features = features.to('cpu').flatten(1) + features = features.to('cpu') if isinstance(data, list): vecs = list(features.detach().numpy()) @@ -90,6 +96,50 @@ class Isc(NNOperator): vecs = features.squeeze(0).detach().numpy() return vecs + def save_model(self, format: str = 'pytorch', path: str = 'default'): + if path == 'default': + path = str(Path(__file__).parent) + path = os.path.join(path, 'saved', format) + os.makedirs(path, exist_ok=True) + name = self.timm_backbone.replace('/', '-') + path = os.path.join(path, name) + dummy_input = torch.rand(1, 3, 224, 224) + if format == 'pytorch': + path = path + '.pt' + torch.save(self.model, path) + elif format == 'torchscript': + path = path + '.pt' + try: + try: + jit_model = torch.jit.script(self.model) + except Exception: + jit_model = torch.jit.trace(self.model, dummy_input, strict=False) + torch.jit.save(jit_model, path) + except Exception as e: + log.error(f'Fail to save as torchscript: {e}.') + raise RuntimeError(f'Fail to save as torchscript: {e}.') + elif format == 'onnx': + path = path + '.onnx' + try: + torch.onnx.export(self.model, + dummy_input, + path, + input_names=['input_0'], + output_names=['output_0'], + opset_version=14, + dynamic_axes={ + 'input_0': {0: 'batch_size', 2: 'height', 3: 'width'}, + 'output_0': {0: 'batch_size', 1: 'dim'} + }, + do_constant_folding=True + ) + except Exception as e: + log.error(f'Fail to save as onnx: {e}.') + raise RuntimeError(f'Fail to save as onnx: {e}.') + # todo: elif format == 'tensorrt': + else: + log.error(f'Unsupported format "{format}".') + @arg(1, to_image_color('RGB')) def convert_img(self, img: towhee._types.Image): img = PILImage.fromarray(img.astype('uint8'), 'RGB') diff --git a/test_onnx.py b/test_onnx.py new file mode 100644 index 0000000..e01ad82 --- /dev/null +++ b/test_onnx.py @@ -0,0 +1,102 @@ +import onnx +from isc import Isc + +from towhee import ops +import torch +import numpy +import onnx +import onnxruntime + +import os +from pathlib import Path +import logging +import platform +import psutil + +models = ['tf_efficientnetv2_m_in21ft1k'] + +atol = 1e-3 +log_path = 'isc_onnx.log' +f = open('onnx.csv', 'w+') +f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') + +logger = logging.getLogger('isc_onnx') +logger.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +fh = logging.FileHandler(log_path) +fh.setLevel(logging.DEBUG) +fh.setFormatter(formatter) +logger.addHandler(fh) +ch = logging.StreamHandler() +ch.setLevel(logging.ERROR) +ch.setFormatter(formatter) +logger.addHandler(ch) + +logger.debug(f'machine: {platform.platform()}-{platform.processor()}') +logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') +logger.debug(f'cpu: {psutil.cpu_count()}') + + +status = None +for name in models: + logger.info(f'***{name}***') + saved_name = name.replace('/', '-') + onnx_path = f'saved/onnx/{saved_name}.onnx' + + try: + op = Isc(timm_backbone=name, device='cpu') + except Exception as e: + logger.error(f'Fail to load model {name}. Please check weights.') + + data = torch.rand(1, 3, 224, 224) + if status: + f.write(','.join(status) + '\n') + status = [name] + ['fail'] * 5 + + try: + out1 = op.model(data).detach().numpy() + logger.info('OP LOADED.') + status[1] = 'success' + except Exception as e: + logger.error(f'FAIL TO LOAD OP: {e}') + continue + try: + op.save_model(format='onnx') + logger.info('ONNX SAVED.') + status[2] = 'success' + except Exception as e: + logger.error(f'FAIL TO SAVE ONNX: {e}') + continue + try: + try: + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + except Exception: + saved_onnx = onnx.load(onnx_path, load_external_data=False) + onnx.checker.check_model(saved_onnx) + logger.info('ONNX CHECKED.') + status[3] = 'success' + except Exception as e: + logger.error(f'FAIL TO CHECK ONNX: {e}') + pass + try: + sess = onnxruntime.InferenceSession(onnx_path, + providers=onnxruntime.get_available_providers()) + out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()}) + logger.info('ONNX WORKED.') + status[4] = 'success' + if numpy.allclose(out1, out2, atol=atol): + logger.info('Check accuracy: OK') + status[5] = 'success' + else: + logger.info(f'Check accuracy: atol is larger than {atol}.') + except Exception as e: + logger.error(f'FAIL TO RUN ONNX: {e}') + continue + +if status: + f.write(','.join(status) + '\n') + +print('Finished.')