diff --git a/auto_transformers.py b/auto_transformers.py index 9f22bce..225b3e2 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -22,6 +22,7 @@ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoMod from towhee.operator import NNOperator from towhee import register +from towhee.dc2 import accelerate import warnings import logging @@ -34,6 +35,26 @@ log = logging.getLogger('run_op') warnings.filterwarnings('ignore') +@accelerate +class Model: + def __init__(self, model_name, device, checkpoint_path): + try: + self.model = AutoModel.from_pretrained(model_name).to(device) + except Exception as e: + log.error(f"Fail to load model by name: {self.model_name}") + raise e + if checkpoint_path: + try: + state_dict = torch.load(checkpoint_path, map_location=device) + self.model.load_state_dict(state_dict) + except Exception as e: + log.error(f"Fail to load state dict from {checkpoint_path}: {e}") + self.model.eval() + + def __call__(self, *args, **kwargs): + return self.model(*args, **kwargs) + + @register(output_schema=['vec']) class AutoTransformers(NNOperator): """ @@ -57,26 +78,14 @@ class AutoTransformers(NNOperator): if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device + + model_list = self.supported_model_names() + assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" self.model_name = model_name if self.model_name: - model_list = self.supported_model_names() - assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" - - try: - self.model = AutoModel.from_pretrained(model_name).to(self.device) - self.configs = self.model.config - except Exception as e: - log.error(f"Fail to load model by name: {self.model_name}") - raise e - if checkpoint_path: - try: - state_dict = torch.load(checkpoint_path, map_location=self.device) - self.model.load_state_dict(state_dict) - except Exception as e: - log.error(f"Fail to load state dict from {checkpoint_path}: {e}") - self.model.eval() - + self.model = Model(model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path) + self.configs = self.model.model.config if tokenizer is None: try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -135,13 +144,13 @@ class AutoTransformers(NNOperator): from transformers.onnx.features import FeaturesManager from transformers.onnx import export model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( - self.model, feature='default') - onnx_config = model_onnx_config(self.model.config) + self.model.model, feature='default') + onnx_config = model_onnx_config(self.configs) if os.path.isdir(path): shutil.rmtree(path) onnx_inputs, onnx_outputs = export( self.tokenizer, - self.model, + self.model.model, config=onnx_config, opset=13, output=Path(path+'.onnx')