|
|
@ -100,9 +100,9 @@ class AutoTransformers(NNOperator): |
|
|
|
log.warning('The operator is initialized without specified model.') |
|
|
|
pass |
|
|
|
|
|
|
|
def __call__(self, txt: str) -> numpy.ndarray: |
|
|
|
def __call__(self, txt: str, return_sentence_emb: bool = False) -> numpy.ndarray: |
|
|
|
try: |
|
|
|
inputs = self.tokenizer(txt, return_tensors="pt").to(self.device) |
|
|
|
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Invalid input for the tokenizer: {self.model_name}') |
|
|
|
raise e |
|
|
@ -111,11 +111,9 @@ class AutoTransformers(NNOperator): |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
|
raise e |
|
|
|
try: |
|
|
|
features = outs.squeeze(0) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to extract features by model: {self.model_name}') |
|
|
|
raise e |
|
|
|
if return_sentence_emb: |
|
|
|
outs = self.post_proc(outs, inputs) |
|
|
|
features = outs.squeeze(0) |
|
|
|
vec = features.cpu().detach().numpy() |
|
|
|
return vec |
|
|
|
|
|
|
@ -123,6 +121,14 @@ class AutoTransformers(NNOperator): |
|
|
|
def _model(self): |
|
|
|
return self.model.model |
|
|
|
|
|
|
|
def post_proc(self, token_embeddings, inputs): |
|
|
|
token_embeddings = token_embeddings.to(self.device) |
|
|
|
attention_mask = inputs['attention_mask'].to(self.device) |
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
|
sentence_embs = torch.sum( |
|
|
|
token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
return sentence_embs |
|
|
|
|
|
|
|
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): |
|
|
|
if output_file == 'default': |
|
|
|
output_file = str(Path(__file__).parent) |
|
|
@ -138,7 +144,7 @@ class AutoTransformers(NNOperator): |
|
|
|
raise AttributeError('Unsupported model_type.') |
|
|
|
|
|
|
|
dummy_input = '[CLS]' |
|
|
|
inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary |
|
|
|
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary |
|
|
|
if model_type == 'pytorch': |
|
|
|
torch.save(self._model, output_file) |
|
|
|
elif model_type == 'torchscript': |
|
|
|