logo
Browse Source

Update for onnx test

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
73fc63b013
  1. 9
      auto_transformers.py
  2. 84
      test_onnx.py

9
auto_transformers.py

@ -108,15 +108,18 @@ class AutoTransformers(NNOperator):
torch.onnx.export(self.model,
tuple(inputs.values()),
path,
input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys())
input_names=list(inputs.keys()),
output_names=["last_hidden_state"],
opset_version=13,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "input_length"},
"token_type_ids": {0: "batch_size", 1: "input_length"},
"attention_mask": {0: "batch_size", 1: "input_length"},
"last_hidden_state": {0: "batch_size"},
})
},
opset_version=13,
do_constant_folding=True,
enable_onnx_checker=True,
)
except Exception as e:
print(e, '\nTrying with 2 outputs...')
torch.onnx.export(self.model,

84
test_onnx.py

@ -1,43 +1,79 @@
from auto_transformers import AutoTransformers
import onnx
import warnings
warnings.filterwarnings('ignore')
import numpy
import onnx
import onnxruntime
f = open('onnx.csv', 'a+')
f.write('model_name, run_op, save_onnx, check_onnx\n')
from pathlib import Path
import warnings
import logging
# full_models = AutoTransformers.supported_model_names()
# checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models]
models = ['funnel-transformer/large', 'funnel-transformer/medium', 'funnel-transformer/small', 'funnel-transformer/xlarge']
models = ['distilbert-base-cased']
warnings.filterwarnings('ignore')
log_path = './transformers_onnx.log'
logger = logging.getLogger()
if Path(log_path).exists:
handler = logging.FileHandler(log_path)
else:
handler = logging.StreamHandler()
logger.handlers.append(handler)
logging.basicConfig(filename=log_path,
filemode='a',
format='%(asctime)s, %(msecs)d %(name)s %(levelname)s %(message)s',
datefmt='%H:%M:%S',
level=logging.info)
logging.info("Test")
for name in models:
f.write(f'{name},')
print(f'Model {name}:')
saved_name = name.replace('/', '-')
logging.info(f'\nConverting model {name} to {saved_name}:')
try:
op = AutoTransformers(model_name=name)
out1 = op('hello, world.')
f.write('success,')
logging.info('OP LOADED.')
except Exception as e:
f.write('fail')
print(f'Fail to load op for {name}: {e}')
pass
logging.error(f'FAIL TO LOAD OP: {e}')
continue
try:
op.save_model(format='onnx')
f.write('success,')
logging.info('ONNX SAVED.')
except Exception as e:
logging.info(f'FAIL TO SAVE ONNX: {e}')
continue
try:
try:
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx')
onnx.checker.check_model(onnx_model)
except Exception:
saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False)
onnx.checker.check_model(saved_onnx)
logging.info('ONNX CHECKED.')
except Exception as e:
f.write('fail')
print(f'Fail to save onnx for {name}: {e}')
pass
logging.error(f'FAIL TO CHECK ONNX: {e}')
continue
try:
saved_name = name.replace('/', '-')
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False)
onnx.checker.check_model(onnx_model)
f.write('success')
sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx',
providers=onnxruntime.get_available_providers())
inputs = op.tokenizer('hello, world.', return_tensors='np')
out2 = sess.run(output_names=["last_hidden_state"], input_feed=dict(inputs))
logging.info('ONNX WORKED.')
if numpy.allclose(out1, out2, atol=1e-3):
logging.info('Check accuracy: OK')
else:
logging.info('Check accuracy: atol is larger than 1e-3.')
except Exception as e:
f.write('fail')
print(f'Fail to check onnx for {name}: {e}')
pass
f.write('\n')
logging.error(f'FAIL TO RUN ONNX: {e}')
continue
print('Finished.')

Loading…
Cancel
Save