logo
Browse Source

fix the wrapper of the model.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
e458a883c0
  1. 29
      clip.py

29
clip.py

@ -54,25 +54,20 @@ class Clip(NNOperator):
self.model_name = model_name
self.modality = modality
self.device = device
self.checkpoint_path = checkpoint_path
cfg = self._configs()[model_name]
try:
clip_model = CLIPModel.from_pretrained(cfg)
except Exception as e:
log.error(f"Fail to load model by name: {self.model_name}")
raise e
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(state_dict)
except Exception as e:
log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
if self.modality == 'image':
self.model = CLIPModelVision(clip_model)
self.model = CLIPModelVision(self._model)
elif self.modality == 'text':
self.model = CLIPModelText(clip_model)
self.model = CLIPModelText(self._model)
else:
raise ValueError("modality[{}] not implemented.".format(self.modality))
self.model.to(self.device)
self.tokenizer = CLIPTokenizer.from_pretrained(cfg)
self.processor = CLIPProcessor.from_pretrained(cfg)
@ -153,7 +148,21 @@ class Clip(NNOperator):
@property
def _model(self):
return self.model
cfg = self._configs()[self.model_name]
try:
hf_clip_model = CLIPModel.from_pretrained(cfg)
except Exception as e:
log.error(f"Fail to load model by name: {self.model_name}")
raise e
if self.checkpoint_path:
try:
state_dict = torch.load(self.checkpoint_path, map_location=self.device)
hf_clip_model.load_state_dict(state_dict)
except Exception as e:
log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
hf_clip_model.to(self.device)
hf_clip_model.eval()
return hf_clip_model
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
import os

Loading…
Cancel
Save