diff --git a/README.md b/README.md index d345587..54b2ac4 100644 --- a/README.md +++ b/README.md @@ -50,13 +50,15 @@ The model name in string, defaults to None. If None, the operator will be initialized without specified model. Supported model names: NLP transformers models listed in [Huggingface Models](https://huggingface.co/models). -Please note that only models listed in `supported_model_names` are tested. -You can refer to [Towhee Pipeline]() for model performance. +Please note that only models listed in `supported_model_names` are tested by us. +You can refer to [Towhee Pipeline](https://towhee.io/tasks/detail/pipeline/sentence-similarity) for model performance. ***checkpoint_path***: *str* The path to local checkpoint, defaults to None. -If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers. +- If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers. +- The checkpoint path could be a path to a directory containing model weights saved using [`save_pretrained()` by HuggingFace Transformers](https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/model#transformers.PreTrainedModel.save_pretrained). +- Or you can pass a path to a PyTorch `state_dict` save file. ***tokenizer***: *object* diff --git a/auto_transformers.py b/auto_transformers.py index 236a841..580e222 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -44,15 +44,27 @@ t_logging.set_verbosity_error() def create_model(model_name, checkpoint_path, device): - model = AutoModel.from_pretrained(model_name).to(device) + _torch_weights = False + if checkpoint_path: + if os.path.isdir(checkpoint_path) and \ + os.path.exists(os.path.join(checkpoint_path, 'config.json')): + model = AutoModel.from_pretrained(checkpoint_path) + else: + model = AutoConfig.from_pretrained(model_name) + _torch_weights = True + else: + model = AutoModel.from_pretrained(model_name) + + model = model.to(device) if hasattr(model, 'pooler') and model.pooler: model.pooler = None - if checkpoint_path: + + if _torch_weights: try: state_dict = torch.load(checkpoint_path, map_location=device) model.load_state_dict(state_dict) except Exception: - log.error(f'Fail to load weights from {checkpoint_path}') + log.error(f'Failed to load weights from {checkpoint_path}') model.eval() return model