diff --git a/auto_transformers.py b/auto_transformers.py index 056a98c..aca2eba 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -206,6 +206,8 @@ class AutoTransformers(NNOperator): dynamic_axes[k] = v for k, v in self.onnx_config['outputs'].items(): dynamic_axes[k] = v + if hasattr(self._model.config, 'use_cache'): + self._model.config.use_cache = False torch.onnx.export( self._model, tuple(inputs.values()), diff --git a/test_onnx.py b/test_onnx.py index 58134d4..a0fa6d6 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -21,7 +21,7 @@ t_logging.set_verbosity_error() # full_models = op.supported_model_names() # checked_models = AutoTransformers.supported_model_names(format='onnx') # models = [x for x in full_models if x not in checked_models] -models = ['distilbert-base-cased', 'sentence-transformers/paraphrase-albert-small-v2'] +models = ['distilbert-base-cased', 'paraphrase-albert-small-v2'] test_txt = 'hello, world.' atol = 1e-3 log_path = 'transformers_onnx.log'