diff --git a/blip.py b/blip.py index fb503ce..8005669 100644 --- a/blip.py +++ b/blip.py @@ -16,6 +16,7 @@ import sys import os from pathlib import Path from typing import Optional, Union, Tuple +from types import MethodType import torch import logging @@ -36,97 +37,72 @@ warnings.filterwarnings('ignore') os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' t_logging.set_verbosity_error() -class BlipForImageTextRetrievalFinetune(): - def __init__(self, blip_itm_model): - self.proxied = blip_itm_model - self.proxied.logit_scale = torch.nn.Parameter(torch.ones([]) * blip_itm_model.config.logit_scale_init_value) - - def __getattr__(self, attr): - if attr == 'proxied': - return self.proxied - elif attr != 'forward': - if getattr(self.proxied, attr) is True: - def wrapped_method(*args, **kwargs): - result = getattr(self.proxied, attr)(*args, **kwargs) - return result - return wrapped_method - else: - return getattr(self.proxied, attr) - else: - return self._forward - - def forward(self, *args, **kwargs): - raise NotImplemented - - def __call__(self, *args, **kwargs): - return self._forward(*args, **kwargs) - - def _forward( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - return_loss: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, +def _forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple, BlipOutput]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - text_outputs = self.text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - image_embeds = vision_outputs[1] - image_embeds = self.vision_proj(image_embeds) - - text_embeds = text_outputs[0] - text_embeds = self.text_proj(text_embeds[:,0,:]) - - # normalized features - image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) - text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) - - # cosine similarity as logits - logit_scale = self.logit_scale.exp() - logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale - logits_per_image = logits_per_text.t() - - loss = None - if return_loss: - loss = blip_loss(logits_per_text) - - if not return_dict: - output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - return ((loss,) + output) if loss is not None else output - - return BlipOutput( - loss=loss, - logits_per_image=logits_per_image, - logits_per_text=logits_per_text, - text_embeds=text_embeds, - image_embeds=image_embeds, - text_model_output=text_outputs, - vision_model_output=vision_outputs, - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.vision_proj(image_embeds) + + text_embeds = text_outputs[0] + text_embeds = self.text_proj(text_embeds[:,0,:]) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = blip_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return BlipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) def create_model(cfg, modality, checkpoint_path, device): @@ -258,10 +234,13 @@ class Blip(NNOperator): import pathlib path = str(pathlib.Path(__file__).parent) sys.path.append(path) - from train_blip_with_hf_trainer import train_with_hf_trainer + from train_clip_with_hf_trainer import train_with_hf_trainer data_args = kwargs.pop('data_args', None) training_args = kwargs.pop('training_args', None) - model_finetune = BlipForImageTextRetrievalFinetune(self._model.backbone) + model_finetune = self._model.backbone + model_finetune.forward = MethodType(_forward, model_finetune) + model_finetune.logit_scale = torch.nn.Parameter(torch.ones([]) * model_finetune.config.logit_scale_init_value) + train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args)