logo
Browse Source

Update model list for onnx

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
f5865f2021
  1. 14
      auto_transformers.py
  2. 18
      test_onnx.py

14
auto_transformers.py

@ -107,20 +107,21 @@ class AutoTransformers(NNOperator):
path,
input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys())
output_names=["last_hidden_state"],
opset_version=12,
opset_version=13,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "input_length"},
"token_type_ids": {0: "batch_size", 1: "input_length"},
"attention_mask": {0: "batch_size", 1: "input_length"},
"last_hidden_state": {0: "batch_size"},
})
except Exception:
except Exception as e:
print(e, '\nTrying with 2 outputs...')
torch.onnx.export(self.model,
tuple(inputs.values()),
path,
input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys())
output_names=["last_hidden_state", "pooler_output"],
opset_version=12,
opset_version=13,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "input_length"},
"token_type_ids": {0: "batch_size", 1: "input_length"},
@ -321,16 +322,11 @@ class AutoTransformers(NNOperator):
model_list = list(set(full_list) - set(to_remove))
elif format == 'onnx':
to_remove = [
'albert-xlarge-v1',
'albert-xlarge-v2',
'albert-xxlarge-v1',
'albert-xxlarge-v2',
'allenai/led-base-16384',
'ctrl',
'distilgpt2',
'EleutherAI/gpt-j-6B',
'EleutherAI/gpt-neo-1.3B',
'funnel-transformer/intermediate',
'funnel-transformer/large',
'funnel-transformer/medium',
'funnel-transformer/small',
@ -338,8 +334,6 @@ class AutoTransformers(NNOperator):
'google/bigbird-pegasus-large-arxiv',
'google/bigbird-pegasus-large-bigpatent',
'google/bigbird-pegasus-large-pubmed',
'google/canine-c',
'google/canine-s',
'google/fnet-base',
'google/fnet-large',
'google/reformer-crime-and-punishment',

18
test_onnx.py

@ -1,11 +1,16 @@
from auto_transformers import AutoTransformers
import onnx
import warnings
warnings.filterwarnings('ignore')
f = open('onnx.csv', 'a+')
f.write('model_name, run_op, save_onnx, check_onnx\n')
# models = AutoTransformers.supported_model_names()[:1]
models = ['bert-base-cased', 'distilbert-base-cased']
# full_models = AutoTransformers.supported_model_names()
# checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models]
models = ['funnel-transformer/large', 'funnel-transformer/medium', 'funnel-transformer/small', 'funnel-transformer/xlarge']
for name in models:
f.write(f'{name},')
@ -16,22 +21,23 @@ for name in models:
except Exception as e:
f.write('fail')
print(f'Fail to load op for {name}: {e}')
continue
pass
try:
op.save_model(format='onnx')
f.write('success,')
except Exception as e:
f.write('fail')
print(f'Fail to save onnx for {name}: {e}')
continue
pass
try:
saved_name = name.replace('/', '-')
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx')
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False)
onnx.checker.check_model(onnx_model)
f.write('success')
except Exception as e:
f.write('fail')
print(f'Fail to check onnx for {name}: {e}')
continue
pass
f.write('\n')
print('Finished.')

Loading…
Cancel
Save