logo
Browse Source

Support loading trained weights path.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 1 year ago
parent
commit
a704105378
  1. 15
      clip.py

15
clip.py

@ -38,13 +38,11 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error() t_logging.set_verbosity_error()
def create_model(model_name, modality, checkpoint_path, device): def create_model(model_name, modality, checkpoint_path, device):
if checkpoint_path is None:
hf_clip_model = CLIPModel.from_pretrained(model_name) hf_clip_model = CLIPModel.from_pretrained(model_name)
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=device)
hf_clip_model.load_state_dict(state_dict)
except Exception as e:
log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
else:
hf_clip_config = CLIPModel.from_config(model_name)
hf_clip_model = CLIPModel.from_pretrained(checkpoint_path, config=hf_clip_config)
hf_clip_model.to(device) hf_clip_model.to(device)
hf_clip_model.eval() hf_clip_model.eval()
@ -101,7 +99,10 @@ class Clip(NNOperator):
self.modality = modality self.modality = modality
self.device = device self.device = device
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
real_name = self._configs()[model_name]
config = self._configs()
real_name = model_name
if model_name in config:
real_name = config[model_name]
self.model = Model(real_name, modality, checkpoint_path, device) self.model = Model(real_name, modality, checkpoint_path, device)
self.tokenizer = CLIPTokenizer.from_pretrained(real_name) self.tokenizer = CLIPTokenizer.from_pretrained(real_name)

Loading…
Cancel
Save