Browse Source
Replace with transformers.onnx
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
2 changed files with
13 additions and
9 deletions
-
auto_transformers.py
-
test_onnx.py
|
|
@ -106,15 +106,19 @@ class AutoTransformers(NNOperator): |
|
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
elif format == 'onnx': |
|
|
|
from transformers.convert_graph_to_onnx import convert |
|
|
|
from transformers.onnx.features import FeaturesManager |
|
|
|
from transformers.onnx import export |
|
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
|
self.model, feature='default') |
|
|
|
onnx_config = model_onnx_config(self.model.config) |
|
|
|
if os.path.isdir(path): |
|
|
|
shutil.rmtree(path) |
|
|
|
path = os.path.join(path, 'model.onnx') |
|
|
|
convert( |
|
|
|
model=self.model_name, |
|
|
|
output=Path(path), |
|
|
|
framework='pt', |
|
|
|
opset=13 |
|
|
|
onnx_inputs, onnx_outputs = export( |
|
|
|
self.tokenizer, |
|
|
|
self.model, |
|
|
|
config=onnx_config, |
|
|
|
opset=13, |
|
|
|
output=Path(path+'.onnx') |
|
|
|
) |
|
|
|
# todo: elif format == 'tensorrt': |
|
|
|
else: |
|
|
|
|
|
@ -43,7 +43,7 @@ status = None |
|
|
|
for name in models: |
|
|
|
logger.info(f'***{name}***') |
|
|
|
saved_name = name.replace('/', '-') |
|
|
|
onnx_path = f'saved/onnx/{saved_name}/model.onnx' |
|
|
|
onnx_path = f'saved/onnx/{saved_name}.onnx' |
|
|
|
if status: |
|
|
|
f.write(','.join(status) + '\n') |
|
|
|
status = [name] + ['fail'] * 5 |
|
|
@ -78,7 +78,7 @@ for name in models: |
|
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
|
providers=onnxruntime.get_available_providers()) |
|
|
|
inputs = op.tokenizer(test_txt, return_tensors='np') |
|
|
|
out2 = sess.run(output_names=['output_0'], input_feed=dict(inputs)) |
|
|
|
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs)) |
|
|
|
logger.info('ONNX WORKED.') |
|
|
|
status[4] = 'success' |
|
|
|
if numpy.allclose(out1, out2, atol=atol): |
|
|
|