diff --git a/README.md b/README.md
index 3ce47ab..3bbdbb4 100644
--- a/README.md
+++ b/README.md
@@ -329,6 +329,13 @@ If None, the operator will use default tokenizer by `model_name` from Huggingfac
+***return_sentence_emb***: *bool*
+
+The flag to output a sentence embedding for each text, defaults to True.
+If False, the operator returns token embeddings for each text.
+
+
+
## Interface
The operator takes a piece of text in string as input.
@@ -339,9 +346,11 @@ and then return text embedding in ndarray.
**Parameters:**
-***txt***: *str*
+***data***: *Union[str, list]*
- The text in string.
+ The text in string or a list of texts.
+If data is string, the operator returns embedding(s) in ndarray.
+If data is a list, the operator returns embedding(s) in a list.
**Returns**:
@@ -350,9 +359,6 @@ and then return text embedding in ndarray.
The text embedding extracted by model.
-***return_sentence_emb***: *bool*
-
-The flag to output sentence embedding instead of token embeddings, defaults to False.
diff --git a/auto_transformers.py b/auto_transformers.py
index 12ff5b9..b6c1df6 100644
--- a/auto_transformers.py
+++ b/auto_transformers.py
@@ -17,6 +17,7 @@ import os
import torch
import shutil
from pathlib import Path
+from typing import Union
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoModelForCausalLM
@@ -26,13 +27,15 @@ from towhee import register
import warnings
import logging
+from transformers import logging as t_logging
from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer
from .train_clm_with_hf_trainer import train_clm_with_hf_trainer
log = logging.getLogger('run_op')
-
warnings.filterwarnings('ignore')
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+t_logging.set_verbosity_error()
# @accelerate
@@ -76,12 +79,14 @@ class AutoTransformers(NNOperator):
checkpoint_path: str = None,
tokenizer: object = None,
device: str = None,
+ return_sentence_emb: bool = True
):
super().__init__()
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.model_name = model_name
+ self.return_sentence_emb = return_sentence_emb
if self.model_name:
model_list = self.supported_model_names()
@@ -102,7 +107,11 @@ class AutoTransformers(NNOperator):
log.warning('The operator is initialized without specified model.')
pass
- def __call__(self, txt: str, return_sentence_emb: bool = False) -> numpy.ndarray:
+ def __call__(self, data: Union[str, list]) -> numpy.ndarray:
+ if isinstance(data, str):
+ txt = [data]
+ else:
+ txt = data
try:
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device)
except Exception as e:
@@ -113,11 +122,14 @@ class AutoTransformers(NNOperator):
except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}')
raise e
- if return_sentence_emb:
+ if self.return_sentence_emb:
outs = self.post_proc(outs, inputs)
- features = outs.squeeze(0)
- vec = features.cpu().detach().numpy()
- return vec
+ features = outs.cpu().detach().numpy()
+ if isinstance(data, str):
+ features = features.squeeze(0)
+ else:
+ features = list(features)
+ return features
@property
def _model(self):