From e1a1abd518fe6d42fd70110360c3c8c06767e8f4 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 3 Feb 2023 10:11:54 +0800 Subject: [PATCH] Fix onnx export Signed-off-by: Jael Gu --- auto_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_transformers.py b/auto_transformers.py index 6f199f1..8b56a7f 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -188,7 +188,7 @@ class AutoTransformers(NNOperator): raise AttributeError('Unsupported model_type.') dummy_input = 'test sentence' - inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary + inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt').to(self.device) if model_type == 'pytorch': torch.save(self._model, output_file) elif model_type == 'torchscript':