diff --git a/README.md b/README.md
index 3838742..d695696 100644
--- a/README.md
+++ b/README.md
@@ -48,14 +48,14 @@ import towhee
Create the operator via the following factory method:
-***text_embedding.transformers(model_name="bert-base-uncased")***
+***text_embedding.transformers(model_name=None)***
**Parameters:**
***model_name***: *str*
-The model name in string.
-The default model name is "bert-base-uncased".
+The model name in string, defaults to None.
+If None, the operator will be initialized without specified model.
Supported model names:
@@ -307,6 +307,20 @@ Supported model names:
- uw-madison/yoso-4096
+
+
+
+***checkpoint_path***: *str*
+
+The path to local checkpoint, defaults to None.
+If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers.
+
+
+
+***tokenizer***: *object*
+
+The method to tokenize input text, defaults to None.
+If None, the operator will use default tokenizer by `model_name` from Huggingface transformers.
diff --git a/auto_transformers.py b/auto_transformers.py
index 01b92c8..074d1d7 100644
--- a/auto_transformers.py
+++ b/auto_transformers.py
@@ -40,11 +40,19 @@ class AutoTransformers(NNOperator):
NLP embedding operator that uses the pretrained transformers model gathered by huggingface.
Args:
model_name (`str`):
- Which model to use for the embeddings.
+ The model name to load a pretrained model from transformers.
+ checkpoint_path (`str`):
+ The local checkpoint path.
+ tokenizer (`object`):
+ The tokenizer to tokenize input text as model inputs.
"""
- def __init__(self, model_name: str = None, device: str = None, pretrain_weights_path=None,
- load_pretrain_f=None, tokenizer=None) -> None:
+ def __init__(self,
+ model_name: str = None,
+ checkpoint_path: str = None,
+ tokenizer: object = None,
+ device: str = None,
+ ):
super().__init__()
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -52,29 +60,31 @@ class AutoTransformers(NNOperator):
self.model_name = model_name
if self.model_name:
+ model_list = self.supported_model_names()
+ assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
+
try:
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.configs = self.model.config
except Exception as e:
- model_list = self.supported_model_names()
- if model_name not in model_list:
- log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}")
- else:
- log.error(f"Fail to load model by name: {self.model_name}")
+ log.error(f"Fail to load model by name: {self.model_name}")
raise e
- if pretrain_weights_path is not None:
- if load_pretrain_f is None:
- state_dict = torch.load(pretrain_weights_path, map_location='cpu')
+ if checkpoint_path:
+ try:
+ state_dict = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(state_dict)
- else:
- self.model = load_pretrain_f(self.model, pretrain_weights_path)
+ except Exception as e:
+ log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
self.model.eval()
+
if tokenizer is None:
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
- log.error(f'Fail to load tokenizer by name: {self.model_name}')
+ log.error(f'Fail to load default tokenizer by name: {self.model_name}')
raise e
+ else:
+ self.tokenizer = tokenizer
else:
log.warning('The operator is initialized without specified model.')
pass