logo
Browse Source

add weights loading for taiyi.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
42bf1528f0
  1. 12
      README.md
  2. 4
      __init__.py
  3. 24
      taiyi.py

12
README.md

@ -69,9 +69,19 @@ Create the operator via the following factory method
​ Which modality(*image* or *text*) is used to generate the embedding. ​ Which modality(*image* or *text*) is used to generate the embedding.
<br />
***clip_checkpoint_path:*** *str*
​ The weight path to load for the clip branch.
***text_checkpoint_path:*** *str*
​ The weight path to load for the text branch.
***devcice:*** *str*
​ The device in string, defaults to None. If None, it will enable "cuda" automatically when cuda is available.
<br />
## Interface ## Interface

4
__init__.py

@ -15,5 +15,5 @@
from .taiyi import Taiyi from .taiyi import Taiyi
def taiyi(model_name: str, modality: str):
return Taiyi(model_name, modality)
def taiyi(model_name: str, modality: str, clip_checkpoint_path: str=None, text_checkpoint_path: str=None, device=None):
return Taiyi(model_name, modality, clip_checkpoint_path, text_checkpoint_path, device)

24
taiyi.py

@ -25,19 +25,35 @@ class Taiyi(NNOperator):
""" """
Taiyi multi-modal embedding operator Taiyi multi-modal embedding operator
""" """
def __init__(self, model_name: str, modality: str):
def __init__(self, model_name: str, modality: str, clip_checkpoint_path: str=None, text_checkpoint_path: str=None, device: str=None):
self.modality = modality self.modality = modality
if device == None:
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
config = self._configs()[model_name] config = self._configs()[model_name]
self.text_tokenizer = BertTokenizer.from_pretrained(config['tokenizer']) self.text_tokenizer = BertTokenizer.from_pretrained(config['tokenizer'])
self.text_encoder = BertForSequenceClassification.from_pretrained(config['text_encoder']).eval()
self.text_encoder = BertForSequenceClassification.from_pretrained(config['text_encoder'])
if text_checkpoint_path:
try:
text_state_dict = torch.load(text_checkpoint_path, map_location=self.device)
self.text_encoder.load_state_dict(text_state_dict)
except Exception:
log.error(f'Fail to load weights from {text_checkpoint_path}')
self.clip_model = CLIPModel.from_pretrained(config['clip_model']) self.clip_model = CLIPModel.from_pretrained(config['clip_model'])
self.processor = CLIPProcessor.from_pretrained(config['processor']) self.processor = CLIPProcessor.from_pretrained(config['processor'])
self.text_encoder.to(self.device)
self.clip_model.to(self.device)
if clip_checkpoint_path:
try:
clip_state_dict = torch.load(clip_checkpoint_path, map_location=self.device)
self.clip_model.load_state_dict(clip_state_dict)
except Exception:
log.error(f'Fail to load weights from {clip_checkpoint_path}')
self.text_encoder.to(self.device).eval()
self.clip_model.to(self.device).eval()
def inference_single_data(self, data): def inference_single_data(self, data):
if self.modality == 'image': if self.modality == 'image':

Loading…
Cancel
Save