transformers
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
167 lines
5.7 KiB
167 lines
5.7 KiB
from towhee import ops
|
|
import torch
|
|
import numpy
|
|
import onnx
|
|
import onnxruntime
|
|
|
|
import os
|
|
from pathlib import Path
|
|
import logging
|
|
import platform
|
|
import psutil
|
|
|
|
# full_models = AutoTransformers.supported_model_names()
|
|
# checked_models = AutoTransformers.supported_model_names(format='onnx')
|
|
# models = [x for x in full_models if x not in checked_models]
|
|
models = [
|
|
'allenai/led-base-16384',
|
|
'cl-tohoku/bert-base-japanese-char',
|
|
'cl-tohoku/bert-base-japanese-char-whole-word-masking',
|
|
'cl-tohoku/bert-base-japanese-whole-word-masking',
|
|
'ctrl',
|
|
'facebook/wmt19-ru-en',
|
|
'funnel-transformer/intermediate',
|
|
'funnel-transformer/intermediate-base',
|
|
'funnel-transformer/large',
|
|
'funnel-transformer/large-base',
|
|
'funnel-transformer/medium',
|
|
'funnel-transformer/medium-base',
|
|
'funnel-transformer/small',
|
|
'funnel-transformer/small-base',
|
|
'funnel-transformer/xlarge',
|
|
'funnel-transformer/xlarge-base',
|
|
'google/bert_for_seq_generation_L-24_bbc_encoder',
|
|
'google/canine-c',
|
|
'google/canine-s',
|
|
'google/fnet-base',
|
|
'google/fnet-large',
|
|
'google/reformer-crime-and-punishment',
|
|
'microsoft/mpnet-base',
|
|
'openai-gpt',
|
|
'tau/splinter-base',
|
|
'tau/splinter-base-qass',
|
|
'tau/splinter-large',
|
|
'tau/splinter-large-qass',
|
|
'transfo-xl-wt103',
|
|
'uw-madison/nystromformer-512',
|
|
'uw-madison/yoso-4096',
|
|
'xlnet-base-cased',
|
|
'xlnet-large-cased'
|
|
]
|
|
test_txt = 'hello, world.'
|
|
atol = 1e-3
|
|
log_path = 'transformers_onnx.log'
|
|
f = open('onnx.csv', 'w+')
|
|
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n')
|
|
|
|
logger = logging.getLogger('transformers_onnx')
|
|
logger.setLevel(logging.DEBUG)
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
fh = logging.FileHandler(log_path)
|
|
fh.setLevel(logging.DEBUG)
|
|
fh.setFormatter(formatter)
|
|
logger.addHandler(fh)
|
|
ch = logging.StreamHandler()
|
|
ch.setLevel(logging.ERROR)
|
|
ch.setFormatter(formatter)
|
|
logger.addHandler(ch)
|
|
|
|
logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
|
|
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}'
|
|
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}'
|
|
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB')
|
|
logger.debug(f'cpu: {psutil.cpu_count()}')
|
|
|
|
|
|
status = None
|
|
for name in models:
|
|
logger.info(f'***{name}***')
|
|
saved_name = name.replace('/', '-')
|
|
onnx_path = f'saved/onnx/{saved_name}.onnx'
|
|
if status:
|
|
f.write(','.join(status) + '\n')
|
|
status = [name] + ['fail'] * 5
|
|
try:
|
|
op = ops.text_embedding.transformers(model_name=name, device='cpu').get_op()
|
|
out1 = op(test_txt)
|
|
logger.info('OP LOADED.')
|
|
status[1] = 'success'
|
|
except Exception as e:
|
|
logger.error(f'FAIL TO LOAD OP: {e}')
|
|
continue
|
|
try:
|
|
inputs = op.tokenizer(test_txt, return_tensors='pt')
|
|
input_names = list(inputs.keys())
|
|
dynamic_axes = {}
|
|
for i_n in input_names:
|
|
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
|
|
try:
|
|
output_names = ['last_hidden_state']
|
|
for o_n in output_names:
|
|
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
|
|
torch.onnx.export(
|
|
op.model,
|
|
tuple(inputs.values()),
|
|
onnx_path,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
dynamic_axes=dynamic_axes,
|
|
opset_version=14,
|
|
do_constant_folding=True,
|
|
)
|
|
except Exception as e:
|
|
print(e, '\nTrying with 2 outputs...')
|
|
output_names = ['last_hidden_state', 'pooler_output']
|
|
for o_n in output_names:
|
|
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
|
|
torch.onnx.export(
|
|
op.model,
|
|
tuple(inputs.values()),
|
|
onnx_path,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
dynamic_axes=dynamic_axes,
|
|
opset_version=14,
|
|
do_constant_folding=True,
|
|
)
|
|
logger.info('ONNX SAVED.')
|
|
status[2] = 'success'
|
|
except Exception as e:
|
|
logger.error(f'FAIL TO SAVE ONNX: {e}')
|
|
continue
|
|
try:
|
|
try:
|
|
onnx_model = onnx.load(onnx_path)
|
|
onnx.checker.check_model(onnx_model)
|
|
except Exception:
|
|
saved_onnx = onnx.load(onnx_path, load_external_data=True)
|
|
onnx.checker.check_model(saved_onnx)
|
|
logger.info('ONNX CHECKED.')
|
|
status[3] = 'success'
|
|
except Exception as e:
|
|
logger.error(f'FAIL TO CHECK ONNX: {e}')
|
|
pass
|
|
try:
|
|
sess = onnxruntime.InferenceSession(onnx_path,
|
|
providers=onnxruntime.get_available_providers())
|
|
inputs = op.tokenizer(test_txt, return_tensors='np')
|
|
onnx_inputs = [x.name for x in sess.get_inputs()]
|
|
new_inputs = {}
|
|
for k in onnx_inputs:
|
|
new_inputs[k] = inputs[k]
|
|
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs))
|
|
logger.info('ONNX WORKED.')
|
|
status[4] = 'success'
|
|
if numpy.allclose(out1, out2, atol=atol):
|
|
logger.info('Check accuracy: OK')
|
|
status[5] = 'success'
|
|
else:
|
|
logger.info(f'Check accuracy: atol is larger than {atol}.')
|
|
except Exception as e:
|
|
logger.error(f'FAIL TO RUN ONNX: {e}')
|
|
continue
|
|
|
|
if status:
|
|
f.write(','.join(status) + '\n')
|
|
|
|
print('Finished.')
|