diff --git a/auto_transformers.py b/auto_transformers.py index b91d264..02e0f2a 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -25,6 +25,7 @@ from transformers import AutoTokenizer, AutoConfig, AutoModel from towhee.operator import NNOperator from towhee import register +# from towhee.serve.triton.triton_client import TritonClient # from towhee.dc2 import accelerate import warnings @@ -37,10 +38,24 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' t_logging.set_verbosity_error() +def create_model(model_name, checkpoint_path, device): + model = AutoModel.from_pretrained(model_name).to(device) + if hasattr(model, 'pooler') and model.pooler: + model.pooler = None + if checkpoint_path: + try: + state_dict = torch.load(checkpoint_path, map_location=device) + model.load_state_dict(state_dict) + except Exception: + log.error(f'Fail to load weights from {checkpoint_path}') + model.eval() + return model + + # @accelerate class Model: - def __init__(self, model): - self.model = model + def __init__(self, model_name, checkpoint_path, device): + self.model = create_model(model_name, checkpoint_path, device) def __call__(self, *args, **kwargs): outs = self.model(*args, **kwargs, return_dict=True) @@ -79,7 +94,7 @@ class AutoTransformers(NNOperator): if self.model_name: # model_list = self.supported_model_names() # assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" - self.model = Model(self._model) + self.model = Model(model_name=self.model_name, checkpoint_path=self.checkpoint_path, device=self.device) if tokenizer: self.tokenizer = tokenizer else: @@ -117,16 +132,10 @@ class AutoTransformers(NNOperator): @property def _model(self): - model = AutoModel.from_pretrained(self.model_name).to(self.device) - if hasattr(model, 'pooler') and model.pooler: - model.pooler = None - if self.checkpoint_path: - try: - state_dict = torch.load(self.checkpoint_path, map_location=self.device) - model.load_state_dict(state_dict) - except Exception: - log.error(f'Fail to load weights from {self.checkpoint_path}') - model.eval() + # if isinstance(self.model, TritonClient): + # model = create_model(self.model_name, self.checkpoint_path, self.device) + # else: + model = self.model.model return model @property