logo
Browse Source

update blip.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
parent
commit
8fa9b81ca0
  1. 161
      blip.py

161
blip.py

@ -16,6 +16,7 @@ import sys
import os
from pathlib import Path
from typing import Optional, Union, Tuple
from types import MethodType
import torch
import logging
@ -36,97 +37,72 @@ warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error()
class BlipForImageTextRetrievalFinetune():
def __init__(self, blip_itm_model):
self.proxied = blip_itm_model
self.proxied.logit_scale = torch.nn.Parameter(torch.ones([]) * blip_itm_model.config.logit_scale_init_value)
def __getattr__(self, attr):
if attr == 'proxied':
return self.proxied
elif attr != 'forward':
if getattr(self.proxied, attr) is True:
def wrapped_method(*args, **kwargs):
result = getattr(self.proxied, attr)(*args, **kwargs)
return result
return wrapped_method
else:
return getattr(self.proxied, attr)
else:
return self._forward
def forward(self, *args, **kwargs):
raise NotImplemented
def __call__(self, *args, **kwargs):
return self._forward(*args, **kwargs)
def _forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
def _forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BlipOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
image_embeds = self.vision_proj(image_embeds)
text_embeds = text_outputs[0]
text_embeds = self.text_proj(text_embeds[:,0,:])
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
loss = blip_loss(logits_per_text)
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return BlipOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
image_embeds = self.vision_proj(image_embeds)
text_embeds = text_outputs[0]
text_embeds = self.text_proj(text_embeds[:,0,:])
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
loss = blip_loss(logits_per_text)
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return BlipOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
def create_model(cfg, modality, checkpoint_path, device):
@ -258,10 +234,13 @@ class Blip(NNOperator):
import pathlib
path = str(pathlib.Path(__file__).parent)
sys.path.append(path)
from train_blip_with_hf_trainer import train_with_hf_trainer
from train_clip_with_hf_trainer import train_with_hf_trainer
data_args = kwargs.pop('data_args', None)
training_args = kwargs.pop('training_args', None)
model_finetune = BlipForImageTextRetrievalFinetune(self._model.backbone)
model_finetune = self._model.backbone
model_finetune.forward = MethodType(_forward, model_finetune)
model_finetune.logit_scale = torch.nn.Parameter(torch.ones([]) * model_finetune.config.logit_scale_init_value)
train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args)

Loading…
Cancel
Save