logo
Browse Source

Update parameters

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
a00786d445
  1. 20
      README.md
  2. 38
      auto_transformers.py

20
README.md

@ -48,14 +48,14 @@ import towhee
Create the operator via the following factory method: Create the operator via the following factory method:
***text_embedding.transformers(model_name="bert-base-uncased")***
***text_embedding.transformers(model_name=None)***
**Parameters:** **Parameters:**
***model_name***: *str* ***model_name***: *str*
The model name in string.
The default model name is "bert-base-uncased".
The model name in string, defaults to None.
If None, the operator will be initialized without specified model.
Supported model names: Supported model names:
@ -310,6 +310,20 @@ Supported model names:
<br /> <br />
***checkpoint_path***: *str*
The path to local checkpoint, defaults to None.
If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers.
<br />
***tokenizer***: *object*
The method to tokenize input text, defaults to None.
If None, the operator will use default tokenizer by `model_name` from Huggingface transformers.
<br />
## Interface ## Interface
The operator takes a piece of text in string as input. The operator takes a piece of text in string as input.

38
auto_transformers.py

@ -40,11 +40,19 @@ class AutoTransformers(NNOperator):
NLP embedding operator that uses the pretrained transformers model gathered by huggingface. NLP embedding operator that uses the pretrained transformers model gathered by huggingface.
Args: Args:
model_name (`str`): model_name (`str`):
Which model to use for the embeddings.
The model name to load a pretrained model from transformers.
checkpoint_path (`str`):
The local checkpoint path.
tokenizer (`object`):
The tokenizer to tokenize input text as model inputs.
""" """
def __init__(self, model_name: str = None, device: str = None, pretrain_weights_path=None,
load_pretrain_f=None, tokenizer=None) -> None:
def __init__(self,
model_name: str = None,
checkpoint_path: str = None,
tokenizer: object = None,
device: str = None,
):
super().__init__() super().__init__()
if device is None: if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ -52,29 +60,31 @@ class AutoTransformers(NNOperator):
self.model_name = model_name self.model_name = model_name
if self.model_name: if self.model_name:
model_list = self.supported_model_names()
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
try: try:
self.model = AutoModel.from_pretrained(model_name).to(self.device) self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.configs = self.model.config self.configs = self.model.config
except Exception as e: except Exception as e:
model_list = self.supported_model_names()
if model_name not in model_list:
log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}")
else:
log.error(f"Fail to load model by name: {self.model_name}")
log.error(f"Fail to load model by name: {self.model_name}")
raise e raise e
if pretrain_weights_path is not None:
if load_pretrain_f is None:
state_dict = torch.load(pretrain_weights_path, map_location='cpu')
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
else:
self.model = load_pretrain_f(self.model, pretrain_weights_path)
except Exception as e:
log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
self.model.eval() self.model.eval()
if tokenizer is None: if tokenizer is None:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e: except Exception as e:
log.error(f'Fail to load tokenizer by name: {self.model_name}')
log.error(f'Fail to load default tokenizer by name: {self.model_name}')
raise e raise e
else:
self.tokenizer = tokenizer
else: else:
log.warning('The operator is initialized without specified model.') log.warning('The operator is initialized without specified model.')
pass pass

Loading…
Cancel
Save