logo
Browse Source

Support triton server

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 8 months ago
parent
commit
dd2f6db677
  1. 81
      s_bert.py

81
s_bert.py

@ -16,6 +16,7 @@ import logging
import numpy import numpy
from typing import Union, List from typing import Union, List
from pathlib import Path from pathlib import Path
from functools import partial
import torch import torch
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
@ -62,8 +63,11 @@ class Model:
self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device) self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device)
self.model.eval() self.model.eval()
def __call__(self, **features):
outs = self.model(features)
def __call__(self, *_, **kwargs):
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v.to(self.device)
outs = self.model(new_kwargs)
return outs['sentence_embedding'] return outs['sentence_embedding']
@ -81,16 +85,14 @@ class STransformers(NNOperator):
self.model = Model(model_name=self.model_name, device=self.device) self.model = Model(model_name=self.model_name, device=self.device)
else: else:
log.warning('The operator is initialized without specified model.') log.warning('The operator is initialized without specified model.')
pass
self._tokenize = self.get_tokenizer()
def __call__(self, txt: Union[List[str], str]): def __call__(self, txt: Union[List[str], str]):
if isinstance(txt, str): if isinstance(txt, str):
sentences = [txt] sentences = [txt]
else: else:
sentences = txt sentences = txt
inputs = self.tokenize(sentences)
# for k, v in inputs.items():
# inputs[k] = v.to(self.device)
inputs = self._tokenize(sentences)
embs = self.model(**inputs).cpu().detach().numpy() embs = self.model(**inputs).cpu().detach().numpy()
if isinstance(txt, str): if isinstance(txt, str):
embs = embs.squeeze(0) embs = embs.squeeze(0)
@ -102,41 +104,41 @@ class STransformers(NNOperator):
def supported_formats(self): def supported_formats(self):
return ['onnx'] return ['onnx']
def tokenize(self, x):
try:
outs = self._model.tokenize(x)
except Exception:
from transformers import AutoTokenizer
def get_tokenizer(self):
if hasattr(self._model, "tokenize"):
return self._model.tokenize
else:
from transformers import AutoTokenizer, AutoConfig
try: try:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name)
conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name)
except Exception: except Exception:
tokenizer = AutoTokenizer.from_pretrained(self.model_name) tokenizer = AutoTokenizer.from_pretrained(self.model_name)
outs = tokenizer(
x,
padding=True, truncation='longest_first', max_length=self.max_seq_length,
return_tensors='pt',
)
return outs
@property
def max_seq_length(self):
import json
from torch.hub import _get_torch_home
torch_cache = _get_torch_home()
sbert_cache = os.path.join(torch_cache, 'sentence_transformers')
cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json')
if not os.path.exists(cfg_path):
cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json')
k = 'max_position_embeddings'
else:
k = 'max_seq_length'
with open(cfg_path) as f:
cfg = json.load(f)
if k in cfg:
max_seq_len = cfg[k]
else:
max_seq_len = None
return max_seq_len
conf = AutoConfig.from_pretrained(self.model_name)
return partial(tokenizer,
padding=True,
truncation='longest_first',
max_length=conf.max_position_embeddings,
return_tensors='pt')
# @property
# def max_seq_length(self):
# import json
# from torch.hub import _get_torch_home
# torch_cache = _get_torch_home()
# sbert_cache = os.path.join(torch_cache, 'sentence_transformers')
# cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json')
# if not os.path.exists(cfg_path):
# cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json')
# k = 'max_position_embeddings'
# else:
# k = 'max_seq_length'
# with open(cfg_path) as f:
# cfg = json.load(f)
# if k in cfg:
# max_seq_len = cfg[k]
# else:
# max_seq_len = None
# return max_seq_len
@property @property
def _model(self): def _model(self):
@ -156,7 +158,7 @@ class STransformers(NNOperator):
else: else:
raise AttributeError(f'Invalid format {format}.') raise AttributeError(f'Invalid format {format}.')
dummy_text = ['[CLS]'] dummy_text = ['[CLS]']
dummy_input = self.tokenize(dummy_text)
dummy_input = self._tokenize(dummy_text)
if format == 'pytorch': if format == 'pytorch':
torch.save(self._model, path) torch.save(self._model, path)
elif format == 'torchscript': elif format == 'torchscript':
@ -180,7 +182,7 @@ class STransformers(NNOperator):
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'} dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'}
try: try:
torch.onnx.export(new_model,
torch.onnx.export(new_model.to('cpu'),
tuple(dummy_input.values()), tuple(dummy_input.values()),
path, path,
input_names=input_names, input_names=input_names,
@ -200,6 +202,7 @@ class STransformers(NNOperator):
@staticmethod @staticmethod
def supported_model_names(format: str = None): def supported_model_names(format: str = None):
full_list = [ full_list = [
'clip-ViT-B-32-multilingual-v1',
'sentence-t5-xxl', 'sentence-t5-xxl',
'sentence-t5-xl', 'sentence-t5-xl',
'sentence-t5-large', 'sentence-t5-large',

Loading…
Cancel
Save