Browse Source
Support loading trained weights path.
Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb
2 years ago
1 changed files with
9 additions and
8 deletions
-
clip.py
|
|
@ -38,13 +38,11 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
t_logging.set_verbosity_error() |
|
|
|
|
|
|
|
def create_model(model_name, modality, checkpoint_path, device): |
|
|
|
hf_clip_model = CLIPModel.from_pretrained(model_name) |
|
|
|
if checkpoint_path: |
|
|
|
try: |
|
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
|
hf_clip_model.load_state_dict(state_dict) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load state dict from {checkpoint_path}: {e}") |
|
|
|
if checkpoint_path is None: |
|
|
|
hf_clip_model = CLIPModel.from_pretrained(model_name) |
|
|
|
else: |
|
|
|
hf_clip_config = CLIPModel.from_config(model_name) |
|
|
|
hf_clip_model = CLIPModel.from_pretrained(checkpoint_path, config=hf_clip_config) |
|
|
|
hf_clip_model.to(device) |
|
|
|
hf_clip_model.eval() |
|
|
|
|
|
|
@ -101,7 +99,10 @@ class Clip(NNOperator): |
|
|
|
self.modality = modality |
|
|
|
self.device = device |
|
|
|
self.checkpoint_path = checkpoint_path |
|
|
|
real_name = self._configs()[model_name] |
|
|
|
config = self._configs() |
|
|
|
real_name = model_name |
|
|
|
if model_name in config: |
|
|
|
real_name = config[model_name] |
|
|
|
|
|
|
|
self.model = Model(real_name, modality, checkpoint_path, device) |
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(real_name) |
|
|
|