logo
Browse Source

update blip.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
parent
commit
8fa9b81ca0
  1. 161
      blip.py

161
blip.py

@ -16,6 +16,7 @@ import sys
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Tuple from typing import Optional, Union, Tuple
from types import MethodType
import torch import torch
import logging import logging
@ -36,97 +37,72 @@ warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error() t_logging.set_verbosity_error()
class BlipForImageTextRetrievalFinetune():
def __init__(self, blip_itm_model):
self.proxied = blip_itm_model
self.proxied.logit_scale = torch.nn.Parameter(torch.ones([]) * blip_itm_model.config.logit_scale_init_value)
def __getattr__(self, attr):
if attr == 'proxied':
return self.proxied
elif attr != 'forward':
if getattr(self.proxied, attr) is True:
def wrapped_method(*args, **kwargs):
result = getattr(self.proxied, attr)(*args, **kwargs)
return result
return wrapped_method
else:
return getattr(self.proxied, attr)
else:
return self._forward
def forward(self, *args, **kwargs):
raise NotImplemented
def __call__(self, *args, **kwargs):
return self._forward(*args, **kwargs)
def _forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
def _forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BlipOutput]: ) -> Union[Tuple, BlipOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
image_embeds = self.vision_proj(image_embeds)
text_embeds = text_outputs[0]
text_embeds = self.text_proj(text_embeds[:,0,:])
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
loss = blip_loss(logits_per_text)
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return BlipOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
image_embeds = self.vision_proj(image_embeds)
text_embeds = text_outputs[0]
text_embeds = self.text_proj(text_embeds[:,0,:])
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
loss = blip_loss(logits_per_text)
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return BlipOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
def create_model(cfg, modality, checkpoint_path, device): def create_model(cfg, modality, checkpoint_path, device):
@ -258,10 +234,13 @@ class Blip(NNOperator):
import pathlib import pathlib
path = str(pathlib.Path(__file__).parent) path = str(pathlib.Path(__file__).parent)
sys.path.append(path) sys.path.append(path)
from train_blip_with_hf_trainer import train_with_hf_trainer
from train_clip_with_hf_trainer import train_with_hf_trainer
data_args = kwargs.pop('data_args', None) data_args = kwargs.pop('data_args', None)
training_args = kwargs.pop('training_args', None) training_args = kwargs.pop('training_args', None)
model_finetune = BlipForImageTextRetrievalFinetune(self._model.backbone)
model_finetune = self._model.backbone
model_finetune.forward = MethodType(_forward, model_finetune)
model_finetune.logit_scale = torch.nn.Parameter(torch.ones([]) * model_finetune.config.logit_scale_init_value)
train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args) train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args)

Loading…
Cancel
Save