From e458a883c0e608aae144e27774025098db9bc5b9 Mon Sep 17 00:00:00 2001 From: wxywb Date: Fri, 3 Feb 2023 09:33:09 +0000 Subject: [PATCH] fix the wrapper of the model. Signed-off-by: wxywb --- clip.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/clip.py b/clip.py index 6e78f36..14e23ea 100644 --- a/clip.py +++ b/clip.py @@ -54,25 +54,20 @@ class Clip(NNOperator): self.model_name = model_name self.modality = modality self.device = device + self.checkpoint_path = checkpoint_path cfg = self._configs()[model_name] try: clip_model = CLIPModel.from_pretrained(cfg) except Exception as e: log.error(f"Fail to load model by name: {self.model_name}") raise e - if checkpoint_path: - try: - state_dict = torch.load(checkpoint_path, map_location=self.device) - self.model.load_state_dict(state_dict) - except Exception as e: - log.error(f"Fail to load state dict from {checkpoint_path}: {e}") + if self.modality == 'image': - self.model = CLIPModelVision(clip_model) + self.model = CLIPModelVision(self._model) elif self.modality == 'text': - self.model = CLIPModelText(clip_model) + self.model = CLIPModelText(self._model) else: raise ValueError("modality[{}] not implemented.".format(self.modality)) - self.model.to(self.device) self.tokenizer = CLIPTokenizer.from_pretrained(cfg) self.processor = CLIPProcessor.from_pretrained(cfg) @@ -153,7 +148,21 @@ class Clip(NNOperator): @property def _model(self): - return self.model + cfg = self._configs()[self.model_name] + try: + hf_clip_model = CLIPModel.from_pretrained(cfg) + except Exception as e: + log.error(f"Fail to load model by name: {self.model_name}") + raise e + if self.checkpoint_path: + try: + state_dict = torch.load(self.checkpoint_path, map_location=self.device) + hf_clip_model.load_state_dict(state_dict) + except Exception as e: + log.error(f"Fail to load state dict from {checkpoint_path}: {e}") + hf_clip_model.to(self.device) + hf_clip_model.eval() + return hf_clip_model def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): import os