logo
Browse Source

update blip.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
parent
commit
8fa9b81ca0
  1. 33
      blip.py

33
blip.py

@ -16,6 +16,7 @@ import sys
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Tuple from typing import Optional, Union, Tuple
from types import MethodType
import torch import torch
import logging import logging
@ -36,31 +37,6 @@ warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error() t_logging.set_verbosity_error()
class BlipForImageTextRetrievalFinetune():
def __init__(self, blip_itm_model):
self.proxied = blip_itm_model
self.proxied.logit_scale = torch.nn.Parameter(torch.ones([]) * blip_itm_model.config.logit_scale_init_value)
def __getattr__(self, attr):
if attr == 'proxied':
return self.proxied
elif attr != 'forward':
if getattr(self.proxied, attr) is True:
def wrapped_method(*args, **kwargs):
result = getattr(self.proxied, attr)(*args, **kwargs)
return result
return wrapped_method
else:
return getattr(self.proxied, attr)
else:
return self._forward
def forward(self, *args, **kwargs):
raise NotImplemented
def __call__(self, *args, **kwargs):
return self._forward(*args, **kwargs)
def _forward( def _forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@ -258,10 +234,13 @@ class Blip(NNOperator):
import pathlib import pathlib
path = str(pathlib.Path(__file__).parent) path = str(pathlib.Path(__file__).parent)
sys.path.append(path) sys.path.append(path)
from train_blip_with_hf_trainer import train_with_hf_trainer
from train_clip_with_hf_trainer import train_with_hf_trainer
data_args = kwargs.pop('data_args', None) data_args = kwargs.pop('data_args', None)
training_args = kwargs.pop('training_args', None) training_args = kwargs.pop('training_args', None)
model_finetune = BlipForImageTextRetrievalFinetune(self._model.backbone)
model_finetune = self._model.backbone
model_finetune.forward = MethodType(_forward, model_finetune)
model_finetune.logit_scale = torch.nn.Parameter(torch.ones([]) * model_finetune.config.logit_scale_init_value)
train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args) train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args)

Loading…
Cancel
Save