|
|
|
from auto_transformers import AutoTransformers
|
|
|
|
|
|
|
|
import numpy
|
|
|
|
import onnx
|
|
|
|
import onnxruntime
|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
import warnings
|
|
|
|
import logging
|
|
|
|
|
|
|
|
# full_models = AutoTransformers.supported_model_names()
|
|
|
|
# checked_models = AutoTransformers.supported_model_names(format='onnx')
|
|
|
|
# models = [x for x in full_models if x not in checked_models]
|
|
|
|
models = ['distilbert-base-cased']
|
|
|
|
|
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
|
|
|
log_path = './transformers_onnx.log'
|
|
|
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
|
|
if Path(log_path).exists:
|
|
|
|
handler = logging.FileHandler(log_path)
|
|
|
|
else:
|
|
|
|
handler = logging.StreamHandler()
|
|
|
|
logger.handlers.append(handler)
|
|
|
|
|
|
|
|
logging.basicConfig(filename=log_path,
|
|
|
|
filemode='a',
|
|
|
|
format='%(asctime)s, %(msecs)d %(name)s %(levelname)s %(message)s',
|
|
|
|
datefmt='%H:%M:%S',
|
|
|
|
level=logging.info)
|
|
|
|
|
|
|
|
|
|
|
|
logging.info("Test")
|
|
|
|
|
|
|
|
for name in models:
|
|
|
|
print(f'Model {name}:')
|
|
|
|
saved_name = name.replace('/', '-')
|
|
|
|
logging.info(f'\nConverting model {name} to {saved_name}:')
|
|
|
|
try:
|
|
|
|
op = AutoTransformers(model_name=name)
|
|
|
|
out1 = op('hello, world.')
|
|
|
|
logging.info('OP LOADED.')
|
|
|
|
except Exception as e:
|
|
|
|
logging.error(f'FAIL TO LOAD OP: {e}')
|
|
|
|
continue
|
|
|
|
try:
|
|
|
|
op.save_model(format='onnx')
|
|
|
|
logging.info('ONNX SAVED.')
|
|
|
|
except Exception as e:
|
|
|
|
logging.info(f'FAIL TO SAVE ONNX: {e}')
|
|
|
|
continue
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx')
|
|
|
|
onnx.checker.check_model(onnx_model)
|
|
|
|
except Exception:
|
|
|
|
saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False)
|
|
|
|
onnx.checker.check_model(saved_onnx)
|
|
|
|
logging.info('ONNX CHECKED.')
|
|
|
|
except Exception as e:
|
|
|
|
logging.error(f'FAIL TO CHECK ONNX: {e}')
|
|
|
|
continue
|
|
|
|
try:
|
|
|
|
sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx',
|
|
|
|
providers=onnxruntime.get_available_providers())
|
|
|
|
inputs = op.tokenizer('hello, world.', return_tensors='np')
|
|
|
|
out2 = sess.run(output_names=["last_hidden_state"], input_feed=dict(inputs))
|
|
|
|
logging.info('ONNX WORKED.')
|
|
|
|
if numpy.allclose(out1, out2, atol=1e-3):
|
|
|
|
logging.info('Check accuracy: OK')
|
|
|
|
else:
|
|
|
|
logging.info('Check accuracy: atol is larger than 1e-3.')
|
|
|
|
except Exception as e:
|
|
|
|
logging.error(f'FAIL TO RUN ONNX: {e}')
|
|
|
|
continue
|
|
|
|
|
|
|
|
print('Finished.')
|