logo
Browse Source

Deal with model & Model

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
41682a3f61
  1. 35
      auto_transformers.py

35
auto_transformers.py

@ -25,6 +25,7 @@ from transformers import AutoTokenizer, AutoConfig, AutoModel
from towhee.operator import NNOperator
from towhee import register
# from towhee.serve.triton.triton_client import TritonClient
# from towhee.dc2 import accelerate
import warnings
@ -37,10 +38,24 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error()
def create_model(model_name, checkpoint_path, device):
model = AutoModel.from_pretrained(model_name).to(device)
if hasattr(model, 'pooler') and model.pooler:
model.pooler = None
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
except Exception:
log.error(f'Fail to load weights from {checkpoint_path}')
model.eval()
return model
# @accelerate
class Model:
def __init__(self, model):
self.model = model
def __init__(self, model_name, checkpoint_path, device):
self.model = create_model(model_name, checkpoint_path, device)
def __call__(self, *args, **kwargs):
outs = self.model(*args, **kwargs, return_dict=True)
@ -79,7 +94,7 @@ class AutoTransformers(NNOperator):
if self.model_name:
# model_list = self.supported_model_names()
# assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
self.model = Model(self._model)
self.model = Model(model_name=self.model_name, checkpoint_path=self.checkpoint_path, device=self.device)
if tokenizer:
self.tokenizer = tokenizer
else:
@ -117,16 +132,10 @@ class AutoTransformers(NNOperator):
@property
def _model(self):
model = AutoModel.from_pretrained(self.model_name).to(self.device)
if hasattr(model, 'pooler') and model.pooler:
model.pooler = None
if self.checkpoint_path:
try:
state_dict = torch.load(self.checkpoint_path, map_location=self.device)
model.load_state_dict(state_dict)
except Exception:
log.error(f'Fail to load weights from {self.checkpoint_path}')
model.eval()
# if isinstance(self.model, TritonClient):
# model = create_model(self.model_name, self.checkpoint_path, self.device)
# else:
model = self.model.model
return model
@property

Loading…
Cancel
Save