|
|
@ -16,6 +16,7 @@ import sys |
|
|
|
import os |
|
|
|
from pathlib import Path |
|
|
|
from typing import Optional, Union, Tuple |
|
|
|
from types import MethodType |
|
|
|
|
|
|
|
import torch |
|
|
|
import logging |
|
|
@ -36,32 +37,7 @@ warnings.filterwarnings('ignore') |
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
t_logging.set_verbosity_error() |
|
|
|
|
|
|
|
class BlipForImageTextRetrievalFinetune(): |
|
|
|
def __init__(self, blip_itm_model): |
|
|
|
self.proxied = blip_itm_model |
|
|
|
self.proxied.logit_scale = torch.nn.Parameter(torch.ones([]) * blip_itm_model.config.logit_scale_init_value) |
|
|
|
|
|
|
|
def __getattr__(self, attr): |
|
|
|
if attr == 'proxied': |
|
|
|
return self.proxied |
|
|
|
elif attr != 'forward': |
|
|
|
if getattr(self.proxied, attr) is True: |
|
|
|
def wrapped_method(*args, **kwargs): |
|
|
|
result = getattr(self.proxied, attr)(*args, **kwargs) |
|
|
|
return result |
|
|
|
return wrapped_method |
|
|
|
else: |
|
|
|
return getattr(self.proxied, attr) |
|
|
|
else: |
|
|
|
return self._forward |
|
|
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
raise NotImplemented |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
return self._forward(*args, **kwargs) |
|
|
|
|
|
|
|
def _forward( |
|
|
|
def _forward( |
|
|
|
self, |
|
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
@ -258,10 +234,13 @@ class Blip(NNOperator): |
|
|
|
import pathlib |
|
|
|
path = str(pathlib.Path(__file__).parent) |
|
|
|
sys.path.append(path) |
|
|
|
from train_blip_with_hf_trainer import train_with_hf_trainer |
|
|
|
from train_clip_with_hf_trainer import train_with_hf_trainer |
|
|
|
data_args = kwargs.pop('data_args', None) |
|
|
|
training_args = kwargs.pop('training_args', None) |
|
|
|
model_finetune = BlipForImageTextRetrievalFinetune(self._model.backbone) |
|
|
|
model_finetune = self._model.backbone |
|
|
|
model_finetune.forward = MethodType(_forward, model_finetune) |
|
|
|
model_finetune.logit_scale = torch.nn.Parameter(torch.ones([]) * model_finetune.config.logit_scale_init_value) |
|
|
|
|
|
|
|
train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args) |
|
|
|
|
|
|
|
|
|
|
|