logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

168 lines
5.7 KiB

from towhee import ops
import torch
import numpy
import onnx
import onnxruntime
import os
from pathlib import Path
import logging
import platform
import psutil
# full_models = AutoTransformers.supported_model_names()
# checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models]
models = [
'allenai/led-base-16384',
'cl-tohoku/bert-base-japanese-char',
'cl-tohoku/bert-base-japanese-char-whole-word-masking',
'cl-tohoku/bert-base-japanese-whole-word-masking',
'ctrl',
'facebook/wmt19-ru-en',
'funnel-transformer/intermediate',
'funnel-transformer/intermediate-base',
'funnel-transformer/large',
'funnel-transformer/large-base',
'funnel-transformer/medium',
'funnel-transformer/medium-base',
'funnel-transformer/small',
'funnel-transformer/small-base',
'funnel-transformer/xlarge',
'funnel-transformer/xlarge-base',
'google/bert_for_seq_generation_L-24_bbc_encoder',
'google/canine-c',
'google/canine-s',
'google/fnet-base',
'google/fnet-large',
'google/reformer-crime-and-punishment',
'microsoft/mpnet-base',
'openai-gpt',
'tau/splinter-base',
'tau/splinter-base-qass',
'tau/splinter-large',
'tau/splinter-large-qass',
'transfo-xl-wt103',
'uw-madison/nystromformer-512',
'uw-madison/yoso-4096',
'xlnet-base-cased',
'xlnet-large-cased'
]
test_txt = 'hello, world.'
atol = 1e-3
log_path = 'transformers_onnx.log'
f = open('onnx.csv', 'w+')
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n')
logger = logging.getLogger('transformers_onnx')
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_path)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.ERROR)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB')
logger.debug(f'cpu: {psutil.cpu_count()}')
status = None
for name in models:
logger.info(f'***{name}***')
saved_name = name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}.onnx'
if status:
f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5
try:
op = ops.text_embedding.transformers(model_name=name, device='cpu').get_op()
out1 = op(test_txt)
logger.info('OP LOADED.')
status[1] = 'success'
except Exception as e:
logger.error(f'FAIL TO LOAD OP: {e}')
continue
try:
inputs = op.tokenizer(test_txt, return_tensors='pt')
input_names = list(inputs.keys())
dynamic_axes = {}
for i_n in input_names:
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
try:
output_names = ['last_hidden_state']
for o_n in output_names:
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
torch.onnx.export(
op.model,
tuple(inputs.values()),
onnx_path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=14,
do_constant_folding=True,
)
except Exception as e:
print(e, '\nTrying with 2 outputs...')
output_names = ['last_hidden_state', 'pooler_output']
for o_n in output_names:
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
torch.onnx.export(
op.model,
tuple(inputs.values()),
onnx_path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=14,
do_constant_folding=True,
)
logger.info('ONNX SAVED.')
status[2] = 'success'
except Exception as e:
logger.error(f'FAIL TO SAVE ONNX: {e}')
continue
try:
try:
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
except Exception:
saved_onnx = onnx.load(onnx_path, load_external_data=True)
onnx.checker.check_model(saved_onnx)
logger.info('ONNX CHECKED.')
status[3] = 'success'
except Exception as e:
logger.error(f'FAIL TO CHECK ONNX: {e}')
pass
try:
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())
inputs = op.tokenizer(test_txt, return_tensors='np')
onnx_inputs = [x.name for x in sess.get_inputs()]
new_inputs = {}
for k in onnx_inputs:
new_inputs[k] = inputs[k]
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs))
logger.info('ONNX WORKED.')
status[4] = 'success'
if numpy.allclose(out1, out2, atol=atol):
logger.info('Check accuracy: OK')
status[5] = 'success'
else:
logger.info(f'Check accuracy: atol is larger than {atol}.')
except Exception as e:
logger.error(f'FAIL TO RUN ONNX: {e}')
continue
if status:
f.write(','.join(status) + '\n')
print('Finished.')