logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

280 lines
9.9 KiB

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy
from typing import Union, List
from pathlib import Path
import torch
from sentence_transformers import SentenceTransformer
from towhee.operator import NNOperator
try:
from towhee import accelerate
except:
def accelerate(func):
return func
import os
import warnings
warnings.filterwarnings('ignore')
logging.getLogger('sentence_transformers').setLevel(logging.ERROR)
log = logging.getLogger('op_s_transformers')
log.setLevel(logging.ERROR)
class ConvertModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.net = model
try:
self.input_names = self.net.tokenizer.model_input_names
except AttributeError:
self.input_names = list(self.net.tokenize(['test']).keys())
def forward(self, *args, **kwargs):
if args:
assert kwargs == {}, 'Only accept neither args or kwargs as inputs.'
assert len(args) == len(self.input_names)
for k, v in zip(self.input_names, args):
kwargs[k] = v
outs = self.net(kwargs)
return outs['sentence_embedding']
@accelerate
class Model:
def __init__(self, model_name, device):
self.device = device
self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device)
self.model.eval()
def __call__(self, **features):
outs = self.model(features)
return outs['sentence_embedding']
class STransformers(NNOperator):
"""
Operator using pretrained Sentence Transformers
"""
def __init__(self, model_name: str = None, device: str = None):
self.model_name = model_name
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
if self.model_name:
self.model = Model(model_name=self.model_name, device=self.device)
else:
log.warning('The operator is initialized without specified model.')
pass
def __call__(self, txt: Union[List[str], str]):
if isinstance(txt, str):
sentences = [txt]
else:
sentences = txt
inputs = self.tokenize(sentences)
# for k, v in inputs.items():
# inputs[k] = v.to(self.device)
embs = self.model(**inputs).cpu().detach().numpy()
if isinstance(txt, str):
embs = embs.squeeze(0)
else:
embs = list(embs)
return embs
@property
def supported_formats(self):
return ['onnx']
def tokenize(self, x):
try:
outs = self._model.tokenize(x)
except Exception:
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
outs = tokenizer(
x,
padding=True, truncation='longest_first', max_length=self.max_seq_length,
return_tensors='pt',
)
return outs
@property
def max_seq_length(self):
import json
from torch.hub import _get_torch_home
torch_cache = _get_torch_home()
sbert_cache = os.path.join(torch_cache, 'sentence_transformers')
cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json')
if not os.path.exists(cfg_path):
cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json')
k = 'max_position_embeddings'
else:
k = 'max_seq_length'
with open(cfg_path) as f:
cfg = json.load(f)
if k in cfg:
max_seq_len = cfg[k]
else:
max_seq_len = None
return max_seq_len
@property
def _model(self):
return self.model.model
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = self.model_name.replace('/', '-')
path = os.path.join(path, name)
if format in ['pytorch', 'torchscript']:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise AttributeError(f'Invalid format {format}.')
dummy_text = ['[CLS]']
dummy_input = self.tokenize(dummy_text)
if format == 'pytorch':
torch.save(self._model, path)
elif format == 'torchscript':
try:
try:
jit_model = torch.jit.script(self._model)
except Exception:
jit_model = torch.jit.trace(self._model, dummy_input, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
new_model = ConvertModel(self._model)
input_names = list(dummy_input.keys())
dynamic_axes = {}
for i_n, i_v in dummy_input.items():
if len(i_v.shape) == 1:
dynamic_axes[i_n] = {0: 'batch_size'}
else:
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'}
try:
torch.onnx.export(new_model,
tuple(dummy_input.values()),
path,
input_names=input_names,
output_names=['output_0'],
opset_version=13,
dynamic_axes=dynamic_axes,
do_constant_folding=True
)
except Exception as e:
log.error(f'Fail to save as onnx: {e}.')
raise RuntimeError(f'Fail to save as onnx: {e}.')
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')
return Path(path).resolve()
@staticmethod
def supported_model_names(format: str = None):
full_list = [
'sentence-t5-xxl',
'sentence-t5-xl',
'sentence-t5-large',
'sentence-t5-base',
'all-distilroberta-v1',
'gtr-t5-xxl',
'gtr-t5-large',
'gtr-t5-xl',
'all-MiniLM-L12-v1',
'all-MiniLM-L12-v2',
'all-MiniLM-L6-v1',
'all-MiniLM-L6-v2',
'all-mpnet-base-v1',
'all-mpnet-base-v2',
'all-roberta-large-v1',
'bert-base-nli-mean-tokens',
'gtr-t5-base',
'distiluse-base-multilingual-cased-v1',
'distiluse-base-multilingual-cased-v2',
'msmarco-bert-base-dot-v5',
'msmarco-distilbert-base-tas-b',
'msmarco-distilbert-base-v4',
'msmarco-distilbert-dot-v5',
'multi-qa-distilbert-cos-v1',
'multi-qa-distilbert-dot-v1',
'multi-qa-MiniLM-L6-cos-v1',
'multi-qa-MiniLM-L6-dot-v1',
'multi-qa-mpnet-base-cos-v1',
'multi-qa-mpnet-base-dot-v1',
'paraphrase-albert-small-v2',
'paraphrase-distilroberta-base-v2',
'average_word_embeddings_komninos',
'paraphrase-MiniLM-L12-v2',
'paraphrase-MiniLM-L3-v2',
'average_word_embeddings_glove.6B.300d',
'paraphrase-MiniLM-L6-v2',
'paraphrase-mpnet-base-v2',
'paraphrase-multilingual-MiniLM-L12-v2',
'paraphrase-multilingual-mpnet-base-v2',
'paraphrase-TinyBERT-L6-v2'
]
full_list.sort()
if format is None:
model_list = full_list
elif format == 'pytorch':
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
elif format == 'onnx':
to_remove = ['gtr-t5-xxl', 'sentence-t5-xxl']
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
else:
log.error(f'Invalid or unsupported format "{format}".')
return model_list
def train(self, training_config=None, **kwargs):
from .train_sts_task import train_sts
train_sts(self._model, training_config)
if __name__ == '__main__':
from sentence_transformers import util
op = STransformers(model_name='nli-distilroberta-base-v2')
# Check if dataset exsist. If not, download and extract it
sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'
if not os.path.exists(sts_dataset_path):
util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)
training_config = {
'sts_dataset_path': sts_dataset_path,
'train_batch_size': 16,
'num_epochs': 4,
'model_save_path': './output'
}
op.train(training_config)