logo
Browse Source

Support more flexible method to load local checkpoint

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
eac306324b
  1. 8
      README.md
  2. 18
      auto_transformers.py

8
README.md

@ -50,13 +50,15 @@ The model name in string, defaults to None.
If None, the operator will be initialized without specified model.
Supported model names: NLP transformers models listed in [Huggingface Models](https://huggingface.co/models).
Please note that only models listed in `supported_model_names` are tested.
You can refer to [Towhee Pipeline]() for model performance.
Please note that only models listed in `supported_model_names` are tested by us.
You can refer to [Towhee Pipeline](https://towhee.io/tasks/detail/pipeline/sentence-similarity) for model performance.
***checkpoint_path***: *str*
The path to local checkpoint, defaults to None.
If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers.
- If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers.
- The checkpoint path could be a path to a directory containing model weights saved using [`save_pretrained()` by HuggingFace Transformers](https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/model#transformers.PreTrainedModel.save_pretrained).
- Or you can pass a path to a PyTorch `state_dict` save file.
***tokenizer***: *object*

18
auto_transformers.py

@ -44,15 +44,27 @@ t_logging.set_verbosity_error()
def create_model(model_name, checkpoint_path, device):
model = AutoModel.from_pretrained(model_name).to(device)
_torch_weights = False
if checkpoint_path:
if os.path.isdir(checkpoint_path) and \
os.path.exists(os.path.join(checkpoint_path, 'config.json')):
model = AutoModel.from_pretrained(checkpoint_path)
else:
model = AutoConfig.from_pretrained(model_name)
_torch_weights = True
else:
model = AutoModel.from_pretrained(model_name)
model = model.to(device)
if hasattr(model, 'pooler') and model.pooler:
model.pooler = None
if checkpoint_path:
if _torch_weights:
try:
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
except Exception:
log.error(f'Fail to load weights from {checkpoint_path}')
log.error(f'Failed to load weights from {checkpoint_path}')
model.eval()
return model

Loading…
Cancel
Save