transformers
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
334 lines
12 KiB
334 lines
12 KiB
# Copyright 2021 Zilliz. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import numpy
|
|
import os
|
|
import requests
|
|
import torch
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Union
|
|
from collections import OrderedDict
|
|
|
|
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForMaskedLM, AutoModelForCausalLM
|
|
|
|
from towhee.operator import NNOperator
|
|
from towhee import register
|
|
# from towhee.dc2 import accelerate
|
|
|
|
import warnings
|
|
import logging
|
|
from transformers import logging as t_logging
|
|
|
|
from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer
|
|
from .train_clm_with_hf_trainer import train_clm_with_hf_trainer
|
|
|
|
log = logging.getLogger('run_op')
|
|
warnings.filterwarnings('ignore')
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
|
|
|
log.setLevel(logging.ERROR)
|
|
t_logging.set_verbosity_error()
|
|
|
|
|
|
def create_model(model_name, checkpoint_path, device):
|
|
model = AutoModel.from_pretrained(model_name).to(device)
|
|
if hasattr(model, 'pooler') and model.pooler:
|
|
model.pooler = None
|
|
if checkpoint_path:
|
|
try:
|
|
state_dict = torch.load(checkpoint_path, map_location=device)
|
|
model.load_state_dict(state_dict)
|
|
except Exception:
|
|
log.error(f'Fail to load weights from {checkpoint_path}')
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
# @accelerate
|
|
class Model:
|
|
def __init__(self, model_name, checkpoint_path, device):
|
|
self.model = create_model(model_name, checkpoint_path, device)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
outs = self.model(*args, **kwargs, return_dict=True)
|
|
return outs['last_hidden_state']
|
|
|
|
|
|
@register(output_schema=['vec'])
|
|
class AutoTransformers(NNOperator):
|
|
"""
|
|
NLP embedding operator that uses the pretrained transformers model gathered by huggingface.
|
|
Args:
|
|
model_name (`str`):
|
|
The model name to load a pretrained model from transformers.
|
|
checkpoint_path (`str`):
|
|
The local checkpoint path.
|
|
tokenizer (`object`):
|
|
The tokenizer to tokenize input text as model inputs.
|
|
"""
|
|
|
|
def __init__(self,
|
|
model_name: str = None,
|
|
checkpoint_path: str = None,
|
|
tokenizer: object = None,
|
|
device: str = None,
|
|
norm: bool = False
|
|
):
|
|
super().__init__()
|
|
if device:
|
|
self.device = device
|
|
else:
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
if model_name in s_list:
|
|
self.model_name = 'sentence-transformers/' + model_name
|
|
else:
|
|
self.model_name = model_name
|
|
self.norm = norm
|
|
self.checkpoint_path = checkpoint_path
|
|
|
|
if self.model_name:
|
|
# model_list = self.supported_model_names()
|
|
# assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
|
|
self.model = Model(model_name=self.model_name, checkpoint_path=self.checkpoint_path, device=self.device)
|
|
if tokenizer:
|
|
self.tokenizer = tokenizer
|
|
else:
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
if not self.tokenizer.pad_token:
|
|
self.tokenizer.pad_token = '[PAD]'
|
|
else:
|
|
log.warning('The operator is initialized without specified model.')
|
|
pass
|
|
|
|
def __call__(self, data: Union[str, list]) -> numpy.ndarray:
|
|
if isinstance(data, str):
|
|
txt = [data]
|
|
else:
|
|
txt = data
|
|
try:
|
|
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt').to(self.device)
|
|
except Exception as e:
|
|
log.error(f'Fail to tokenize inputs: {e}')
|
|
raise e
|
|
try:
|
|
outs = self.model(**inputs)
|
|
except Exception as e:
|
|
log.error(f'Invalid input for the model: {self.model_name}')
|
|
raise e
|
|
outs = self.post_proc(outs, inputs)
|
|
if self.norm:
|
|
outs = torch.nn.functional.normalize(outs, )
|
|
features = outs.cpu().detach().numpy()
|
|
if isinstance(data, str):
|
|
features = features.squeeze(0)
|
|
else:
|
|
features = list(features)
|
|
return features
|
|
|
|
@property
|
|
def _model(self):
|
|
model = self.model.model
|
|
return model
|
|
|
|
@property
|
|
def model_config(self):
|
|
configs = AutoConfig.from_pretrained(self.model_name)
|
|
return configs
|
|
|
|
@property
|
|
def onnx_config(self):
|
|
from transformers.onnx.features import FeaturesManager
|
|
try:
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
|
|
self._model, feature='default')
|
|
old_config = model_onnx_config(self.model_config)
|
|
onnx_config = {
|
|
'inputs': dict(old_config.inputs),
|
|
'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']}
|
|
}
|
|
except Exception:
|
|
input_dict = {}
|
|
for k in self.tokenizer.model_input_names:
|
|
input_dict[k] = {0: 'batch_size', 1: 'sequence_length'}
|
|
onnx_config = {
|
|
'inputs': input_dict,
|
|
'outputs': {'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}}
|
|
}
|
|
return onnx_config
|
|
|
|
def post_proc(self, token_embeddings, inputs):
|
|
token_embeddings = token_embeddings.to(self.device)
|
|
attention_mask = inputs['attention_mask'].to(self.device)
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
sentence_embs = torch.sum(
|
|
token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
return sentence_embs
|
|
|
|
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
|
|
if output_file == 'default':
|
|
output_file = str(Path(__file__).parent)
|
|
output_file = os.path.join(output_file, 'saved', model_type)
|
|
os.makedirs(output_file, exist_ok=True)
|
|
name = self.model_name.replace('/', '-')
|
|
output_file = os.path.join(output_file, name)
|
|
if model_type in ['pytorch', 'torchscript']:
|
|
output_file = output_file + '.pt'
|
|
elif model_type == 'onnx':
|
|
output_file = output_file + '.onnx'
|
|
else:
|
|
raise AttributeError('Unsupported model_type.')
|
|
|
|
dummy_input = 'test sentence'
|
|
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt').to(self.device)
|
|
if model_type == 'pytorch':
|
|
torch.save(self._model, output_file)
|
|
elif model_type == 'torchscript':
|
|
inputs = list(inputs.values())
|
|
try:
|
|
try:
|
|
jit_model = torch.jit.script(self._model)
|
|
except Exception:
|
|
jit_model = torch.jit.trace(self._model, inputs, strict=False)
|
|
torch.jit.save(jit_model, output_file)
|
|
except Exception as e:
|
|
log.error(f'Fail to save as torchscript: {e}.')
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.')
|
|
elif model_type == 'onnx':
|
|
dynamic_axes = {}
|
|
for k, v in self.onnx_config['inputs'].items():
|
|
dynamic_axes[k] = v
|
|
for k, v in self.onnx_config['outputs'].items():
|
|
dynamic_axes[k] = v
|
|
if hasattr(self._model.config, 'use_cache'):
|
|
self._model.config.use_cache = False
|
|
torch.onnx.export(
|
|
self._model,
|
|
tuple(inputs.values()),
|
|
output_file,
|
|
input_names=list(self.onnx_config['inputs'].keys()),
|
|
output_names=list(self.onnx_config['outputs'].keys()),
|
|
dynamic_axes=dynamic_axes,
|
|
opset_version=torch.onnx.constant_folding_opset_versions[-1] if hasattr(
|
|
torch.onnx, 'constant_folding_opset_versions') else 14,
|
|
do_constant_folding=True,
|
|
)
|
|
# todo: elif format == 'tensorrt':
|
|
else:
|
|
log.error(f'Unsupported format "{format}".')
|
|
return Path(output_file).resolve()
|
|
|
|
@property
|
|
def supported_formats(self):
|
|
return ['onnx']
|
|
|
|
@staticmethod
|
|
def supported_model_names(format: str = None):
|
|
add_models = [
|
|
'bert-base-uncased',
|
|
'bert-large-uncased',
|
|
'bert-large-uncased-whole-word-masking',
|
|
'distilbert-base-uncased',
|
|
'facebook/bart-large',
|
|
'gpt2-xl',
|
|
'microsoft/deberta-xlarge',
|
|
'microsoft/deberta-xlarge-mnli',
|
|
]
|
|
full_list = s_list + add_models
|
|
full_list.sort()
|
|
if format is None:
|
|
model_list = full_list
|
|
elif format == 'pytorch':
|
|
to_remove = []
|
|
assert set(to_remove).issubset(set(full_list))
|
|
model_list = list(set(full_list) - set(to_remove))
|
|
elif format == 'torchscript':
|
|
to_remove = []
|
|
assert set(to_remove).issubset(set(full_list))
|
|
model_list = list(set(full_list) - set(to_remove))
|
|
elif format == 'onnx':
|
|
to_remove = ['gpt2-xl']
|
|
assert set(to_remove).issubset(set(full_list))
|
|
model_list = list(set(full_list) - set(to_remove))
|
|
# todo: elif format == 'tensorrt':
|
|
else:
|
|
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
|
|
return model_list
|
|
|
|
def train(self, training_config=None,
|
|
train_dataset=None,
|
|
eval_dataset=None,
|
|
resume_checkpoint_path=None, **kwargs):
|
|
task = kwargs.pop('task', None)
|
|
data_args = kwargs.pop('data_args', None)
|
|
training_args = kwargs.pop('training_args', None)
|
|
prepare_model_weights_f = kwargs.pop('prepare_model_weights_f', None)
|
|
|
|
if task == 'mlm' or task is None:
|
|
model_with_head = AutoModelForMaskedLM.from_pretrained(self.model_name)
|
|
if prepare_model_weights_f is not None:
|
|
model_with_head = prepare_model_weights_f(self._model, model_with_head, **kwargs)
|
|
|
|
train_mlm_with_hf_trainer(
|
|
model_with_head,
|
|
self.tokenizer,
|
|
data_args,
|
|
training_args,
|
|
**kwargs
|
|
)
|
|
elif task == 'clm':
|
|
model_with_head = AutoModelForCausalLM.from_pretrained(self.model_name)
|
|
if prepare_model_weights_f is not None:
|
|
model_with_head = prepare_model_weights_f(self._model, model_with_head, **kwargs)
|
|
|
|
train_clm_with_hf_trainer(
|
|
model_with_head,
|
|
self.tokenizer,
|
|
data_args,
|
|
training_args,
|
|
**kwargs
|
|
)
|
|
|
|
s_list = [
|
|
'paraphrase-MiniLM-L3-v2',
|
|
'paraphrase-MiniLM-L6-v2',
|
|
'paraphrase-MiniLM-L12-v2',
|
|
'paraphrase-distilroberta-base-v2',
|
|
'paraphrase-TinyBERT-L6-v2',
|
|
'paraphrase-mpnet-base-v2',
|
|
'paraphrase-albert-small-v2',
|
|
'paraphrase-multilingual-mpnet-base-v2',
|
|
'paraphrase-multilingual-MiniLM-L12-v2',
|
|
'distiluse-base-multilingual-cased-v1',
|
|
'distiluse-base-multilingual-cased-v2',
|
|
'all-distilroberta-v1',
|
|
'all-MiniLM-L6-v1',
|
|
'all-MiniLM-L6-v2',
|
|
'all-MiniLM-L12-v1',
|
|
'all-MiniLM-L12-v2',
|
|
'all-mpnet-base-v1',
|
|
'all-mpnet-base-v2',
|
|
'all-roberta-large-v1',
|
|
'multi-qa-MiniLM-L6-dot-v1',
|
|
'multi-qa-MiniLM-L6-cos-v1',
|
|
'multi-qa-distilbert-dot-v1',
|
|
'multi-qa-distilbert-cos-v1',
|
|
'multi-qa-mpnet-base-dot-v1',
|
|
'multi-qa-mpnet-base-cos-v1',
|
|
'msmarco-distilbert-dot-v5',
|
|
'msmarco-bert-base-dot-v5',
|
|
'msmarco-distilbert-base-tas-b',
|
|
'bert-base-nli-mean-tokens',
|
|
'msmarco-distilbert-base-v4'
|
|
]
|