From 28419844ff7595b09d1e1bfe6a3bef2964191d82 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 16 Feb 2023 18:55:03 +0800 Subject: [PATCH] Update Signed-off-by: Jael Gu --- README.md | 8 ++++++- auto_transformers.py | 56 ++++++++++++++++++++++++-------------------- 2 files changed, 38 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 1f4b524..21e7e5d 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ Create the operator via the following factory method: The model name in string, defaults to None. If None, the operator will be initialized without specified model. -Supported model names: +Please note only supported models are tested by us:
Albert @@ -307,6 +307,12 @@ Supported model names: The path to local checkpoint, defaults to None. If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers. +
+ +***device***: *str* + +The device in string, defaults to None. If None, it will enable "cuda" automatically when cuda is available. +
***tokenizer***: *object* diff --git a/auto_transformers.py b/auto_transformers.py index cc9f6d7..a5e741f 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -37,14 +37,33 @@ warnings.filterwarnings('ignore') os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' t_logging.set_verbosity_error() +def create_model(model_name, checkpoint_path, device): + model = AutoModel.from_pretrained(model_name).to(device) + if hasattr(model, 'pooler') and model.pooler: + model.pooler = None + if checkpoint_path: + try: + state_dict = torch.load(checkpoint_path, map_location=device) + model.load_state_dict(state_dict) + except Exception: + log.error(f'Fail to load weights from {checkpoint_path}') + model.eval() + return model # @accelerate class Model: - def __init__(self, model): - self.model = model + def __init__(self, model_name, checkpoint_path, device): + self.device = device + self.model = create_model(model_name, checkpoint_path, device) def __call__(self, *args, **kwargs): - outs = self.model(*args, **kwargs, return_dict=True) + new_args = [] + for x in args: + new_args.append(x.to(self.device)) + new_kwargs = {} + for k, v in kwargs.items(): + new_kwargs[k] = v.to(self.device) + outs = self.model(*new_args, **new_kwargs, return_dict=True) return outs['last_hidden_state'] @@ -75,17 +94,13 @@ class AutoTransformers(NNOperator): self.checkpoint_path = checkpoint_path if self.model_name: - model_list = self.supported_model_names() - assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" - self.model = Model(self._model) - if tokenizer is None: - try: - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - except Exception as e: - log.error(f'Fail to load default tokenizer by name: {self.model_name}') - raise e - else: + # model_list = self.supported_model_names() + # assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" + self.model = Model(model_name=self.model_name, checkpoint_path=self.checkpoint_path, device=self.device) + if tokenizer: self.tokenizer = tokenizer + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) if not self.tokenizer.pad_token: self.tokenizer.pad_token = '[PAD]' else: @@ -98,7 +113,7 @@ class AutoTransformers(NNOperator): else: txt = data try: - inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device) + inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt') except Exception as e: log.error(f'Fail to tokenize inputs: {e}') raise e @@ -116,17 +131,7 @@ class AutoTransformers(NNOperator): @property def _model(self): - model = AutoModel.from_pretrained(self.model_name).to(self.device) - if hasattr(model, 'pooler') and model.pooler: - model.pooler = None - if self.checkpoint_path: - try: - state_dict = torch.load(self.checkpoint_path, map_location=self.device) - model.load_state_dict(state_dict) - except Exception: - log.error(f'Fail to load weights from {self.checkpoint_path}') - model.eval() - return model + return self.model.model def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): if output_file == 'default': @@ -160,6 +165,7 @@ class AutoTransformers(NNOperator): elif model_type == 'onnx': from transformers.onnx.features import FeaturesManager from transformers.onnx import export + self._model = self._model.to('cpu') model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( self._model, feature='default') onnx_config = model_onnx_config(self._model.config)