From a00786d4452b39e92bfd2ba4c739ae5e2849a2b7 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 14 Dec 2022 11:40:20 +0800 Subject: [PATCH] Update parameters Signed-off-by: Jael Gu --- README.md | 20 +++++++++++++++++--- auto_transformers.py | 38 ++++++++++++++++++++++++-------------- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 3838742..d695696 100644 --- a/README.md +++ b/README.md @@ -48,14 +48,14 @@ import towhee Create the operator via the following factory method: -***text_embedding.transformers(model_name="bert-base-uncased")*** +***text_embedding.transformers(model_name=None)*** **Parameters:** ***model_name***: *str* -The model name in string. -The default model name is "bert-base-uncased". +The model name in string, defaults to None. +If None, the operator will be initialized without specified model. Supported model names: @@ -307,6 +307,20 @@ Supported model names: - uw-madison/yoso-4096 + +
+ +***checkpoint_path***: *str* + +The path to local checkpoint, defaults to None. +If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers. + +
+ +***tokenizer***: *object* + +The method to tokenize input text, defaults to None. +If None, the operator will use default tokenizer by `model_name` from Huggingface transformers.
diff --git a/auto_transformers.py b/auto_transformers.py index 01b92c8..074d1d7 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -40,11 +40,19 @@ class AutoTransformers(NNOperator): NLP embedding operator that uses the pretrained transformers model gathered by huggingface. Args: model_name (`str`): - Which model to use for the embeddings. + The model name to load a pretrained model from transformers. + checkpoint_path (`str`): + The local checkpoint path. + tokenizer (`object`): + The tokenizer to tokenize input text as model inputs. """ - def __init__(self, model_name: str = None, device: str = None, pretrain_weights_path=None, - load_pretrain_f=None, tokenizer=None) -> None: + def __init__(self, + model_name: str = None, + checkpoint_path: str = None, + tokenizer: object = None, + device: str = None, + ): super().__init__() if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -52,29 +60,31 @@ class AutoTransformers(NNOperator): self.model_name = model_name if self.model_name: + model_list = self.supported_model_names() + assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" + try: self.model = AutoModel.from_pretrained(model_name).to(self.device) self.configs = self.model.config except Exception as e: - model_list = self.supported_model_names() - if model_name not in model_list: - log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}") - else: - log.error(f"Fail to load model by name: {self.model_name}") + log.error(f"Fail to load model by name: {self.model_name}") raise e - if pretrain_weights_path is not None: - if load_pretrain_f is None: - state_dict = torch.load(pretrain_weights_path, map_location='cpu') + if checkpoint_path: + try: + state_dict = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(state_dict) - else: - self.model = load_pretrain_f(self.model, pretrain_weights_path) + except Exception as e: + log.error(f"Fail to load state dict from {checkpoint_path}: {e}") self.model.eval() + if tokenizer is None: try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) except Exception as e: - log.error(f'Fail to load tokenizer by name: {self.model_name}') + log.error(f'Fail to load default tokenizer by name: {self.model_name}') raise e + else: + self.tokenizer = tokenizer else: log.warning('The operator is initialized without specified model.') pass