diff --git a/test_onnx.py b/test_onnx.py index d8b018c..286e07b 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -1,110 +1,97 @@ from towhee import ops from timm_image import TimmImage +import torch +import numpy import onnx +import onnxruntime -f = open('onnx.csv', 'a+') -f.write('model_name, run_op, save_onnx, check_onnx\n') +import os +from pathlib import Path +import logging +import platform +import psutil # models = TimmImage.supported_model_names()[:1] -models = [ - 'vgg11', - 'resnet18', - 'resnetv2_50', - 'seresnet33ts', - 'skresnet18', - 'resnext26ts', - 'seresnext26d_32x4d', - 'skresnext50_32x4d', - 'convit_base', - 'inception_v4', - 'efficientnet_b0', - 'tf_efficientnet_b0', - 'swin_base_patch4_window7_224', - 'vit_base_patch8_224', - 'beit_base_patch16_224', - 'convnext_base', - 'crossvit_9_240', - 'convmixer_768_32', - 'coat_lite_mini', - 'inception_v3', - 'cait_m36_384', - 'cspdarknet53', - 'deit_base_distilled_patch16_224', - 'densenet121', - 'dla34', - 'dm_nfnet_f0', - 'nf_regnet_b1', - 'nf_resnet50', - 'dpn68', - 'ese_vovnet19b_dw', - 'fbnetc_100', - 'fbnetv3_b', - 'halonet26t', - 'eca_halonext26ts', - 'sehalonet33ts', - 'hardcorenas_a', - 'hrnet_w18', - 'jx_nest_base', - 'lcnet_050', - 'levit_128', - 'mixer_b16_224', - 'mixnet_s', - 'mnasnet_100', - 'mobilenetv2_050', - 'mobilenetv3_large_100', - 'nasnetalarge', - 'pit_b_224', - 'pnasnet5large', - 'regnetx_002', - 'repvgg_a2', - 'res2net50_14w_8s', - 'res2next50', - 'resmlp_12_224', - 'resnest14d', - 'rexnet_100', - 'selecsls42b', - 'semnasnet_075', - 'tinynet_a', - 'tnt_s_patch16_224', - 'tresnet_l', - 'twins_pcpvt_base', - 'visformer_small', - 'xception', - 'xcit_large_24_p8_224', - 'ghostnet_100', - 'gmlp_s16_224', - 'lambda_resnet26rpt_256', - 'spnasnet_100', -] +models = ['resnet50'] -decoder = ops.image_decode() -data = decoder('./towhee.jpeg') +atol = 1e-3 +log_path = 'timm_onnx.log' +f = open('onnx.csv', 'w+') +f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') +logger = logging.getLogger('timm_onnx') +logger.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +fh = logging.FileHandler(log_path) +fh.setLevel(logging.DEBUG) +fh.setFormatter(formatter) +logger.addHandler(fh) +ch = logging.StreamHandler() +ch.setLevel(logging.ERROR) +ch.setFormatter(formatter) +logger.addHandler(ch) + +logger.debug(f'machine: {platform.platform()}-{platform.processor()}') +logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') +logger.debug(f'cpu: {psutil.cpu_count()}') + + +status = None for name in models: - f.write(f'{name},') + logger.info(f'***{name}***') + saved_name = name.replace('/', '-') + onnx_path = f'saved/onnx/{saved_name}.onnx' + if status: + f.write(','.join(status) + '\n') + status = [name] + ['fail'] * 5 + + op = TimmImage(model_name=name, device='cpu') + data = torch.rand((1,) + op.config['input_size']) try: - op = TimmImage(model_name=name) - out1 = op(data) - f.write('success,') + out1 = op.model(data).detach().numpy() + logger.info('OP LOADED.') + status[1] = 'success' except Exception as e: - f.write('fail\n') - print(f'Fail to load op for {name}: {e}') + logger.error(f'FAIL TO LOAD OP: {e}') continue try: op.save_model(format='onnx') - f.write('success,') + logger.info('ONNX SAVED.') + status[2] = 'success' + except Exception as e: + logger.error(f'FAIL TO SAVE ONNX: {e}') + continue + try: + try: + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + except Exception: + saved_onnx = onnx.load(onnx_path, load_external_data=False) + onnx.checker.check_model(saved_onnx) + logger.info('ONNX CHECKED.') + status[3] = 'success' except Exception as e: - f.write('fail\n') - print(f'Fail to save onnx for {name}: {e}') + logger.error(f'FAIL TO CHECK ONNX: {e}') continue try: - saved_name = name.replace('/', '-') - onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') - onnx.checker.check_model(onnx_model) - f.write('success') + sess = onnxruntime.InferenceSession(onnx_path, + providers=onnxruntime.get_available_providers()) + + out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()}) + logger.info('ONNX WORKED.') + status[4] = 'success' + if numpy.allclose(out1, out2, atol=atol): + logger.info('Check accuracy: OK') + status[5] = 'success' + else: + logger.info(f'Check accuracy: atol is larger than {atol}.') except Exception as e: - f.write('fail\n') - print(f'Fail to check onnx for {name}: {e}') + logger.error(f'FAIL TO RUN ONNX: {e}') continue - f.write('\n') -print('Finished.') + +if status: + f.write(','.join(status) + '\n') + +print('Finished.') \ No newline at end of file diff --git a/timm_image.py b/timm_image.py index 3070762..c8fd104 100644 --- a/timm_image.py +++ b/timm_image.py @@ -37,7 +37,7 @@ from timm.models.factory import create_model import warnings warnings.filterwarnings('ignore') -log = logging.getLogger() +log = logging.getLogger('timm_op') @register(output_schema=['vec']) @@ -53,13 +53,22 @@ class TimmImage(NNOperator): Whether skip image transforms. """ - def __init__(self, model_name: str, num_classes: int = 1000, skip_preprocess: bool = False) -> None: + def __init__(self, + model_name: str = 'resnet50', + num_classes: int = 1000, + skip_preprocess: bool = False, + pretrained: bool = True, + device: str = None + ) -> None: super().__init__() - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device self.model_name = model_name - self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes) - self.model.to(self.device) + self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes) self.model.eval() + self.model.to(self.device) + self.config = resolve_data_config({}, model=self.model) self.tfms = create_transform(**self.config) self.skip_tfms = skip_preprocess @@ -120,9 +129,15 @@ class TimmImage(NNOperator): torch.onnx.export(self.model, dummy_input, path, - input_names=["input"], - output_names=["output"], - opset_version=12) + input_names=['input_0'], + output_names=['output_0'], + opset_version=13, + dynamic_axes={ + 'input_0': {0: 'batch_size'}, + 'output_0': {0: 'batch_size'} + }, + do_constant_folding=True + ) except Exception as e: log.error(f'Fail to save as onnx: {e}.') raise RuntimeError(f'Fail to save as onnx: {e}.') @@ -140,6 +155,10 @@ class TimmImage(NNOperator): to_remove = [] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) + elif format == 'onnx': + to_remove = [] + assert set(to_remove).issubset(set(full_list)) + model_list = list(set(full_list) - set(to_remove)) # todo: elif format == 'torchscript': # todo: elif format == 'tensorrt' else: