|
|
|
# Copyright 2021 Zilliz. All rights reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import logging
|
|
|
|
import numpy
|
|
|
|
from typing import Union, List
|
|
|
|
from pathlib import Path
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
from transformers import AutoTokenizer, AutoConfig
|
|
|
|
|
|
|
|
from towhee.operator import NNOperator
|
|
|
|
try:
|
|
|
|
from towhee import accelerate
|
|
|
|
except:
|
|
|
|
def accelerate(func):
|
|
|
|
return func
|
|
|
|
|
|
|
|
import os
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
logging.getLogger('sentence_transformers').setLevel(logging.ERROR)
|
|
|
|
log = logging.getLogger('op_s_transformers')
|
|
|
|
log.setLevel(logging.ERROR)
|
|
|
|
|
|
|
|
|
|
|
|
class ConvertModel(torch.nn.Module):
|
|
|
|
def __init__(self, model):
|
|
|
|
super().__init__()
|
|
|
|
self.net = model
|
|
|
|
try:
|
|
|
|
self.input_names = self.net.tokenizer.model_input_names
|
|
|
|
except AttributeError:
|
|
|
|
self.input_names = list(self.net.tokenize(['test']).keys())
|
|
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
|
if args:
|
|
|
|
assert kwargs == {}, 'Only accept neither args or kwargs as inputs.'
|
|
|
|
assert len(args) == len(self.input_names)
|
|
|
|
for k, v in zip(self.input_names, args):
|
|
|
|
kwargs[k] = v
|
|
|
|
outs = self.net(kwargs)
|
|
|
|
return outs['sentence_embedding']
|
|
|
|
|
|
|
|
|
|
|
|
@accelerate
|
|
|
|
class Model:
|
|
|
|
def __init__(self, model_name, device):
|
|
|
|
self.device = device
|
|
|
|
self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device)
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
def __call__(self, *_, **kwargs):
|
|
|
|
new_kwargs = {}
|
|
|
|
for k, v in kwargs.items():
|
|
|
|
new_kwargs[k] = v.to(self.device)
|
|
|
|
outs = self.model(new_kwargs)
|
|
|
|
return outs['sentence_embedding']
|
|
|
|
|
|
|
|
|
|
|
|
class STransformers(NNOperator):
|
|
|
|
"""
|
|
|
|
Operator using pretrained Sentence Transformers
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, model_name: str = None, device: str = None, return_usage: bool = False):
|
|
|
|
self.model_name = model_name
|
|
|
|
if device is None:
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
self.device = device
|
|
|
|
if self.model_name:
|
|
|
|
self.model = Model(model_name=self.model_name, device=self.device)
|
|
|
|
else:
|
|
|
|
log.warning('The operator is initialized without specified model.')
|
|
|
|
self._tokenize = self.get_tokenizer()
|
|
|
|
self.return_usage = return_usage
|
|
|
|
|
|
|
|
def __call__(self, txt: Union[List[str], str]):
|
|
|
|
if isinstance(txt, str):
|
|
|
|
sentences = [txt]
|
|
|
|
else:
|
|
|
|
sentences = txt
|
|
|
|
inputs = self._tokenize(sentences)
|
|
|
|
num_tokens = int(torch.count_nonzero(inputs['input_ids']))
|
|
|
|
embs = self.model(**inputs).cpu().detach().numpy()
|
|
|
|
if isinstance(txt, str):
|
|
|
|
embs = embs.squeeze(0)
|
|
|
|
else:
|
|
|
|
embs = list(embs)
|
|
|
|
if self.return_usage:
|
|
|
|
return {'data': embs, 'token_usage': num_tokens}
|
|
|
|
return embs
|
|
|
|
|
|
|
|
@property
|
|
|
|
def supported_formats(self):
|
|
|
|
return ['onnx']
|
|
|
|
|
|
|
|
def get_tokenizer(self):
|
|
|
|
try:
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name)
|
|
|
|
conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name)
|
|
|
|
except Exception:
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
|
|
conf = AutoConfig.from_pretrained(self.model_name)
|
|
|
|
return partial(tokenizer,
|
|
|
|
padding=True,
|
|
|
|
truncation='longest_first',
|
|
|
|
max_length=conf.max_position_embeddings,
|
|
|
|
return_tensors='pt')
|
|
|
|
# @property
|
|
|
|
# def max_seq_length(self):
|
|
|
|
# import json
|
|
|
|
# from torch.hub import _get_torch_home
|
|
|
|
# torch_cache = _get_torch_home()
|
|
|
|
# sbert_cache = os.path.join(torch_cache, 'sentence_transformers')
|
|
|
|
# cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json')
|
|
|
|
# if not os.path.exists(cfg_path):
|
|
|
|
# cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json')
|
|
|
|
# k = 'max_position_embeddings'
|
|
|
|
# else:
|
|
|
|
# k = 'max_seq_length'
|
|
|
|
# with open(cfg_path) as f:
|
|
|
|
# cfg = json.load(f)
|
|
|
|
# if k in cfg:
|
|
|
|
# max_seq_len = cfg[k]
|
|
|
|
# else:
|
|
|
|
# max_seq_len = None
|
|
|
|
# return max_seq_len
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _model(self):
|
|
|
|
return self.model.model
|
|
|
|
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'):
|
|
|
|
if path == 'default':
|
|
|
|
path = str(Path(__file__).parent)
|
|
|
|
path = os.path.join(path, 'saved', format)
|
|
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
name = self.model_name.replace('/', '-')
|
|
|
|
path = os.path.join(path, name)
|
|
|
|
if format in ['pytorch', 'torchscript']:
|
|
|
|
path = path + '.pt'
|
|
|
|
elif format == 'onnx':
|
|
|
|
path = path + '.onnx'
|
|
|
|
else:
|
|
|
|
raise AttributeError(f'Invalid format {format}.')
|
|
|
|
dummy_text = ['[CLS]']
|
|
|
|
dummy_input = self._tokenize(dummy_text)
|
|
|
|
if format == 'pytorch':
|
|
|
|
torch.save(self._model, path)
|
|
|
|
elif format == 'torchscript':
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
jit_model = torch.jit.script(self._model)
|
|
|
|
except Exception:
|
|
|
|
jit_model = torch.jit.trace(self._model, dummy_input, strict=False)
|
|
|
|
torch.jit.save(jit_model, path)
|
|
|
|
except Exception as e:
|
|
|
|
log.error(f'Fail to save as torchscript: {e}.')
|
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.')
|
|
|
|
elif format == 'onnx':
|
|
|
|
new_model = ConvertModel(self._model)
|
|
|
|
input_names = list(dummy_input.keys())
|
|
|
|
dynamic_axes = {}
|
|
|
|
for i_n, i_v in dummy_input.items():
|
|
|
|
if len(i_v.shape) == 1:
|
|
|
|
dynamic_axes[i_n] = {0: 'batch_size'}
|
|
|
|
else:
|
|
|
|
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
|
|
|
|
dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'}
|
|
|
|
try:
|
|
|
|
torch.onnx.export(new_model.to('cpu'),
|
|
|
|
tuple(dummy_input.values()),
|
|
|
|
path,
|
|
|
|
input_names=input_names,
|
|
|
|
output_names=['output_0'],
|
|
|
|
opset_version=13,
|
|
|
|
dynamic_axes=dynamic_axes,
|
|
|
|
do_constant_folding=True
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
log.error(f'Fail to save as onnx: {e}.')
|
|
|
|
raise RuntimeError(f'Fail to save as onnx: {e}.')
|
|
|
|
# todo: elif format == 'tensorrt':
|
|
|
|
else:
|
|
|
|
log.error(f'Unsupported format "{format}".')
|
|
|
|
return Path(path).resolve()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def supported_model_names(format: str = None):
|
|
|
|
full_list = [
|
|
|
|
'clip-ViT-B-32-multilingual-v1',
|
|
|
|
'sentence-t5-xxl',
|
|
|
|
'sentence-t5-xl',
|
|
|
|
'sentence-t5-large',
|
|
|
|
'sentence-t5-base',
|
|
|
|
'all-distilroberta-v1',
|
|
|
|
'gtr-t5-xxl',
|
|
|
|
'gtr-t5-large',
|
|
|
|
'gtr-t5-xl',
|
|
|
|
'all-MiniLM-L12-v1',
|
|
|
|
'all-MiniLM-L12-v2',
|
|
|
|
'all-MiniLM-L6-v1',
|
|
|
|
'all-MiniLM-L6-v2',
|
|
|
|
'all-mpnet-base-v1',
|
|
|
|
'all-mpnet-base-v2',
|
|
|
|
'all-roberta-large-v1',
|
|
|
|
'bert-base-nli-mean-tokens',
|
|
|
|
'gtr-t5-base',
|
|
|
|
'distiluse-base-multilingual-cased-v1',
|
|
|
|
'distiluse-base-multilingual-cased-v2',
|
|
|
|
'msmarco-bert-base-dot-v5',
|
|
|
|
'msmarco-distilbert-base-tas-b',
|
|
|
|
'msmarco-distilbert-base-v4',
|
|
|
|
'msmarco-distilbert-dot-v5',
|
|
|
|
'multi-qa-distilbert-cos-v1',
|
|
|
|
'multi-qa-distilbert-dot-v1',
|
|
|
|
'multi-qa-MiniLM-L6-cos-v1',
|
|
|
|
'multi-qa-MiniLM-L6-dot-v1',
|
|
|
|
'multi-qa-mpnet-base-cos-v1',
|
|
|
|
'multi-qa-mpnet-base-dot-v1',
|
|
|
|
'paraphrase-albert-small-v2',
|
|
|
|
'paraphrase-distilroberta-base-v2',
|
|
|
|
'average_word_embeddings_komninos',
|
|
|
|
'paraphrase-MiniLM-L12-v2',
|
|
|
|
'paraphrase-MiniLM-L3-v2',
|
|
|
|
'average_word_embeddings_glove.6B.300d',
|
|
|
|
'paraphrase-MiniLM-L6-v2',
|
|
|
|
'paraphrase-mpnet-base-v2',
|
|
|
|
'paraphrase-multilingual-MiniLM-L12-v2',
|
|
|
|
'paraphrase-multilingual-mpnet-base-v2',
|
|
|
|
'paraphrase-TinyBERT-L6-v2'
|
|
|
|
]
|
|
|
|
full_list.sort()
|
|
|
|
if format is None:
|
|
|
|
model_list = full_list
|
|
|
|
elif format == 'pytorch':
|
|
|
|
to_remove = []
|
|
|
|
assert set(to_remove).issubset(set(full_list))
|
|
|
|
model_list = list(set(full_list) - set(to_remove))
|
|
|
|
elif format == 'onnx':
|
|
|
|
to_remove = ['gtr-t5-xxl', 'sentence-t5-xxl']
|
|
|
|
assert set(to_remove).issubset(set(full_list))
|
|
|
|
model_list = list(set(full_list) - set(to_remove))
|
|
|
|
else:
|
|
|
|
log.error(f'Invalid or unsupported format "{format}".')
|
|
|
|
return model_list
|
|
|
|
|
|
|
|
def train(self, training_config=None, **kwargs):
|
|
|
|
from .train_sts_task import train_sts
|
|
|
|
train_sts(self._model, training_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
from sentence_transformers import util
|
|
|
|
op = STransformers(model_name='nli-distilroberta-base-v2')
|
|
|
|
# Check if dataset exsist. If not, download and extract it
|
|
|
|
sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'
|
|
|
|
|
|
|
|
if not os.path.exists(sts_dataset_path):
|
|
|
|
util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)
|
|
|
|
|
|
|
|
training_config = {
|
|
|
|
'sts_dataset_path': sts_dataset_path,
|
|
|
|
'train_batch_size': 16,
|
|
|
|
'num_epochs': 4,
|
|
|
|
'model_save_path': './output'
|
|
|
|
}
|
|
|
|
op.train(training_config)
|