logo
Browse Source

Support triton server

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 8 months ago
parent
commit
dd2f6db677
  1. 81
      s_bert.py

81
s_bert.py

@ -16,6 +16,7 @@ import logging
import numpy
from typing import Union, List
from pathlib import Path
from functools import partial
import torch
from sentence_transformers import SentenceTransformer
@ -62,8 +63,11 @@ class Model:
self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device)
self.model.eval()
def __call__(self, **features):
outs = self.model(features)
def __call__(self, *_, **kwargs):
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v.to(self.device)
outs = self.model(new_kwargs)
return outs['sentence_embedding']
@ -81,16 +85,14 @@ class STransformers(NNOperator):
self.model = Model(model_name=self.model_name, device=self.device)
else:
log.warning('The operator is initialized without specified model.')
pass
self._tokenize = self.get_tokenizer()
def __call__(self, txt: Union[List[str], str]):
if isinstance(txt, str):
sentences = [txt]
else:
sentences = txt
inputs = self.tokenize(sentences)
# for k, v in inputs.items():
# inputs[k] = v.to(self.device)
inputs = self._tokenize(sentences)
embs = self.model(**inputs).cpu().detach().numpy()
if isinstance(txt, str):
embs = embs.squeeze(0)
@ -102,41 +104,41 @@ class STransformers(NNOperator):
def supported_formats(self):
return ['onnx']
def tokenize(self, x):
try:
outs = self._model.tokenize(x)
except Exception:
from transformers import AutoTokenizer
def get_tokenizer(self):
if hasattr(self._model, "tokenize"):
return self._model.tokenize
else:
from transformers import AutoTokenizer, AutoConfig
try:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name)
conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
outs = tokenizer(
x,
padding=True, truncation='longest_first', max_length=self.max_seq_length,
return_tensors='pt',
)
return outs
@property
def max_seq_length(self):
import json
from torch.hub import _get_torch_home
torch_cache = _get_torch_home()
sbert_cache = os.path.join(torch_cache, 'sentence_transformers')
cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json')
if not os.path.exists(cfg_path):
cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json')
k = 'max_position_embeddings'
else:
k = 'max_seq_length'
with open(cfg_path) as f:
cfg = json.load(f)
if k in cfg:
max_seq_len = cfg[k]
else:
max_seq_len = None
return max_seq_len
conf = AutoConfig.from_pretrained(self.model_name)
return partial(tokenizer,
padding=True,
truncation='longest_first',
max_length=conf.max_position_embeddings,
return_tensors='pt')
# @property
# def max_seq_length(self):
# import json
# from torch.hub import _get_torch_home
# torch_cache = _get_torch_home()
# sbert_cache = os.path.join(torch_cache, 'sentence_transformers')
# cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json')
# if not os.path.exists(cfg_path):
# cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json')
# k = 'max_position_embeddings'
# else:
# k = 'max_seq_length'
# with open(cfg_path) as f:
# cfg = json.load(f)
# if k in cfg:
# max_seq_len = cfg[k]
# else:
# max_seq_len = None
# return max_seq_len
@property
def _model(self):
@ -156,7 +158,7 @@ class STransformers(NNOperator):
else:
raise AttributeError(f'Invalid format {format}.')
dummy_text = ['[CLS]']
dummy_input = self.tokenize(dummy_text)
dummy_input = self._tokenize(dummy_text)
if format == 'pytorch':
torch.save(self._model, path)
elif format == 'torchscript':
@ -180,7 +182,7 @@ class STransformers(NNOperator):
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'}
try:
torch.onnx.export(new_model,
torch.onnx.export(new_model.to('cpu'),
tuple(dummy_input.values()),
path,
input_names=input_names,
@ -200,6 +202,7 @@ class STransformers(NNOperator):
@staticmethod
def supported_model_names(format: str = None):
full_list = [
'clip-ViT-B-32-multilingual-v1',
'sentence-t5-xxl',
'sentence-t5-xl',
'sentence-t5-large',

Loading…
Cancel
Save