logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

355 lines
13 KiB

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import os
import requests
import torch
import shutil
from pathlib import Path
from typing import Union
from collections import OrderedDict
2 years ago
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForMaskedLM, AutoModelForCausalLM
from towhee.operator import NNOperator
from towhee import register
try:
from towhee import accelerate
except:
def accelerate(func):
return func
import warnings
import logging
from transformers import logging as t_logging
log = logging.getLogger('run_op')
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
log.setLevel(logging.ERROR)
t_logging.set_verbosity_error()
def create_model(model_name, checkpoint_path, device):
_torch_weights = False
if checkpoint_path:
if os.path.isdir(checkpoint_path) and \
os.path.exists(os.path.join(checkpoint_path, 'config.json')):
model = AutoModel.from_pretrained(checkpoint_path)
else:
model = AutoConfig.from_pretrained(model_name)
_torch_weights = True
else:
model = AutoModel.from_pretrained(model_name)
model = model.to(device)
if hasattr(model, 'pooler') and model.pooler:
model.pooler = None
if _torch_weights:
try:
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
except Exception:
log.error(f'Failed to load weights from {checkpoint_path}')
model.eval()
return model
@accelerate
class Model:
def __init__(self, model_name, checkpoint_path, device):
self.device = device
self.model = create_model(model_name, checkpoint_path, device)
def __call__(self, *args, **kwargs):
new_args = []
for x in args:
new_args.append(x.to(self.device))
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v.to(self.device)
outs = self.model(*new_args, **new_kwargs, return_dict=True)
return outs['last_hidden_state']
@register(output_schema=['vec'])
class AutoTransformers(NNOperator):
"""
NLP embedding operator that uses the pretrained transformers model gathered by huggingface.
Args:
model_name (`str`):
The model name to load a pretrained model from transformers.
checkpoint_path (`str`):
The local checkpoint path.
tokenizer (`object`):
The tokenizer to tokenize input text as model inputs.
"""
def __init__(self,
model_name: str = None,
checkpoint_path: str = None,
tokenizer: object = None,
device: str = None,
):
super().__init__()
if device:
self.device = device
else:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if model_name in s_list:
self.model_name = 'sentence-transformers/' + model_name
else:
self.model_name = model_name
self.checkpoint_path = checkpoint_path
if self.model_name:
# model_list = self.supported_model_names()
# assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
self.model = Model(model_name=self.model_name, checkpoint_path=self.checkpoint_path, device=self.device)
if tokenizer:
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = '[PAD]'
else:
log.warning('The operator is initialized without specified model.')
pass
def __call__(self, data: Union[str, list]) -> numpy.ndarray:
if isinstance(data, str):
txt = [data]
else:
txt = data
try:
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt')
except Exception as e:
log.error(f'Fail to tokenize inputs: {e}')
raise e
try:
outs = self.model(**inputs).to('cpu')
except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}')
raise e
outs = self.post_proc(outs, inputs)
features = outs.detach().numpy()
if isinstance(data, str):
features = features.squeeze(0)
else:
features = list(features)
return features
@property
def _model(self):
model = self.model.model
return model
@property
def model_config(self):
configs = AutoConfig.from_pretrained(self.model_name)
return configs
@property
def onnx_config(self):
from transformers.onnx.features import FeaturesManager
try:
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self._model, feature='default')
old_config = model_onnx_config(self.model_config)
onnx_config = {
'inputs': dict(old_config.inputs),
'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']}
}
except Exception:
input_dict = {}
for k in self.tokenizer.model_input_names:
input_dict[k] = {0: 'batch_size', 1: 'sequence_length'}
onnx_config = {
'inputs': input_dict,
'outputs': {'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}}
}
return onnx_config
def post_proc(self, token_embeddings, inputs):
token_embeddings = token_embeddings
attention_mask = inputs['attention_mask']
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sentence_embs = torch.sum(
token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sentence_embs
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
if output_file == 'default':
output_file = str(Path(__file__).parent)
output_file = os.path.join(output_file, 'saved', model_type)
os.makedirs(output_file, exist_ok=True)
name = self.model_name.replace('/', '-')
output_file = os.path.join(output_file, name)
if model_type in ['pytorch', 'torchscript']:
output_file = output_file + '.pt'
elif model_type == 'onnx':
output_file = output_file + '.onnx'
else:
raise AttributeError('Unsupported model_type.')
dummy_input = 'test sentence'
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt')
if model_type == 'pytorch':
torch.save(self._model, output_file)
elif model_type == 'torchscript':
inputs = list(inputs.values())
try:
try:
jit_model = torch.jit.script(self._model)
except Exception:
jit_model = torch.jit.trace(self._model, inputs, strict=False)
torch.jit.save(jit_model, output_file)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif model_type == 'onnx':
dynamic_axes = {}
for k, v in self.onnx_config['inputs'].items():
dynamic_axes[k] = v
for k, v in self.onnx_config['outputs'].items():
dynamic_axes[k] = v
if hasattr(self._model.config, 'use_cache'):
self._model.config.use_cache = False
torch.onnx.export(
self._model.to('cpu'),
tuple(inputs.values()),
output_file,
input_names=list(self.onnx_config['inputs'].keys()),
output_names=list(self.onnx_config['outputs'].keys()),
dynamic_axes=dynamic_axes,
opset_version=torch.onnx.constant_folding_opset_versions[-1] if hasattr(
torch.onnx, 'constant_folding_opset_versions') else 14,
do_constant_folding=True,
)
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')
return Path(output_file).resolve()
@property
def supported_formats(self):
return ['onnx']
@staticmethod
def supported_model_names(format: str = None):
add_models = [
'bert-base-uncased',
'bert-large-uncased',
'bert-large-uncased-whole-word-masking',
'distilbert-base-uncased',
'facebook/bart-large',
'gpt2-xl',
'microsoft/deberta-xlarge',
'microsoft/deberta-xlarge-mnli',
]
full_list = s_list + add_models
full_list.sort()
if format is None:
model_list = full_list
elif format == 'pytorch':
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
elif format == 'torchscript':
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
elif format == 'onnx':
to_remove = ['gpt2-xl']
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
# todo: elif format == 'tensorrt':
else:
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
return model_list
2 years ago
def train(self, training_config=None,
train_dataset=None,
eval_dataset=None,
resume_checkpoint_path=None, **kwargs):
from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer
from .train_clm_with_hf_trainer import train_clm_with_hf_trainer
2 years ago
task = kwargs.pop('task', None)
data_args = kwargs.pop('data_args', None)
training_args = kwargs.pop('training_args', None)
prepare_model_weights_f = kwargs.pop('prepare_model_weights_f', None)
if task == 'mlm' or task is None:
model_with_head = AutoModelForMaskedLM.from_pretrained(self.model_name)
if prepare_model_weights_f is not None:
model_with_head = prepare_model_weights_f(self._model, model_with_head, **kwargs)
train_mlm_with_hf_trainer(
model_with_head,
self.tokenizer,
data_args,
training_args,
**kwargs
)
elif task == 'clm':
model_with_head = AutoModelForCausalLM.from_pretrained(self.model_name)
if prepare_model_weights_f is not None:
model_with_head = prepare_model_weights_f(self._model, model_with_head, **kwargs)
train_clm_with_hf_trainer(
model_with_head,
self.tokenizer,
data_args,
training_args,
**kwargs
)
s_list = [
'paraphrase-MiniLM-L3-v2',
'paraphrase-MiniLM-L6-v2',
'paraphrase-MiniLM-L12-v2',
'paraphrase-distilroberta-base-v2',
'paraphrase-TinyBERT-L6-v2',
'paraphrase-mpnet-base-v2',
'paraphrase-albert-small-v2',
'paraphrase-multilingual-mpnet-base-v2',
'paraphrase-multilingual-MiniLM-L12-v2',
'distiluse-base-multilingual-cased-v1',
'distiluse-base-multilingual-cased-v2',
'all-distilroberta-v1',
'all-MiniLM-L6-v1',
'all-MiniLM-L6-v2',
'all-MiniLM-L12-v1',
'all-MiniLM-L12-v2',
'all-mpnet-base-v1',
'all-mpnet-base-v2',
'all-roberta-large-v1',
'multi-qa-MiniLM-L6-dot-v1',
'multi-qa-MiniLM-L6-cos-v1',
'multi-qa-distilbert-dot-v1',
'multi-qa-distilbert-cos-v1',
'multi-qa-mpnet-base-dot-v1',
'multi-qa-mpnet-base-cos-v1',
'msmarco-distilbert-dot-v5',
'msmarco-bert-base-dot-v5',
'msmarco-distilbert-base-tas-b',
'bert-base-nli-mean-tokens',
'msmarco-distilbert-base-v4'
]