diff --git a/clip.py b/clip.py index c64f453..aa8cea1 100644 --- a/clip.py +++ b/clip.py @@ -38,13 +38,11 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' t_logging.set_verbosity_error() def create_model(model_name, modality, checkpoint_path, device): - hf_clip_model = CLIPModel.from_pretrained(model_name) - if checkpoint_path: - try: - state_dict = torch.load(checkpoint_path, map_location=device) - hf_clip_model.load_state_dict(state_dict) - except Exception as e: - log.error(f"Fail to load state dict from {checkpoint_path}: {e}") + if checkpoint_path is None: + hf_clip_model = CLIPModel.from_pretrained(model_name) + else: + hf_clip_config = CLIPModel.from_config(model_name) + hf_clip_model = CLIPModel.from_pretrained(checkpoint_path, config=hf_clip_config) hf_clip_model.to(device) hf_clip_model.eval() @@ -101,7 +99,10 @@ class Clip(NNOperator): self.modality = modality self.device = device self.checkpoint_path = checkpoint_path - real_name = self._configs()[model_name] + config = self._configs() + real_name = model_name + if model_name in config: + real_name = config[model_name] self.model = Model(real_name, modality, checkpoint_path, device) self.tokenizer = CLIPTokenizer.from_pretrained(real_name)