logo
Browse Source

Modify output format

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
0485da7399
  1. 16
      README.md
  2. 24
      auto_transformers.py

16
README.md

@ -329,6 +329,13 @@ If None, the operator will use default tokenizer by `model_name` from Huggingfac
<br /> <br />
***return_sentence_emb***: *bool*
The flag to output a sentence embedding for each text, defaults to True.
If False, the operator returns token embeddings for each text.
<br />
## Interface ## Interface
The operator takes a piece of text in string as input. The operator takes a piece of text in string as input.
@ -339,9 +346,11 @@ and then return text embedding in ndarray.
**Parameters:** **Parameters:**
***txt***: *str*
***data***: *Union[str, list]*
​ The text in string.
​ The text in string or a list of texts.
If data is string, the operator returns embedding(s) in ndarray.
If data is a list, the operator returns embedding(s) in a list.
**Returns**: **Returns**:
@ -350,9 +359,6 @@ and then return text embedding in ndarray.
​ The text embedding extracted by model. ​ The text embedding extracted by model.
***return_sentence_emb***: *bool*
The flag to output sentence embedding instead of token embeddings, defaults to False.
<br /> <br />

24
auto_transformers.py

@ -17,6 +17,7 @@ import os
import torch import torch
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Union
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoModelForCausalLM
@ -26,13 +27,15 @@ from towhee import register
import warnings import warnings
import logging import logging
from transformers import logging as t_logging
from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer
from .train_clm_with_hf_trainer import train_clm_with_hf_trainer from .train_clm_with_hf_trainer import train_clm_with_hf_trainer
log = logging.getLogger('run_op') log = logging.getLogger('run_op')
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error()
# @accelerate # @accelerate
@ -76,12 +79,14 @@ class AutoTransformers(NNOperator):
checkpoint_path: str = None, checkpoint_path: str = None,
tokenizer: object = None, tokenizer: object = None,
device: str = None, device: str = None,
return_sentence_emb: bool = True
): ):
super().__init__() super().__init__()
if device is None: if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device self.device = device
self.model_name = model_name self.model_name = model_name
self.return_sentence_emb = return_sentence_emb
if self.model_name: if self.model_name:
model_list = self.supported_model_names() model_list = self.supported_model_names()
@ -102,7 +107,11 @@ class AutoTransformers(NNOperator):
log.warning('The operator is initialized without specified model.') log.warning('The operator is initialized without specified model.')
pass pass
def __call__(self, txt: str, return_sentence_emb: bool = False) -> numpy.ndarray:
def __call__(self, data: Union[str, list]) -> numpy.ndarray:
if isinstance(data, str):
txt = [data]
else:
txt = data
try: try:
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device) inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device)
except Exception as e: except Exception as e:
@ -113,11 +122,14 @@ class AutoTransformers(NNOperator):
except Exception as e: except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}') log.error(f'Invalid input for the model: {self.model_name}')
raise e raise e
if return_sentence_emb:
if self.return_sentence_emb:
outs = self.post_proc(outs, inputs) outs = self.post_proc(outs, inputs)
features = outs.squeeze(0)
vec = features.cpu().detach().numpy()
return vec
features = outs.cpu().detach().numpy()
if isinstance(data, str):
features = features.squeeze(0)
else:
features = list(features)
return features
@property @property
def _model(self): def _model(self):

Loading…
Cancel
Save