|
|
@ -54,25 +54,20 @@ class Clip(NNOperator): |
|
|
|
self.model_name = model_name |
|
|
|
self.modality = modality |
|
|
|
self.device = device |
|
|
|
self.checkpoint_path = checkpoint_path |
|
|
|
cfg = self._configs()[model_name] |
|
|
|
try: |
|
|
|
clip_model = CLIPModel.from_pretrained(cfg) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
|
raise e |
|
|
|
if checkpoint_path: |
|
|
|
try: |
|
|
|
state_dict = torch.load(checkpoint_path, map_location=self.device) |
|
|
|
self.model.load_state_dict(state_dict) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load state dict from {checkpoint_path}: {e}") |
|
|
|
|
|
|
|
if self.modality == 'image': |
|
|
|
self.model = CLIPModelVision(clip_model) |
|
|
|
self.model = CLIPModelVision(self._model) |
|
|
|
elif self.modality == 'text': |
|
|
|
self.model = CLIPModelText(clip_model) |
|
|
|
self.model = CLIPModelText(self._model) |
|
|
|
else: |
|
|
|
raise ValueError("modality[{}] not implemented.".format(self.modality)) |
|
|
|
self.model.to(self.device) |
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(cfg) |
|
|
|
self.processor = CLIPProcessor.from_pretrained(cfg) |
|
|
|
|
|
|
@ -153,7 +148,21 @@ class Clip(NNOperator): |
|
|
|
|
|
|
|
@property |
|
|
|
def _model(self): |
|
|
|
return self.model |
|
|
|
cfg = self._configs()[self.model_name] |
|
|
|
try: |
|
|
|
hf_clip_model = CLIPModel.from_pretrained(cfg) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
|
raise e |
|
|
|
if self.checkpoint_path: |
|
|
|
try: |
|
|
|
state_dict = torch.load(self.checkpoint_path, map_location=self.device) |
|
|
|
hf_clip_model.load_state_dict(state_dict) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load state dict from {checkpoint_path}: {e}") |
|
|
|
hf_clip_model.to(self.device) |
|
|
|
hf_clip_model.eval() |
|
|
|
return hf_clip_model |
|
|
|
|
|
|
|
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): |
|
|
|
import os |
|
|
|