logo
Browse Source

update blip.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
parent
commit
8fa9b81ca0
  1. 35
      blip.py

35
blip.py

@ -16,6 +16,7 @@ import sys
import os
from pathlib import Path
from typing import Optional, Union, Tuple
from types import MethodType
import torch
import logging
@ -36,32 +37,7 @@ warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error()
class BlipForImageTextRetrievalFinetune():
def __init__(self, blip_itm_model):
self.proxied = blip_itm_model
self.proxied.logit_scale = torch.nn.Parameter(torch.ones([]) * blip_itm_model.config.logit_scale_init_value)
def __getattr__(self, attr):
if attr == 'proxied':
return self.proxied
elif attr != 'forward':
if getattr(self.proxied, attr) is True:
def wrapped_method(*args, **kwargs):
result = getattr(self.proxied, attr)(*args, **kwargs)
return result
return wrapped_method
else:
return getattr(self.proxied, attr)
else:
return self._forward
def forward(self, *args, **kwargs):
raise NotImplemented
def __call__(self, *args, **kwargs):
return self._forward(*args, **kwargs)
def _forward(
def _forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
@ -258,10 +234,13 @@ class Blip(NNOperator):
import pathlib
path = str(pathlib.Path(__file__).parent)
sys.path.append(path)
from train_blip_with_hf_trainer import train_with_hf_trainer
from train_clip_with_hf_trainer import train_with_hf_trainer
data_args = kwargs.pop('data_args', None)
training_args = kwargs.pop('training_args', None)
model_finetune = BlipForImageTextRetrievalFinetune(self._model.backbone)
model_finetune = self._model.backbone
model_finetune.forward = MethodType(_forward, model_finetune)
model_finetune.logit_scale = torch.nn.Parameter(torch.ones([]) * model_finetune.config.logit_scale_init_value)
train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args)

Loading…
Cancel
Save