logo
Browse Source

Support Towhee TritonServer

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
e6c7f3d8e0
  1. 49
      auto_transformers.py

49
auto_transformers.py

@ -22,6 +22,7 @@ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoMod
from towhee.operator import NNOperator
from towhee import register
from towhee.dc2 import accelerate
import warnings
import logging
@ -34,6 +35,26 @@ log = logging.getLogger('run_op')
warnings.filterwarnings('ignore')
@accelerate
class Model:
def __init__(self, model_name, device, checkpoint_path):
try:
self.model = AutoModel.from_pretrained(model_name).to(device)
except Exception as e:
log.error(f"Fail to load model by name: {self.model_name}")
raise e
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=device)
self.model.load_state_dict(state_dict)
except Exception as e:
log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
self.model.eval()
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
@register(output_schema=['vec'])
class AutoTransformers(NNOperator):
"""
@ -57,26 +78,14 @@ class AutoTransformers(NNOperator):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
model_list = self.supported_model_names()
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
self.model_name = model_name
if self.model_name:
model_list = self.supported_model_names()
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
try:
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.configs = self.model.config
except Exception as e:
log.error(f"Fail to load model by name: {self.model_name}")
raise e
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(state_dict)
except Exception as e:
log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
self.model.eval()
self.model = Model(model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path)
self.configs = self.model.model.config
if tokenizer is None:
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -135,13 +144,13 @@ class AutoTransformers(NNOperator):
from transformers.onnx.features import FeaturesManager
from transformers.onnx import export
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self.model, feature='default')
onnx_config = model_onnx_config(self.model.config)
self.model.model, feature='default')
onnx_config = model_onnx_config(self.configs)
if os.path.isdir(path):
shutil.rmtree(path)
onnx_inputs, onnx_outputs = export(
self.tokenizer,
self.model,
self.model.model,
config=onnx_config,
opset=13,
output=Path(path+'.onnx')

Loading…
Cancel
Save