logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

218 lines
7.8 KiB

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy
from typing import Union, List
from pathlib import Path
import torch
from sentence_transformers import SentenceTransformer
from towhee.operator import NNOperator
# from towhee.dc2 import accelerate
import os
import warnings
warnings.filterwarnings('ignore')
logging.getLogger('sentence_transformers').setLevel(logging.ERROR)
log = logging.getLogger('op_s_transformers')
log.setLevel(logging.ERROR)
class ConvertModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.net = model
try:
self.input_names = self.net.tokenizer.model_input_names
except AttributeError:
self.input_names = list(self.net.tokenize(['test']).keys())
def forward(self, *args, **kwargs):
if args:
assert kwargs == {}, 'Only accept neither args or kwargs as inputs.'
assert len(args) == len(self.input_names)
for k, v in zip(self.input_names, args):
kwargs[k] = v
outs = self.net(kwargs)
return outs['sentence_embedding']
# @accelerate
class Model:
def __init__(self, model_name, device):
self.device = device
self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device)
self.model.eval()
def __call__(self, **features):
outs = self.model(features)
return outs['sentence_embedding']
class STransformers(NNOperator):
"""
Operator using pretrained Sentence Transformers
"""
def __init__(self, model_name: str = None, device: str = None):
self.model_name = model_name
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
if self.model_name:
self.model = Model(model_name=self.model_name, device=self.device)
else:
log.warning('The operator is initialized without specified model.')
pass
def __call__(self, txt: Union[List[str], str]):
if isinstance(txt, str):
sentences = [txt]
else:
sentences = txt
inputs = self.tokenize(sentences)
embs = self.model(**inputs).cpu().detach().numpy()
if isinstance(txt, str):
embs = embs.squeeze(0)
else:
embs = list(embs)
return embs
@property
def supported_formats(self):
return ['onnx']
def tokenize(self, x):
try:
outs = self._model.tokenize(x)
except Exception:
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
outs = tokenizer(
x,
padding=True, truncation='longest_first', max_length=self.max_seq_length,
return_tensors='pt',
)
return outs
@property
def max_seq_length(self):
import json
from torch.hub import _get_torch_home
torch_cache = _get_torch_home()
sbert_cache = os.path.join(torch_cache, 'sentence_transformers')
cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json')
if not os.path.exists(cfg_path):
cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json')
k = 'max_position_embeddings'
else:
k = 'max_seq_length'
with open(cfg_path) as f:
cfg = json.load(f)
if k in cfg:
max_seq_len = cfg[k]
else:
max_seq_len = None
return max_seq_len
@property
def _model(self):
return self.model.model
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = self.model_name.replace('/', '-')
path = os.path.join(path, name)
if format in ['pytorch', 'torchscript']:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise AttributeError(f'Invalid format {format}.')
dummy_text = ['[CLS]']
dummy_input = self.tokenize(dummy_text)
if format == 'pytorch':
torch.save(self._model, path)
elif format == 'torchscript':
try:
try:
jit_model = torch.jit.script(self._model)
except Exception:
jit_model = torch.jit.trace(self._model, dummy_input, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
new_model = ConvertModel(self._model)
input_names = list(dummy_input.keys())
dynamic_axes = {}
for i_n, i_v in dummy_input.items():
if len(i_v.shape) == 1:
dynamic_axes[i_n] = {0: 'batch_size'}
else:
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'}
try:
torch.onnx.export(new_model,
tuple(dummy_input.values()),
path,
input_names=input_names,
output_names=['output_0'],
opset_version=13,
dynamic_axes=dynamic_axes,
do_constant_folding=True
)
except Exception as e:
log.error(f'Fail to save as onnx: {e}.')
raise RuntimeError(f'Fail to save as onnx: {e}.')
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')
return Path(path).resolve()
@staticmethod
def supported_model_names(format: str = None):
import requests
req = requests.get("https://www.sbert.net/_static/html/models_en_sentence_embeddings.html")
data = req.text
full_list = []
for line in data.split('\r\n'):
line = line.replace(' ', '')
if line.startswith('"name":'):
name = line.split(':')[-1].replace('"', '').replace(',', '')
full_list.append(name)
full_list.sort()
if format is None:
model_list = full_list
elif format == 'pytorch':
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
elif format == 'onnx':
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
else:
log.error(f'Invalid or unsupported format "{format}".')
return model_list