From 0485da739901e8b7d89749096e7ca799dea3758b Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 6 Jan 2023 16:53:12 +0800 Subject: [PATCH] Modify output format Signed-off-by: Jael Gu --- README.md | 16 +++++++++++----- auto_transformers.py | 24 ++++++++++++++++++------ 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 3ce47ab..3bbdbb4 100644 --- a/README.md +++ b/README.md @@ -329,6 +329,13 @@ If None, the operator will use default tokenizer by `model_name` from Huggingfac
+***return_sentence_emb***: *bool* + +The flag to output a sentence embedding for each text, defaults to True. +If False, the operator returns token embeddings for each text. + +
+ ## Interface The operator takes a piece of text in string as input. @@ -339,9 +346,11 @@ and then return text embedding in ndarray. **Parameters:** -***txt***: *str* +***data***: *Union[str, list]* -​ The text in string. +​ The text in string or a list of texts. +If data is string, the operator returns embedding(s) in ndarray. +If data is a list, the operator returns embedding(s) in a list. **Returns**: @@ -350,9 +359,6 @@ and then return text embedding in ndarray. ​ The text embedding extracted by model. -***return_sentence_emb***: *bool* - -The flag to output sentence embedding instead of token embeddings, defaults to False.
diff --git a/auto_transformers.py b/auto_transformers.py index 12ff5b9..b6c1df6 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -17,6 +17,7 @@ import os import torch import shutil from pathlib import Path +from typing import Union from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoModelForCausalLM @@ -26,13 +27,15 @@ from towhee import register import warnings import logging +from transformers import logging as t_logging from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer from .train_clm_with_hf_trainer import train_clm_with_hf_trainer log = logging.getLogger('run_op') - warnings.filterwarnings('ignore') +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +t_logging.set_verbosity_error() # @accelerate @@ -76,12 +79,14 @@ class AutoTransformers(NNOperator): checkpoint_path: str = None, tokenizer: object = None, device: str = None, + return_sentence_emb: bool = True ): super().__init__() if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device self.model_name = model_name + self.return_sentence_emb = return_sentence_emb if self.model_name: model_list = self.supported_model_names() @@ -102,7 +107,11 @@ class AutoTransformers(NNOperator): log.warning('The operator is initialized without specified model.') pass - def __call__(self, txt: str, return_sentence_emb: bool = False) -> numpy.ndarray: + def __call__(self, data: Union[str, list]) -> numpy.ndarray: + if isinstance(data, str): + txt = [data] + else: + txt = data try: inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device) except Exception as e: @@ -113,11 +122,14 @@ class AutoTransformers(NNOperator): except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e - if return_sentence_emb: + if self.return_sentence_emb: outs = self.post_proc(outs, inputs) - features = outs.squeeze(0) - vec = features.cpu().detach().numpy() - return vec + features = outs.cpu().detach().numpy() + if isinstance(data, str): + features = features.squeeze(0) + else: + features = list(features) + return features @property def _model(self):