|
|
@ -17,6 +17,7 @@ import os |
|
|
|
import torch |
|
|
|
import shutil |
|
|
|
from pathlib import Path |
|
|
|
from typing import Union |
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoModelForCausalLM |
|
|
|
|
|
|
@ -26,13 +27,15 @@ from towhee import register |
|
|
|
|
|
|
|
import warnings |
|
|
|
import logging |
|
|
|
from transformers import logging as t_logging |
|
|
|
|
|
|
|
from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer |
|
|
|
from .train_clm_with_hf_trainer import train_clm_with_hf_trainer |
|
|
|
|
|
|
|
log = logging.getLogger('run_op') |
|
|
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
t_logging.set_verbosity_error() |
|
|
|
|
|
|
|
|
|
|
|
# @accelerate |
|
|
@ -76,12 +79,14 @@ class AutoTransformers(NNOperator): |
|
|
|
checkpoint_path: str = None, |
|
|
|
tokenizer: object = None, |
|
|
|
device: str = None, |
|
|
|
return_sentence_emb: bool = True |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
if device is None: |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.device = device |
|
|
|
self.model_name = model_name |
|
|
|
self.return_sentence_emb = return_sentence_emb |
|
|
|
|
|
|
|
if self.model_name: |
|
|
|
model_list = self.supported_model_names() |
|
|
@ -102,7 +107,11 @@ class AutoTransformers(NNOperator): |
|
|
|
log.warning('The operator is initialized without specified model.') |
|
|
|
pass |
|
|
|
|
|
|
|
def __call__(self, txt: str, return_sentence_emb: bool = False) -> numpy.ndarray: |
|
|
|
def __call__(self, data: Union[str, list]) -> numpy.ndarray: |
|
|
|
if isinstance(data, str): |
|
|
|
txt = [data] |
|
|
|
else: |
|
|
|
txt = data |
|
|
|
try: |
|
|
|
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device) |
|
|
|
except Exception as e: |
|
|
@ -113,11 +122,14 @@ class AutoTransformers(NNOperator): |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
|
raise e |
|
|
|
if return_sentence_emb: |
|
|
|
if self.return_sentence_emb: |
|
|
|
outs = self.post_proc(outs, inputs) |
|
|
|
features = outs.squeeze(0) |
|
|
|
vec = features.cpu().detach().numpy() |
|
|
|
return vec |
|
|
|
features = outs.cpu().detach().numpy() |
|
|
|
if isinstance(data, str): |
|
|
|
features = features.squeeze(0) |
|
|
|
else: |
|
|
|
features = list(features) |
|
|
|
return features |
|
|
|
|
|
|
|
@property |
|
|
|
def _model(self): |
|
|
|