logo
Browse Source

Allow to save model as onnx

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
516372f339
  1. 37
      auto_transformers.py
  2. 36
      test_save.py

37
auto_transformers.py

@ -79,10 +79,13 @@ class AutoTransformers(NNOperator):
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = self.model_name.replace('/', '-')
path = os.path.join(path, name)
inputs = self.tokenizer('[CLS]', return_tensors='pt')
inputs = self.tokenizer('[CLS]', return_tensors='pt') # a dictionary
if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
elif format == 'torchscript':
path = path + '.pt'
@ -96,8 +99,36 @@ class AutoTransformers(NNOperator):
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onxx':
pass # todo
elif format == 'onnx':
path = path + '.onnx'
try:
torch.onnx.export(self.model,
tuple(inputs.values()),
path,
input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys())
output_names=["last_hidden_state"],
opset_version=10,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "input_length"},
"token_type_ids": {0: "batch_size", 1: "input_length"},
"attention_mask": {0: "batch_size", 1: "input_length"},
"last_hidden_state": {0: "batch_size"},
})
except Exception:
torch.onnx.export(self.model,
tuple(inputs.values()),
path,
input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys())
output_names=["last_hidden_state", "pooler_output"],
opset_version=10,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "input_length"},
"token_type_ids": {0: "batch_size", 1: "input_length"},
"attention_mask": {0: "batch_size", 1: "input_length"},
"last_hidden_state": {0: "batch_size"},
"pooler_outputs": {0: "batch_size"}
})
else:
log.error(f'Unsupported format "{format}".')

36
test_save.py

@ -1,23 +1,39 @@
from auto_transformers import AutoTransformers
import onnx
import torch
f = open('onnx.csv', 'a+')
f.write('model_name, run op, save_onnx, check_onnx\n')
models = [
'bert-base-cased',
'distilbert-base-cased',
'distilgpt2',
'google/fnet-base'
'distilbert-base-cased'
]
for name in models:
line = f'{name}, '
try:
op = AutoTransformers(model_name=name)
out1 = op('hello, world.')
op.save_model(format='torchscript')
op.model = torch.jit.load(name.replace('/', '-') + '.pt')
out2 = op('hello, world.')
assert (out1 == out2).all()
print(f'[SUCCESS] Saved torchscript for model "{name}"')
line += 'success, '
except Exception as e:
print(f'[ERROR] Fail for model "{name}": {e}.')
line += 'fail, '
print(f'Fail to load op for {name}: {e}.')
continue
try:
op.save_model(format='onnx')
line += 'success, '
except Exception as e:
line += 'fail, '
print(f'Fail to save onnx for {name}: {e}.')
continue
try:
saved_name = name.replace('/', '-')
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx')
onnx.checker.check_model(onnx_model)
line += 'success'
except Exception as e:
line += 'fail'
print(f'Fail to check onnx for {name}: {e}.')
continue
line += '\n'
f.write(line)

Loading…
Cancel
Save