diff --git a/clip.py b/clip.py index aa8cea1..8bedb00 100644 --- a/clip.py +++ b/clip.py @@ -41,8 +41,7 @@ def create_model(model_name, modality, checkpoint_path, device): if checkpoint_path is None: hf_clip_model = CLIPModel.from_pretrained(model_name) else: - hf_clip_config = CLIPModel.from_config(model_name) - hf_clip_model = CLIPModel.from_pretrained(checkpoint_path, config=hf_clip_config) + hf_clip_model = CLIPModel.from_pretrained(checkpoint_path) hf_clip_model.to(device) hf_clip_model.eval()