From 59f19c7f1721afb79c9f724dbb8f93bf1bc0a1ba Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 8 Dec 2022 19:00:11 +0800 Subject: [PATCH] Add test onnx Signed-off-by: Jael Gu --- test_onnx2.py | 167 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 test_onnx2.py diff --git a/test_onnx2.py b/test_onnx2.py new file mode 100644 index 0000000..d8a981e --- /dev/null +++ b/test_onnx2.py @@ -0,0 +1,167 @@ +from auto_transformers import AutoTransformers +import torch +import numpy +import onnx +import onnxruntime + +import os +from pathlib import Path +import logging +import platform +import psutil + +# full_models = AutoTransformers.supported_model_names() +# checked_models = AutoTransformers.supported_model_names(format='onnx') +# models = [x for x in full_models if x not in checked_models] +models = [ + 'allenai/led-base-16384', + 'cl-tohoku/bert-base-japanese-char', + 'cl-tohoku/bert-base-japanese-char-whole-word-masking', + 'cl-tohoku/bert-base-japanese-whole-word-masking', + 'ctrl', + 'facebook/wmt19-ru-en', + 'funnel-transformer/intermediate', + 'funnel-transformer/intermediate-base', + 'funnel-transformer/large', + 'funnel-transformer/large-base', + 'funnel-transformer/medium', + 'funnel-transformer/medium-base', + 'funnel-transformer/small', + 'funnel-transformer/small-base', + 'funnel-transformer/xlarge', + 'funnel-transformer/xlarge-base', + 'google/bert_for_seq_generation_L-24_bbc_encoder', + 'google/canine-c', + 'google/canine-s', + 'google/fnet-base', + 'google/fnet-large', + 'google/reformer-crime-and-punishment', + 'microsoft/mpnet-base', + 'openai-gpt', + 'tau/splinter-base', + 'tau/splinter-base-qass', + 'tau/splinter-large', + 'tau/splinter-large-qass', + 'transfo-xl-wt103', + 'uw-madison/nystromformer-512', + 'uw-madison/yoso-4096', + 'xlnet-base-cased', + 'xlnet-large-cased' +] +test_txt = 'hello, world.' +atol = 1e-3 +log_path = 'transformers_onnx.log' +f = open('onnx.csv', 'w+') +f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') + +logger = logging.getLogger('transformers_onnx') +logger.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +fh = logging.FileHandler(log_path) +fh.setLevel(logging.DEBUG) +fh.setFormatter(formatter) +logger.addHandler(fh) +ch = logging.StreamHandler() +ch.setLevel(logging.ERROR) +ch.setFormatter(formatter) +logger.addHandler(ch) + +logger.debug(f'machine: {platform.platform()}-{platform.processor()}') +logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') +logger.debug(f'cpu: {psutil.cpu_count()}') + + +status = None +for name in models: + logger.info(f'***{name}***') + saved_name = name.replace('/', '-') + onnx_path = f'saved/onnx/{saved_name}.onnx' + if status: + f.write(','.join(status) + '\n') + status = [name] + ['fail'] * 5 + try: + op = AutoTransformers(model_name=name, device='cpu') + out1 = op(test_txt) + logger.info('OP LOADED.') + status[1] = 'success' + except Exception as e: + logger.error(f'FAIL TO LOAD OP: {e}') + continue + try: + inputs = op.tokenizer(test_txt, return_tensors='pt') + input_names = list(inputs.keys()) + dynamic_axes = {} + for i_n in input_names: + dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} + try: + output_names = ['last_hidden_state'] + for o_n in output_names: + dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} + torch.onnx.export( + op.model, + tuple(inputs.values()), + onnx_path, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=14, + do_constant_folding=True, + ) + except Exception as e: + print(e, '\nTrying with 2 outputs...') + output_names = ['last_hidden_state', 'pooler_output'] + for o_n in output_names: + dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} + torch.onnx.export( + op.model, + tuple(inputs.values()), + onnx_path, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=14, + do_constant_folding=True, + ) + logger.info('ONNX SAVED.') + status[2] = 'success' + except Exception as e: + logger.error(f'FAIL TO SAVE ONNX: {e}') + continue + try: + try: + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + except Exception: + saved_onnx = onnx.load(onnx_path, load_external_data=True) + onnx.checker.check_model(saved_onnx) + logger.info('ONNX CHECKED.') + status[3] = 'success' + except Exception as e: + logger.error(f'FAIL TO CHECK ONNX: {e}') + pass + try: + sess = onnxruntime.InferenceSession(onnx_path, + providers=onnxruntime.get_available_providers()) + inputs = op.tokenizer(test_txt, return_tensors='np') + onnx_inputs = [x.name for x in sess.get_inputs()] + new_inputs = {} + for k in onnx_inputs: + new_inputs[k] = inputs[k] + out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) + logger.info('ONNX WORKED.') + status[4] = 'success' + if numpy.allclose(out1, out2, atol=atol): + logger.info('Check accuracy: OK') + status[5] = 'success' + else: + logger.info(f'Check accuracy: atol is larger than {atol}.') + except Exception as e: + logger.error(f'FAIL TO RUN ONNX: {e}') + continue + +if status: + f.write(','.join(status) + '\n') + +print('Finished.')