transformers
copied
Jael Gu
2 years ago
2 changed files with 66 additions and 27 deletions
@ -1,43 +1,79 @@ |
|||||
from auto_transformers import AutoTransformers |
from auto_transformers import AutoTransformers |
||||
import onnx |
|
||||
import warnings |
|
||||
|
|
||||
warnings.filterwarnings('ignore') |
|
||||
|
import numpy |
||||
|
import onnx |
||||
|
import onnxruntime |
||||
|
|
||||
f = open('onnx.csv', 'a+') |
|
||||
f.write('model_name, run_op, save_onnx, check_onnx\n') |
|
||||
|
from pathlib import Path |
||||
|
import warnings |
||||
|
import logging |
||||
|
|
||||
# full_models = AutoTransformers.supported_model_names() |
# full_models = AutoTransformers.supported_model_names() |
||||
# checked_models = AutoTransformers.supported_model_names(format='onnx') |
# checked_models = AutoTransformers.supported_model_names(format='onnx') |
||||
# models = [x for x in full_models if x not in checked_models] |
# models = [x for x in full_models if x not in checked_models] |
||||
models = ['funnel-transformer/large', 'funnel-transformer/medium', 'funnel-transformer/small', 'funnel-transformer/xlarge'] |
|
||||
|
models = ['distilbert-base-cased'] |
||||
|
|
||||
|
warnings.filterwarnings('ignore') |
||||
|
|
||||
|
log_path = './transformers_onnx.log' |
||||
|
|
||||
|
logger = logging.getLogger() |
||||
|
|
||||
|
if Path(log_path).exists: |
||||
|
handler = logging.FileHandler(log_path) |
||||
|
else: |
||||
|
handler = logging.StreamHandler() |
||||
|
logger.handlers.append(handler) |
||||
|
|
||||
|
logging.basicConfig(filename=log_path, |
||||
|
filemode='a', |
||||
|
format='%(asctime)s, %(msecs)d %(name)s %(levelname)s %(message)s', |
||||
|
datefmt='%H:%M:%S', |
||||
|
level=logging.info) |
||||
|
|
||||
|
|
||||
|
logging.info("Test") |
||||
|
|
||||
for name in models: |
for name in models: |
||||
f.write(f'{name},') |
|
||||
|
print(f'Model {name}:') |
||||
|
saved_name = name.replace('/', '-') |
||||
|
logging.info(f'\nConverting model {name} to {saved_name}:') |
||||
try: |
try: |
||||
op = AutoTransformers(model_name=name) |
op = AutoTransformers(model_name=name) |
||||
out1 = op('hello, world.') |
out1 = op('hello, world.') |
||||
f.write('success,') |
|
||||
|
logging.info('OP LOADED.') |
||||
except Exception as e: |
except Exception as e: |
||||
f.write('fail') |
|
||||
print(f'Fail to load op for {name}: {e}') |
|
||||
pass |
|
||||
|
logging.error(f'FAIL TO LOAD OP: {e}') |
||||
|
continue |
||||
try: |
try: |
||||
op.save_model(format='onnx') |
op.save_model(format='onnx') |
||||
f.write('success,') |
|
||||
|
logging.info('ONNX SAVED.') |
||||
|
except Exception as e: |
||||
|
logging.info(f'FAIL TO SAVE ONNX: {e}') |
||||
|
continue |
||||
|
try: |
||||
|
try: |
||||
|
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') |
||||
|
onnx.checker.check_model(onnx_model) |
||||
|
except Exception: |
||||
|
saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False) |
||||
|
onnx.checker.check_model(saved_onnx) |
||||
|
logging.info('ONNX CHECKED.') |
||||
except Exception as e: |
except Exception as e: |
||||
f.write('fail') |
|
||||
print(f'Fail to save onnx for {name}: {e}') |
|
||||
pass |
|
||||
|
logging.error(f'FAIL TO CHECK ONNX: {e}') |
||||
|
continue |
||||
try: |
try: |
||||
saved_name = name.replace('/', '-') |
|
||||
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False) |
|
||||
onnx.checker.check_model(onnx_model) |
|
||||
f.write('success') |
|
||||
|
sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx', |
||||
|
providers=onnxruntime.get_available_providers()) |
||||
|
inputs = op.tokenizer('hello, world.', return_tensors='np') |
||||
|
out2 = sess.run(output_names=["last_hidden_state"], input_feed=dict(inputs)) |
||||
|
logging.info('ONNX WORKED.') |
||||
|
if numpy.allclose(out1, out2, atol=1e-3): |
||||
|
logging.info('Check accuracy: OK') |
||||
|
else: |
||||
|
logging.info('Check accuracy: atol is larger than 1e-3.') |
||||
except Exception as e: |
except Exception as e: |
||||
f.write('fail') |
|
||||
print(f'Fail to check onnx for {name}: {e}') |
|
||||
pass |
|
||||
|
|
||||
f.write('\n') |
|
||||
|
logging.error(f'FAIL TO RUN ONNX: {e}') |
||||
|
continue |
||||
|
|
||||
print('Finished.') |
print('Finished.') |
||||
|
Loading…
Reference in new issue