|
|
@ -16,6 +16,7 @@ import sys |
|
|
|
import os |
|
|
|
from pathlib import Path |
|
|
|
from typing import Optional, Union, Tuple |
|
|
|
from types import MethodType |
|
|
|
|
|
|
|
import torch |
|
|
|
import logging |
|
|
@ -36,97 +37,72 @@ warnings.filterwarnings('ignore') |
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
t_logging.set_verbosity_error() |
|
|
|
|
|
|
|
class BlipForImageTextRetrievalFinetune(): |
|
|
|
def __init__(self, blip_itm_model): |
|
|
|
self.proxied = blip_itm_model |
|
|
|
self.proxied.logit_scale = torch.nn.Parameter(torch.ones([]) * blip_itm_model.config.logit_scale_init_value) |
|
|
|
|
|
|
|
def __getattr__(self, attr): |
|
|
|
if attr == 'proxied': |
|
|
|
return self.proxied |
|
|
|
elif attr != 'forward': |
|
|
|
if getattr(self.proxied, attr) is True: |
|
|
|
def wrapped_method(*args, **kwargs): |
|
|
|
result = getattr(self.proxied, attr)(*args, **kwargs) |
|
|
|
return result |
|
|
|
return wrapped_method |
|
|
|
else: |
|
|
|
return getattr(self.proxied, attr) |
|
|
|
else: |
|
|
|
return self._forward |
|
|
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
raise NotImplemented |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
return self._forward(*args, **kwargs) |
|
|
|
|
|
|
|
def _forward( |
|
|
|
self, |
|
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
|
return_loss: Optional[bool] = None, |
|
|
|
output_attentions: Optional[bool] = None, |
|
|
|
output_hidden_states: Optional[bool] = None, |
|
|
|
return_dict: Optional[bool] = None, |
|
|
|
def _forward( |
|
|
|
self, |
|
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
|
return_loss: Optional[bool] = None, |
|
|
|
output_attentions: Optional[bool] = None, |
|
|
|
output_hidden_states: Optional[bool] = None, |
|
|
|
return_dict: Optional[bool] = None, |
|
|
|
) -> Union[Tuple, BlipOutput]: |
|
|
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
|
output_hidden_states = ( |
|
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
|
) |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
vision_outputs = self.vision_model( |
|
|
|
pixel_values=pixel_values, |
|
|
|
output_attentions=output_attentions, |
|
|
|
output_hidden_states=output_hidden_states, |
|
|
|
return_dict=return_dict, |
|
|
|
) |
|
|
|
|
|
|
|
text_outputs = self.text_encoder( |
|
|
|
input_ids=input_ids, |
|
|
|
attention_mask=attention_mask, |
|
|
|
position_ids=position_ids, |
|
|
|
output_attentions=output_attentions, |
|
|
|
output_hidden_states=output_hidden_states, |
|
|
|
return_dict=return_dict, |
|
|
|
) |
|
|
|
|
|
|
|
image_embeds = vision_outputs[1] |
|
|
|
image_embeds = self.vision_proj(image_embeds) |
|
|
|
|
|
|
|
text_embeds = text_outputs[0] |
|
|
|
text_embeds = self.text_proj(text_embeds[:,0,:]) |
|
|
|
|
|
|
|
# normalized features |
|
|
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
# cosine similarity as logits |
|
|
|
logit_scale = self.logit_scale.exp() |
|
|
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale |
|
|
|
logits_per_image = logits_per_text.t() |
|
|
|
|
|
|
|
loss = None |
|
|
|
if return_loss: |
|
|
|
loss = blip_loss(logits_per_text) |
|
|
|
|
|
|
|
if not return_dict: |
|
|
|
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) |
|
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
|
|
return BlipOutput( |
|
|
|
loss=loss, |
|
|
|
logits_per_image=logits_per_image, |
|
|
|
logits_per_text=logits_per_text, |
|
|
|
text_embeds=text_embeds, |
|
|
|
image_embeds=image_embeds, |
|
|
|
text_model_output=text_outputs, |
|
|
|
vision_model_output=vision_outputs, |
|
|
|
) |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
|
output_hidden_states = ( |
|
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
|
) |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
vision_outputs = self.vision_model( |
|
|
|
pixel_values=pixel_values, |
|
|
|
output_attentions=output_attentions, |
|
|
|
output_hidden_states=output_hidden_states, |
|
|
|
return_dict=return_dict, |
|
|
|
) |
|
|
|
|
|
|
|
text_outputs = self.text_encoder( |
|
|
|
input_ids=input_ids, |
|
|
|
attention_mask=attention_mask, |
|
|
|
position_ids=position_ids, |
|
|
|
output_attentions=output_attentions, |
|
|
|
output_hidden_states=output_hidden_states, |
|
|
|
return_dict=return_dict, |
|
|
|
) |
|
|
|
|
|
|
|
image_embeds = vision_outputs[1] |
|
|
|
image_embeds = self.vision_proj(image_embeds) |
|
|
|
|
|
|
|
text_embeds = text_outputs[0] |
|
|
|
text_embeds = self.text_proj(text_embeds[:,0,:]) |
|
|
|
|
|
|
|
# normalized features |
|
|
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
# cosine similarity as logits |
|
|
|
logit_scale = self.logit_scale.exp() |
|
|
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale |
|
|
|
logits_per_image = logits_per_text.t() |
|
|
|
|
|
|
|
loss = None |
|
|
|
if return_loss: |
|
|
|
loss = blip_loss(logits_per_text) |
|
|
|
|
|
|
|
if not return_dict: |
|
|
|
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) |
|
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
|
|
return BlipOutput( |
|
|
|
loss=loss, |
|
|
|
logits_per_image=logits_per_image, |
|
|
|
logits_per_text=logits_per_text, |
|
|
|
text_embeds=text_embeds, |
|
|
|
image_embeds=image_embeds, |
|
|
|
text_model_output=text_outputs, |
|
|
|
vision_model_output=vision_outputs, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def create_model(cfg, modality, checkpoint_path, device): |
|
|
@ -258,10 +234,13 @@ class Blip(NNOperator): |
|
|
|
import pathlib |
|
|
|
path = str(pathlib.Path(__file__).parent) |
|
|
|
sys.path.append(path) |
|
|
|
from train_blip_with_hf_trainer import train_with_hf_trainer |
|
|
|
from train_clip_with_hf_trainer import train_with_hf_trainer |
|
|
|
data_args = kwargs.pop('data_args', None) |
|
|
|
training_args = kwargs.pop('training_args', None) |
|
|
|
model_finetune = BlipForImageTextRetrievalFinetune(self._model.backbone) |
|
|
|
model_finetune = self._model.backbone |
|
|
|
model_finetune.forward = MethodType(_forward, model_finetune) |
|
|
|
model_finetune.logit_scale = torch.nn.Parameter(torch.ones([]) * model_finetune.config.logit_scale_init_value) |
|
|
|
|
|
|
|
train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args) |
|
|
|
|
|
|
|
|
|
|
|