logo
Browse Source

Disable use_cache

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
674d9a3b69
  1. 2
      auto_transformers.py
  2. 2
      test_onnx.py

2
auto_transformers.py

@ -206,6 +206,8 @@ class AutoTransformers(NNOperator):
dynamic_axes[k] = v dynamic_axes[k] = v
for k, v in self.onnx_config['outputs'].items(): for k, v in self.onnx_config['outputs'].items():
dynamic_axes[k] = v dynamic_axes[k] = v
if hasattr(self._model.config, 'use_cache'):
self._model.config.use_cache = False
torch.onnx.export( torch.onnx.export(
self._model, self._model,
tuple(inputs.values()), tuple(inputs.values()),

2
test_onnx.py

@ -21,7 +21,7 @@ t_logging.set_verbosity_error()
# full_models = op.supported_model_names() # full_models = op.supported_model_names()
# checked_models = AutoTransformers.supported_model_names(format='onnx') # checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models] # models = [x for x in full_models if x not in checked_models]
models = ['distilbert-base-cased', 'sentence-transformers/paraphrase-albert-small-v2']
models = ['distilbert-base-cased', 'paraphrase-albert-small-v2']
test_txt = 'hello, world.' test_txt = 'hello, world.'
atol = 1e-3 atol = 1e-3
log_path = 'transformers_onnx.log' log_path = 'transformers_onnx.log'

Loading…
Cancel
Save