logo
Browse Source

Support Towhee TritonServer

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
e6c7f3d8e0
  1. 49
      auto_transformers.py

49
auto_transformers.py

@ -22,6 +22,7 @@ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoMod
from towhee.operator import NNOperator from towhee.operator import NNOperator
from towhee import register from towhee import register
from towhee.dc2 import accelerate
import warnings import warnings
import logging import logging
@ -34,6 +35,26 @@ log = logging.getLogger('run_op')
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
@accelerate
class Model:
def __init__(self, model_name, device, checkpoint_path):
try:
self.model = AutoModel.from_pretrained(model_name).to(device)
except Exception as e:
log.error(f"Fail to load model by name: {self.model_name}")
raise e
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=device)
self.model.load_state_dict(state_dict)
except Exception as e:
log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
self.model.eval()
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
@register(output_schema=['vec']) @register(output_schema=['vec'])
class AutoTransformers(NNOperator): class AutoTransformers(NNOperator):
""" """
@ -57,26 +78,14 @@ class AutoTransformers(NNOperator):
if device is None: if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device self.device = device
model_list = self.supported_model_names()
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
self.model_name = model_name self.model_name = model_name
if self.model_name: if self.model_name:
model_list = self.supported_model_names()
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
try:
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.configs = self.model.config
except Exception as e:
log.error(f"Fail to load model by name: {self.model_name}")
raise e
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(state_dict)
except Exception as e:
log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
self.model.eval()
self.model = Model(model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path)
self.configs = self.model.model.config
if tokenizer is None: if tokenizer is None:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -135,13 +144,13 @@ class AutoTransformers(NNOperator):
from transformers.onnx.features import FeaturesManager from transformers.onnx.features import FeaturesManager
from transformers.onnx import export from transformers.onnx import export
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self.model, feature='default')
onnx_config = model_onnx_config(self.model.config)
self.model.model, feature='default')
onnx_config = model_onnx_config(self.configs)
if os.path.isdir(path): if os.path.isdir(path):
shutil.rmtree(path) shutil.rmtree(path)
onnx_inputs, onnx_outputs = export( onnx_inputs, onnx_outputs = export(
self.tokenizer, self.tokenizer,
self.model,
self.model.model,
config=onnx_config, config=onnx_config,
opset=13, opset=13,
output=Path(path+'.onnx') output=Path(path+'.onnx')

Loading…
Cancel
Save