From 0db0a18ad0f630419fcf17b37d115a9ac91c472c Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 6 Jan 2023 14:22:36 +0800 Subject: [PATCH] Fix for model without pad Signed-off-by: Jael Gu --- auto_transformers.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index c254399..e5255de 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -103,9 +103,8 @@ class AutoTransformers(NNOperator): def __call__(self, txt: str, return_sentence_emb: bool = False) -> numpy.ndarray: try: inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device) - except Exception as e: - log.error(f'Invalid input for the tokenizer: {self.model_name}') - raise e + except Exception: + inputs = self.tokenizer(dummy_input, truncation=True, return_tensors='pt').to(self.device) try: outs = self.model(**inputs) except Exception as e: @@ -144,7 +143,10 @@ class AutoTransformers(NNOperator): raise AttributeError('Unsupported model_type.') dummy_input = '[CLS]' - inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary + try: + inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary + except Exception: + inputs = self.tokenizer(dummy_input, truncation=True, return_tensors='pt') if model_type == 'pytorch': torch.save(self._model, output_file) elif model_type == 'torchscript':