logo
Browse Source

Replace with transformers.onnx

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
d838ba8357
  1. 18
      auto_transformers.py
  2. 4
      test_onnx.py

18
auto_transformers.py

@ -106,15 +106,19 @@ class AutoTransformers(NNOperator):
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
from transformers.convert_graph_to_onnx import convert
from transformers.onnx.features import FeaturesManager
from transformers.onnx import export
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self.model, feature='default')
onnx_config = model_onnx_config(self.model.config)
if os.path.isdir(path):
shutil.rmtree(path)
path = os.path.join(path, 'model.onnx')
convert(
model=self.model_name,
output=Path(path),
framework='pt',
opset=13
onnx_inputs, onnx_outputs = export(
self.tokenizer,
self.model,
config=onnx_config,
opset=13,
output=Path(path+'.onnx')
)
# todo: elif format == 'tensorrt':
else:

4
test_onnx.py

@ -43,7 +43,7 @@ status = None
for name in models:
logger.info(f'***{name}***')
saved_name = name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}/model.onnx'
onnx_path = f'saved/onnx/{saved_name}.onnx'
if status:
f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5
@ -78,7 +78,7 @@ for name in models:
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())
inputs = op.tokenizer(test_txt, return_tensors='np')
out2 = sess.run(output_names=['output_0'], input_feed=dict(inputs))
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))
logger.info('ONNX WORKED.')
status[4] = 'success'
if numpy.allclose(out1, out2, atol=atol):

Loading…
Cancel
Save