From 6279bd4d80693d09ed85d4c0f6db2e171389c30b Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 6 Jan 2023 10:50:39 +0800 Subject: [PATCH] Support sentence embedding Signed-off-by: Jael Gu --- auto_transformers.py | 22 ++++++++++++++-------- test_onnx.py | 7 +++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index a495c36..c254399 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -100,9 +100,9 @@ class AutoTransformers(NNOperator): log.warning('The operator is initialized without specified model.') pass - def __call__(self, txt: str) -> numpy.ndarray: + def __call__(self, txt: str, return_sentence_emb: bool = False) -> numpy.ndarray: try: - inputs = self.tokenizer(txt, return_tensors="pt").to(self.device) + inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device) except Exception as e: log.error(f'Invalid input for the tokenizer: {self.model_name}') raise e @@ -111,11 +111,9 @@ class AutoTransformers(NNOperator): except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e - try: - features = outs.squeeze(0) - except Exception as e: - log.error(f'Fail to extract features by model: {self.model_name}') - raise e + if return_sentence_emb: + outs = self.post_proc(outs, inputs) + features = outs.squeeze(0) vec = features.cpu().detach().numpy() return vec @@ -123,6 +121,14 @@ class AutoTransformers(NNOperator): def _model(self): return self.model.model + def post_proc(self, token_embeddings, inputs): + token_embeddings = token_embeddings.to(self.device) + attention_mask = inputs['attention_mask'].to(self.device) + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + sentence_embs = torch.sum( + token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + return sentence_embs + def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): if output_file == 'default': output_file = str(Path(__file__).parent) @@ -138,7 +144,7 @@ class AutoTransformers(NNOperator): raise AttributeError('Unsupported model_type.') dummy_input = '[CLS]' - inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary + inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary if model_type == 'pytorch': torch.save(self._model, output_file) elif model_type == 'torchscript': diff --git a/test_onnx.py b/test_onnx.py index 662045c..777636d 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -9,6 +9,13 @@ import logging import platform import psutil +import warnings +from transformers import logging as t_logging + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +warnings.filterwarnings("ignore") +t_logging.set_verbosity_error() + # full_models = AutoTransformers.supported_model_names() # checked_models = AutoTransformers.supported_model_names(format='onnx') # models = [x for x in full_models if x not in checked_models]