logo
Browse Source

Modify output format

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
0485da7399
  1. 16
      README.md
  2. 24
      auto_transformers.py

16
README.md

@ -329,6 +329,13 @@ If None, the operator will use default tokenizer by `model_name` from Huggingfac
<br />
***return_sentence_emb***: *bool*
The flag to output a sentence embedding for each text, defaults to True.
If False, the operator returns token embeddings for each text.
<br />
## Interface
The operator takes a piece of text in string as input.
@ -339,9 +346,11 @@ and then return text embedding in ndarray.
**Parameters:**
***txt***: *str*
***data***: *Union[str, list]*
​ The text in string.
​ The text in string or a list of texts.
If data is string, the operator returns embedding(s) in ndarray.
If data is a list, the operator returns embedding(s) in a list.
**Returns**:
@ -350,9 +359,6 @@ and then return text embedding in ndarray.
​ The text embedding extracted by model.
***return_sentence_emb***: *bool*
The flag to output sentence embedding instead of token embeddings, defaults to False.
<br />

24
auto_transformers.py

@ -17,6 +17,7 @@ import os
import torch
import shutil
from pathlib import Path
from typing import Union
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoModelForCausalLM
@ -26,13 +27,15 @@ from towhee import register
import warnings
import logging
from transformers import logging as t_logging
from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer
from .train_clm_with_hf_trainer import train_clm_with_hf_trainer
log = logging.getLogger('run_op')
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error()
# @accelerate
@ -76,12 +79,14 @@ class AutoTransformers(NNOperator):
checkpoint_path: str = None,
tokenizer: object = None,
device: str = None,
return_sentence_emb: bool = True
):
super().__init__()
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.model_name = model_name
self.return_sentence_emb = return_sentence_emb
if self.model_name:
model_list = self.supported_model_names()
@ -102,7 +107,11 @@ class AutoTransformers(NNOperator):
log.warning('The operator is initialized without specified model.')
pass
def __call__(self, txt: str, return_sentence_emb: bool = False) -> numpy.ndarray:
def __call__(self, data: Union[str, list]) -> numpy.ndarray:
if isinstance(data, str):
txt = [data]
else:
txt = data
try:
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device)
except Exception as e:
@ -113,11 +122,14 @@ class AutoTransformers(NNOperator):
except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}')
raise e
if return_sentence_emb:
if self.return_sentence_emb:
outs = self.post_proc(outs, inputs)
features = outs.squeeze(0)
vec = features.cpu().detach().numpy()
return vec
features = outs.cpu().detach().numpy()
if isinstance(data, str):
features = features.squeeze(0)
else:
features = list(features)
return features
@property
def _model(self):

Loading…
Cancel
Save