|
@ -25,19 +25,35 @@ class Taiyi(NNOperator): |
|
|
""" |
|
|
""" |
|
|
Taiyi multi-modal embedding operator |
|
|
Taiyi multi-modal embedding operator |
|
|
""" |
|
|
""" |
|
|
def __init__(self, model_name: str, modality: str): |
|
|
|
|
|
|
|
|
def __init__(self, model_name: str, modality: str, clip_checkpoint_path: str=None, text_checkpoint_path: str=None, device: str=None): |
|
|
self.modality = modality |
|
|
self.modality = modality |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
if device == None: |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
else: |
|
|
|
|
|
self.device = device |
|
|
config = self._configs()[model_name] |
|
|
config = self._configs()[model_name] |
|
|
|
|
|
|
|
|
self.text_tokenizer = BertTokenizer.from_pretrained(config['tokenizer']) |
|
|
self.text_tokenizer = BertTokenizer.from_pretrained(config['tokenizer']) |
|
|
self.text_encoder = BertForSequenceClassification.from_pretrained(config['text_encoder']).eval() |
|
|
|
|
|
|
|
|
self.text_encoder = BertForSequenceClassification.from_pretrained(config['text_encoder']) |
|
|
|
|
|
if text_checkpoint_path: |
|
|
|
|
|
try: |
|
|
|
|
|
text_state_dict = torch.load(text_checkpoint_path, map_location=self.device) |
|
|
|
|
|
self.text_encoder.load_state_dict(text_state_dict) |
|
|
|
|
|
except Exception: |
|
|
|
|
|
log.error(f'Fail to load weights from {text_checkpoint_path}') |
|
|
|
|
|
|
|
|
self.clip_model = CLIPModel.from_pretrained(config['clip_model']) |
|
|
self.clip_model = CLIPModel.from_pretrained(config['clip_model']) |
|
|
self.processor = CLIPProcessor.from_pretrained(config['processor']) |
|
|
self.processor = CLIPProcessor.from_pretrained(config['processor']) |
|
|
|
|
|
|
|
|
self.text_encoder.to(self.device) |
|
|
|
|
|
self.clip_model.to(self.device) |
|
|
|
|
|
|
|
|
if clip_checkpoint_path: |
|
|
|
|
|
try: |
|
|
|
|
|
clip_state_dict = torch.load(clip_checkpoint_path, map_location=self.device) |
|
|
|
|
|
self.clip_model.load_state_dict(clip_state_dict) |
|
|
|
|
|
except Exception: |
|
|
|
|
|
log.error(f'Fail to load weights from {clip_checkpoint_path}') |
|
|
|
|
|
|
|
|
|
|
|
self.text_encoder.to(self.device).eval() |
|
|
|
|
|
self.clip_model.to(self.device).eval() |
|
|
|
|
|
|
|
|
def inference_single_data(self, data): |
|
|
def inference_single_data(self, data): |
|
|
if self.modality == 'image': |
|
|
if self.modality == 'image': |
|
|