from towhee import ops import torch import numpy import onnx import onnxruntime import os from pathlib import Path import logging import platform import psutil # full_models = AutoTransformers.supported_model_names() # checked_models = AutoTransformers.supported_model_names(format='onnx') # models = [x for x in full_models if x not in checked_models] models = [ 'allenai/led-base-16384', 'cl-tohoku/bert-base-japanese-char', 'cl-tohoku/bert-base-japanese-char-whole-word-masking', 'cl-tohoku/bert-base-japanese-whole-word-masking', 'ctrl', 'facebook/wmt19-ru-en', 'funnel-transformer/intermediate', 'funnel-transformer/intermediate-base', 'funnel-transformer/large', 'funnel-transformer/large-base', 'funnel-transformer/medium', 'funnel-transformer/medium-base', 'funnel-transformer/small', 'funnel-transformer/small-base', 'funnel-transformer/xlarge', 'funnel-transformer/xlarge-base', 'google/bert_for_seq_generation_L-24_bbc_encoder', 'google/canine-c', 'google/canine-s', 'google/fnet-base', 'google/fnet-large', 'google/reformer-crime-and-punishment', 'microsoft/mpnet-base', 'openai-gpt', 'tau/splinter-base', 'tau/splinter-base-qass', 'tau/splinter-large', 'tau/splinter-large-qass', 'transfo-xl-wt103', 'uw-madison/nystromformer-512', 'uw-madison/yoso-4096', 'xlnet-base-cased', 'xlnet-large-cased' ] test_txt = 'hello, world.' atol = 1e-3 log_path = 'transformers_onnx.log' f = open('onnx.csv', 'w+') f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') logger = logging.getLogger('transformers_onnx') logger.setLevel(logging.DEBUG) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') fh = logging.FileHandler(log_path) fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) logger.addHandler(fh) ch = logging.StreamHandler() ch.setLevel(logging.ERROR) ch.setFormatter(formatter) logger.addHandler(ch) logger.debug(f'machine: {platform.platform()}-{platform.processor()}') logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') logger.debug(f'cpu: {psutil.cpu_count()}') status = None for name in models: logger.info(f'***{name}***') saved_name = name.replace('/', '-') onnx_path = f'saved/onnx/{saved_name}.onnx' if status: f.write(','.join(status) + '\n') status = [name] + ['fail'] * 5 try: op = ops.text_embedding.transformers(model_name=name, device='cpu').get_op() out1 = op(test_txt) logger.info('OP LOADED.') status[1] = 'success' except Exception as e: logger.error(f'FAIL TO LOAD OP: {e}') continue try: inputs = op.tokenizer(test_txt, return_tensors='pt') input_names = list(inputs.keys()) dynamic_axes = {} for i_n in input_names: dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} try: output_names = ['last_hidden_state'] for o_n in output_names: dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} torch.onnx.export( op.model, tuple(inputs.values()), onnx_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=14, do_constant_folding=True, ) except Exception as e: print(e, '\nTrying with 2 outputs...') output_names = ['last_hidden_state', 'pooler_output'] for o_n in output_names: dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} torch.onnx.export( op.model, tuple(inputs.values()), onnx_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=14, do_constant_folding=True, ) logger.info('ONNX SAVED.') status[2] = 'success' except Exception as e: logger.error(f'FAIL TO SAVE ONNX: {e}') continue try: try: onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) except Exception: saved_onnx = onnx.load(onnx_path, load_external_data=True) onnx.checker.check_model(saved_onnx) logger.info('ONNX CHECKED.') status[3] = 'success' except Exception as e: logger.error(f'FAIL TO CHECK ONNX: {e}') pass try: sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) inputs = op.tokenizer(test_txt, return_tensors='np') onnx_inputs = [x.name for x in sess.get_inputs()] new_inputs = {} for k in onnx_inputs: new_inputs[k] = inputs[k] out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) logger.info('ONNX WORKED.') status[4] = 'success' if numpy.allclose(out1, out2, atol=atol): logger.info('Check accuracy: OK') status[5] = 'success' else: logger.info(f'Check accuracy: atol is larger than {atol}.') except Exception as e: logger.error(f'FAIL TO RUN ONNX: {e}') continue if status: f.write(','.join(status) + '\n') print('Finished.')