transformers
copied
Jael Gu
2 years ago
1 changed files with 167 additions and 0 deletions
@ -0,0 +1,167 @@ |
|||||
|
from auto_transformers import AutoTransformers |
||||
|
import torch |
||||
|
import numpy |
||||
|
import onnx |
||||
|
import onnxruntime |
||||
|
|
||||
|
import os |
||||
|
from pathlib import Path |
||||
|
import logging |
||||
|
import platform |
||||
|
import psutil |
||||
|
|
||||
|
# full_models = AutoTransformers.supported_model_names() |
||||
|
# checked_models = AutoTransformers.supported_model_names(format='onnx') |
||||
|
# models = [x for x in full_models if x not in checked_models] |
||||
|
models = [ |
||||
|
'allenai/led-base-16384', |
||||
|
'cl-tohoku/bert-base-japanese-char', |
||||
|
'cl-tohoku/bert-base-japanese-char-whole-word-masking', |
||||
|
'cl-tohoku/bert-base-japanese-whole-word-masking', |
||||
|
'ctrl', |
||||
|
'facebook/wmt19-ru-en', |
||||
|
'funnel-transformer/intermediate', |
||||
|
'funnel-transformer/intermediate-base', |
||||
|
'funnel-transformer/large', |
||||
|
'funnel-transformer/large-base', |
||||
|
'funnel-transformer/medium', |
||||
|
'funnel-transformer/medium-base', |
||||
|
'funnel-transformer/small', |
||||
|
'funnel-transformer/small-base', |
||||
|
'funnel-transformer/xlarge', |
||||
|
'funnel-transformer/xlarge-base', |
||||
|
'google/bert_for_seq_generation_L-24_bbc_encoder', |
||||
|
'google/canine-c', |
||||
|
'google/canine-s', |
||||
|
'google/fnet-base', |
||||
|
'google/fnet-large', |
||||
|
'google/reformer-crime-and-punishment', |
||||
|
'microsoft/mpnet-base', |
||||
|
'openai-gpt', |
||||
|
'tau/splinter-base', |
||||
|
'tau/splinter-base-qass', |
||||
|
'tau/splinter-large', |
||||
|
'tau/splinter-large-qass', |
||||
|
'transfo-xl-wt103', |
||||
|
'uw-madison/nystromformer-512', |
||||
|
'uw-madison/yoso-4096', |
||||
|
'xlnet-base-cased', |
||||
|
'xlnet-large-cased' |
||||
|
] |
||||
|
test_txt = 'hello, world.' |
||||
|
atol = 1e-3 |
||||
|
log_path = 'transformers_onnx.log' |
||||
|
f = open('onnx.csv', 'w+') |
||||
|
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') |
||||
|
|
||||
|
logger = logging.getLogger('transformers_onnx') |
||||
|
logger.setLevel(logging.DEBUG) |
||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
||||
|
fh = logging.FileHandler(log_path) |
||||
|
fh.setLevel(logging.DEBUG) |
||||
|
fh.setFormatter(formatter) |
||||
|
logger.addHandler(fh) |
||||
|
ch = logging.StreamHandler() |
||||
|
ch.setLevel(logging.ERROR) |
||||
|
ch.setFormatter(formatter) |
||||
|
logger.addHandler(ch) |
||||
|
|
||||
|
logger.debug(f'machine: {platform.platform()}-{platform.processor()}') |
||||
|
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' |
||||
|
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' |
||||
|
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') |
||||
|
logger.debug(f'cpu: {psutil.cpu_count()}') |
||||
|
|
||||
|
|
||||
|
status = None |
||||
|
for name in models: |
||||
|
logger.info(f'***{name}***') |
||||
|
saved_name = name.replace('/', '-') |
||||
|
onnx_path = f'saved/onnx/{saved_name}.onnx' |
||||
|
if status: |
||||
|
f.write(','.join(status) + '\n') |
||||
|
status = [name] + ['fail'] * 5 |
||||
|
try: |
||||
|
op = AutoTransformers(model_name=name, device='cpu') |
||||
|
out1 = op(test_txt) |
||||
|
logger.info('OP LOADED.') |
||||
|
status[1] = 'success' |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO LOAD OP: {e}') |
||||
|
continue |
||||
|
try: |
||||
|
inputs = op.tokenizer(test_txt, return_tensors='pt') |
||||
|
input_names = list(inputs.keys()) |
||||
|
dynamic_axes = {} |
||||
|
for i_n in input_names: |
||||
|
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} |
||||
|
try: |
||||
|
output_names = ['last_hidden_state'] |
||||
|
for o_n in output_names: |
||||
|
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} |
||||
|
torch.onnx.export( |
||||
|
op.model, |
||||
|
tuple(inputs.values()), |
||||
|
onnx_path, |
||||
|
input_names=input_names, |
||||
|
output_names=output_names, |
||||
|
dynamic_axes=dynamic_axes, |
||||
|
opset_version=14, |
||||
|
do_constant_folding=True, |
||||
|
) |
||||
|
except Exception as e: |
||||
|
print(e, '\nTrying with 2 outputs...') |
||||
|
output_names = ['last_hidden_state', 'pooler_output'] |
||||
|
for o_n in output_names: |
||||
|
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} |
||||
|
torch.onnx.export( |
||||
|
op.model, |
||||
|
tuple(inputs.values()), |
||||
|
onnx_path, |
||||
|
input_names=input_names, |
||||
|
output_names=output_names, |
||||
|
dynamic_axes=dynamic_axes, |
||||
|
opset_version=14, |
||||
|
do_constant_folding=True, |
||||
|
) |
||||
|
logger.info('ONNX SAVED.') |
||||
|
status[2] = 'success' |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO SAVE ONNX: {e}') |
||||
|
continue |
||||
|
try: |
||||
|
try: |
||||
|
onnx_model = onnx.load(onnx_path) |
||||
|
onnx.checker.check_model(onnx_model) |
||||
|
except Exception: |
||||
|
saved_onnx = onnx.load(onnx_path, load_external_data=True) |
||||
|
onnx.checker.check_model(saved_onnx) |
||||
|
logger.info('ONNX CHECKED.') |
||||
|
status[3] = 'success' |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO CHECK ONNX: {e}') |
||||
|
pass |
||||
|
try: |
||||
|
sess = onnxruntime.InferenceSession(onnx_path, |
||||
|
providers=onnxruntime.get_available_providers()) |
||||
|
inputs = op.tokenizer(test_txt, return_tensors='np') |
||||
|
onnx_inputs = [x.name for x in sess.get_inputs()] |
||||
|
new_inputs = {} |
||||
|
for k in onnx_inputs: |
||||
|
new_inputs[k] = inputs[k] |
||||
|
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) |
||||
|
logger.info('ONNX WORKED.') |
||||
|
status[4] = 'success' |
||||
|
if numpy.allclose(out1, out2, atol=atol): |
||||
|
logger.info('Check accuracy: OK') |
||||
|
status[5] = 'success' |
||||
|
else: |
||||
|
logger.info(f'Check accuracy: atol is larger than {atol}.') |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO RUN ONNX: {e}') |
||||
|
continue |
||||
|
|
||||
|
if status: |
||||
|
f.write(','.join(status) + '\n') |
||||
|
|
||||
|
print('Finished.') |
Loading…
Reference in new issue