From 42bf1528f0e2fbd027c0df745d671f3f25ef328e Mon Sep 17 00:00:00 2001 From: wxywb Date: Mon, 27 Mar 2023 12:49:42 +0000 Subject: [PATCH] add weights loading for taiyi. Signed-off-by: wxywb --- README.md | 12 +++++++++++- __init__.py | 4 ++-- taiyi.py | 28 ++++++++++++++++++++++------ 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index cd6c9de..2e93760 100644 --- a/README.md +++ b/README.md @@ -69,9 +69,19 @@ Create the operator via the following factory method ​ Which modality(*image* or *text*) is used to generate the embedding. -
+​ ***clip_checkpoint_path:*** *str* + +​ The weight path to load for the clip branch. + +​ ***text_checkpoint_path:*** *str* +​ The weight path to load for the text branch. +​ ***devcice:*** *str* + +​ The device in string, defaults to None. If None, it will enable "cuda" automatically when cuda is available. + +
## Interface diff --git a/__init__.py b/__init__.py index d1acbac..0b0484f 100644 --- a/__init__.py +++ b/__init__.py @@ -15,5 +15,5 @@ from .taiyi import Taiyi -def taiyi(model_name: str, modality: str): - return Taiyi(model_name, modality) +def taiyi(model_name: str, modality: str, clip_checkpoint_path: str=None, text_checkpoint_path: str=None, device=None): + return Taiyi(model_name, modality, clip_checkpoint_path, text_checkpoint_path, device) diff --git a/taiyi.py b/taiyi.py index 2f901ab..adcbceb 100644 --- a/taiyi.py +++ b/taiyi.py @@ -25,19 +25,35 @@ class Taiyi(NNOperator): """ Taiyi multi-modal embedding operator """ - def __init__(self, model_name: str, modality: str): + def __init__(self, model_name: str, modality: str, clip_checkpoint_path: str=None, text_checkpoint_path: str=None, device: str=None): self.modality = modality - self.device = "cuda" if torch.cuda.is_available() else "cpu" + if device == None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device config = self._configs()[model_name] self.text_tokenizer = BertTokenizer.from_pretrained(config['tokenizer']) - self.text_encoder = BertForSequenceClassification.from_pretrained(config['text_encoder']).eval() - + self.text_encoder = BertForSequenceClassification.from_pretrained(config['text_encoder']) + if text_checkpoint_path: + try: + text_state_dict = torch.load(text_checkpoint_path, map_location=self.device) + self.text_encoder.load_state_dict(text_state_dict) + except Exception: + log.error(f'Fail to load weights from {text_checkpoint_path}') + self.clip_model = CLIPModel.from_pretrained(config['clip_model']) self.processor = CLIPProcessor.from_pretrained(config['processor']) - self.text_encoder.to(self.device) - self.clip_model.to(self.device) + if clip_checkpoint_path: + try: + clip_state_dict = torch.load(clip_checkpoint_path, map_location=self.device) + self.clip_model.load_state_dict(clip_state_dict) + except Exception: + log.error(f'Fail to load weights from {clip_checkpoint_path}') + + self.text_encoder.to(self.device).eval() + self.clip_model.to(self.device).eval() def inference_single_data(self, data): if self.modality == 'image':