|
|
@ -25,6 +25,7 @@ from transformers import AutoTokenizer, AutoConfig, AutoModel |
|
|
|
|
|
|
|
from towhee.operator import NNOperator |
|
|
|
from towhee import register |
|
|
|
# from towhee.serve.triton.triton_client import TritonClient |
|
|
|
# from towhee.dc2 import accelerate |
|
|
|
|
|
|
|
import warnings |
|
|
@ -37,10 +38,24 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
t_logging.set_verbosity_error() |
|
|
|
|
|
|
|
|
|
|
|
def create_model(model_name, checkpoint_path, device): |
|
|
|
model = AutoModel.from_pretrained(model_name).to(device) |
|
|
|
if hasattr(model, 'pooler') and model.pooler: |
|
|
|
model.pooler = None |
|
|
|
if checkpoint_path: |
|
|
|
try: |
|
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
|
model.load_state_dict(state_dict) |
|
|
|
except Exception: |
|
|
|
log.error(f'Fail to load weights from {checkpoint_path}') |
|
|
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
# @accelerate |
|
|
|
class Model: |
|
|
|
def __init__(self, model): |
|
|
|
self.model = model |
|
|
|
def __init__(self, model_name, checkpoint_path, device): |
|
|
|
self.model = create_model(model_name, checkpoint_path, device) |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
outs = self.model(*args, **kwargs, return_dict=True) |
|
|
@ -79,7 +94,7 @@ class AutoTransformers(NNOperator): |
|
|
|
if self.model_name: |
|
|
|
# model_list = self.supported_model_names() |
|
|
|
# assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" |
|
|
|
self.model = Model(self._model) |
|
|
|
self.model = Model(model_name=self.model_name, checkpoint_path=self.checkpoint_path, device=self.device) |
|
|
|
if tokenizer: |
|
|
|
self.tokenizer = tokenizer |
|
|
|
else: |
|
|
@ -117,16 +132,10 @@ class AutoTransformers(NNOperator): |
|
|
|
|
|
|
|
@property |
|
|
|
def _model(self): |
|
|
|
model = AutoModel.from_pretrained(self.model_name).to(self.device) |
|
|
|
if hasattr(model, 'pooler') and model.pooler: |
|
|
|
model.pooler = None |
|
|
|
if self.checkpoint_path: |
|
|
|
try: |
|
|
|
state_dict = torch.load(self.checkpoint_path, map_location=self.device) |
|
|
|
model.load_state_dict(state_dict) |
|
|
|
except Exception: |
|
|
|
log.error(f'Fail to load weights from {self.checkpoint_path}') |
|
|
|
model.eval() |
|
|
|
# if isinstance(self.model, TritonClient): |
|
|
|
# model = create_model(self.model_name, self.checkpoint_path, self.device) |
|
|
|
# else: |
|
|
|
model = self.model.model |
|
|
|
return model |
|
|
|
|
|
|
|
@property |
|
|
|