logo
Browse Source

Support more flexible method to load local checkpoint

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
eac306324b
  1. 8
      README.md
  2. 18
      auto_transformers.py

8
README.md

@ -50,13 +50,15 @@ The model name in string, defaults to None.
If None, the operator will be initialized without specified model. If None, the operator will be initialized without specified model.
Supported model names: NLP transformers models listed in [Huggingface Models](https://huggingface.co/models). Supported model names: NLP transformers models listed in [Huggingface Models](https://huggingface.co/models).
Please note that only models listed in `supported_model_names` are tested.
You can refer to [Towhee Pipeline]() for model performance.
Please note that only models listed in `supported_model_names` are tested by us.
You can refer to [Towhee Pipeline](https://towhee.io/tasks/detail/pipeline/sentence-similarity) for model performance.
***checkpoint_path***: *str* ***checkpoint_path***: *str*
The path to local checkpoint, defaults to None. The path to local checkpoint, defaults to None.
If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers.
- If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers.
- The checkpoint path could be a path to a directory containing model weights saved using [`save_pretrained()` by HuggingFace Transformers](https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/model#transformers.PreTrainedModel.save_pretrained).
- Or you can pass a path to a PyTorch `state_dict` save file.
***tokenizer***: *object* ***tokenizer***: *object*

18
auto_transformers.py

@ -44,15 +44,27 @@ t_logging.set_verbosity_error()
def create_model(model_name, checkpoint_path, device): def create_model(model_name, checkpoint_path, device):
model = AutoModel.from_pretrained(model_name).to(device)
_torch_weights = False
if checkpoint_path:
if os.path.isdir(checkpoint_path) and \
os.path.exists(os.path.join(checkpoint_path, 'config.json')):
model = AutoModel.from_pretrained(checkpoint_path)
else:
model = AutoConfig.from_pretrained(model_name)
_torch_weights = True
else:
model = AutoModel.from_pretrained(model_name)
model = model.to(device)
if hasattr(model, 'pooler') and model.pooler: if hasattr(model, 'pooler') and model.pooler:
model.pooler = None model.pooler = None
if checkpoint_path:
if _torch_weights:
try: try:
state_dict = torch.load(checkpoint_path, map_location=device) state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
except Exception: except Exception:
log.error(f'Fail to load weights from {checkpoint_path}')
log.error(f'Failed to load weights from {checkpoint_path}')
model.eval() model.eval()
return model return model

Loading…
Cancel
Save