|
@ -96,6 +96,8 @@ class AutoTransformers(NNOperator): |
|
|
raise e |
|
|
raise e |
|
|
else: |
|
|
else: |
|
|
self.tokenizer = tokenizer |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
if not self.tokenizer.pad_token: |
|
|
|
|
|
self.tokenizer.pad_token = '[PAD]' |
|
|
else: |
|
|
else: |
|
|
log.warning('The operator is initialized without specified model.') |
|
|
log.warning('The operator is initialized without specified model.') |
|
|
pass |
|
|
pass |
|
@ -103,8 +105,9 @@ class AutoTransformers(NNOperator): |
|
|
def __call__(self, txt: str, return_sentence_emb: bool = False) -> numpy.ndarray: |
|
|
def __call__(self, txt: str, return_sentence_emb: bool = False) -> numpy.ndarray: |
|
|
try: |
|
|
try: |
|
|
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device) |
|
|
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device) |
|
|
except Exception: |
|
|
|
|
|
inputs = self.tokenizer(txt, truncation=True, return_tensors='pt').to(self.device) |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
log.error(f'Fail to tokenize inputs: {e}') |
|
|
|
|
|
raise e |
|
|
try: |
|
|
try: |
|
|
outs = self.model(**inputs) |
|
|
outs = self.model(**inputs) |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
@ -143,10 +146,7 @@ class AutoTransformers(NNOperator): |
|
|
raise AttributeError('Unsupported model_type.') |
|
|
raise AttributeError('Unsupported model_type.') |
|
|
|
|
|
|
|
|
dummy_input = '[CLS]' |
|
|
dummy_input = '[CLS]' |
|
|
try: |
|
|
|
|
|
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary |
|
|
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary |
|
|
except Exception: |
|
|
|
|
|
inputs = self.tokenizer(dummy_input, truncation=True, return_tensors='pt') |
|
|
|
|
|
if model_type == 'pytorch': |
|
|
if model_type == 'pytorch': |
|
|
torch.save(self._model, output_file) |
|
|
torch.save(self._model, output_file) |
|
|
elif model_type == 'torchscript': |
|
|
elif model_type == 'torchscript': |
|
|