|
@ -13,9 +13,12 @@ import psutil |
|
|
# full_models = AutoTransformers.supported_model_names() |
|
|
# full_models = AutoTransformers.supported_model_names() |
|
|
# checked_models = AutoTransformers.supported_model_names(format='onnx') |
|
|
# checked_models = AutoTransformers.supported_model_names(format='onnx') |
|
|
# models = [x for x in full_models if x not in checked_models] |
|
|
# models = [x for x in full_models if x not in checked_models] |
|
|
models = ['distilbert-base-cased'] |
|
|
|
|
|
|
|
|
models = ['bert-base-cased', 'distilbert-base-cased'] |
|
|
|
|
|
test_txt = '[UNK]' |
|
|
atol = 1e-3 |
|
|
atol = 1e-3 |
|
|
log_path = 'transformers_onnx.log' |
|
|
log_path = 'transformers_onnx.log' |
|
|
|
|
|
f = open('onnx.csv', 'w+') |
|
|
|
|
|
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') |
|
|
|
|
|
|
|
|
logger = logging.getLogger('transformers_onnx') |
|
|
logger = logging.getLogger('transformers_onnx') |
|
|
logger.setLevel(logging.DEBUG) |
|
|
logger.setLevel(logging.DEBUG) |
|
@ -30,24 +33,31 @@ ch.setFormatter(formatter) |
|
|
logger.addHandler(ch) |
|
|
logger.addHandler(ch) |
|
|
|
|
|
|
|
|
logger.debug(f'machine: {platform.platform()}-{platform.processor()}') |
|
|
logger.debug(f'machine: {platform.platform()}-{platform.processor()}') |
|
|
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 **3))}' |
|
|
|
|
|
f'/{round(psutil.virtual_memory().available / (1024.0 **3))}' |
|
|
|
|
|
f'/{round(psutil.virtual_memory().total / (1024.0 **3))} GB') |
|
|
|
|
|
|
|
|
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' |
|
|
|
|
|
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' |
|
|
|
|
|
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') |
|
|
logger.debug(f'cpu: {psutil.cpu_count()}') |
|
|
logger.debug(f'cpu: {psutil.cpu_count()}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
status = None |
|
|
for name in models: |
|
|
for name in models: |
|
|
logger.info(f'***{name}***') |
|
|
logger.info(f'***{name}***') |
|
|
saved_name = name.replace('/', '-') |
|
|
saved_name = name.replace('/', '-') |
|
|
|
|
|
if status: |
|
|
|
|
|
f.write(','.join(status) + '\n') |
|
|
|
|
|
status = [name] + ['fail'] * 5 |
|
|
try: |
|
|
try: |
|
|
op = AutoTransformers(model_name=name) |
|
|
op = AutoTransformers(model_name=name) |
|
|
out1 = op('hello, world.') |
|
|
|
|
|
|
|
|
out1 = op(test_txt) |
|
|
logger.info('OP LOADED.') |
|
|
logger.info('OP LOADED.') |
|
|
|
|
|
status[1] = 'success' |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
logger.error(f'FAIL TO LOAD OP: {e}') |
|
|
logger.error(f'FAIL TO LOAD OP: {e}') |
|
|
continue |
|
|
continue |
|
|
try: |
|
|
try: |
|
|
op.save_model(format='onnx') |
|
|
op.save_model(format='onnx') |
|
|
logger.info('ONNX SAVED.') |
|
|
logger.info('ONNX SAVED.') |
|
|
|
|
|
status[2] = 'success' |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
logger.error(f'FAIL TO SAVE ONNX: {e}') |
|
|
logger.error(f'FAIL TO SAVE ONNX: {e}') |
|
|
continue |
|
|
continue |
|
@ -59,21 +69,27 @@ for name in models: |
|
|
saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False) |
|
|
saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False) |
|
|
onnx.checker.check_model(saved_onnx) |
|
|
onnx.checker.check_model(saved_onnx) |
|
|
logger.info('ONNX CHECKED.') |
|
|
logger.info('ONNX CHECKED.') |
|
|
|
|
|
status[3] = 'success' |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
logger.error(f'FAIL TO CHECK ONNX: {e}') |
|
|
logger.error(f'FAIL TO CHECK ONNX: {e}') |
|
|
continue |
|
|
continue |
|
|
try: |
|
|
try: |
|
|
sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx', |
|
|
sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx', |
|
|
providers=onnxruntime.get_available_providers()) |
|
|
providers=onnxruntime.get_available_providers()) |
|
|
inputs = op.tokenizer('hello, world.', return_tensors='np') |
|
|
|
|
|
out2 = sess.run(output_names=["last_hidden_state"], input_feed=dict(inputs)) |
|
|
|
|
|
|
|
|
inputs = op.tokenizer(test_txt, return_tensors='np') |
|
|
|
|
|
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs)) |
|
|
logger.info('ONNX WORKED.') |
|
|
logger.info('ONNX WORKED.') |
|
|
|
|
|
status[4] = 'success' |
|
|
if numpy.allclose(out1, out2, atol=atol): |
|
|
if numpy.allclose(out1, out2, atol=atol): |
|
|
logger.info('Check accuracy: OK') |
|
|
logger.info('Check accuracy: OK') |
|
|
|
|
|
status[5] = 'success' |
|
|
else: |
|
|
else: |
|
|
logger.info(f'Check accuracy: atol is larger than {atol}.') |
|
|
logger.info(f'Check accuracy: atol is larger than {atol}.') |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
logger.error(f'FAIL TO RUN ONNX: {e}') |
|
|
logger.error(f'FAIL TO RUN ONNX: {e}') |
|
|
continue |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
if status: |
|
|
|
|
|
f.write(','.join(status) + '\n') |
|
|
|
|
|
|
|
|
print('Finished.') |
|
|
print('Finished.') |
|
|