diff --git a/auto_transformers.py b/auto_transformers.py index 6f199f1..8b56a7f 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -188,7 +188,7 @@ class AutoTransformers(NNOperator): raise AttributeError('Unsupported model_type.') dummy_input = 'test sentence' - inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary + inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt').to(self.device) if model_type == 'pytorch': torch.save(self._model, output_file) elif model_type == 'torchscript':