|
|
@ -19,6 +19,7 @@ import torch |
|
|
|
from torchvision import transforms |
|
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
|
|
|
|
from transformers import AutoProcessor, BlipForImageTextRetrieval |
|
|
|
from towhee import register |
|
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
|
from towhee.types.arg import arg, to_image_color |
|
|
@ -32,6 +33,7 @@ class BLIPModelVision(nn.Module): |
|
|
|
|
|
|
|
def forward(self, image): |
|
|
|
image_embeds = self.model.visual_encoder(image) |
|
|
|
image_embeds = self.model.vision_proj(image_embeds[:,0,:]) |
|
|
|
return image_embeds |
|
|
|
|
|
|
|
#@accelerate |
|
|
@ -41,10 +43,10 @@ class BLIPModelText(nn.Module): |
|
|
|
self.model = model |
|
|
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
text_output = self.text_encoder(input_ids, attention_mask = attention_mask, |
|
|
|
return_dict = False, mode = 'text') |
|
|
|
return text_output |
|
|
|
|
|
|
|
text_features = self.model.text_encoder(input_ids, attention_mask = attention_mask, |
|
|
|
return_dict = False, mode = 'text')[0] |
|
|
|
text_features = self.model.text_proj(text_features[:,0,:])j |
|
|
|
return text_features |
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
|
class Blip(NNOperator): |
|
|
@ -53,8 +55,6 @@ class Blip(NNOperator): |
|
|
|
""" |
|
|
|
def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None): |
|
|
|
super().__init__() |
|
|
|
sys.path.append(str(Path(__file__).parent)) |
|
|
|
from models.blip import blip_feature_extractor |
|
|
|
self.model_name = model_name |
|
|
|
self.device = device |
|
|
|
cfg = self._configs()[model_name] |
|
|
@ -62,8 +62,8 @@ class Blip(NNOperator): |
|
|
|
model_url = cfg['weights'] |
|
|
|
image_size = cfg['image_size'] |
|
|
|
|
|
|
|
model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base') |
|
|
|
self.tokenizer = model.tokenizer |
|
|
|
model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco") |
|
|
|
self.processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco") |
|
|
|
|
|
|
|
if self.modality == 'image': |
|
|
|
self.model = BLIPModelVision(model) |
|
|
@ -77,11 +77,11 @@ class Blip(NNOperator): |
|
|
|
self.model.to(self.device) |
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
self.tfms = transforms.Compose([ |
|
|
|
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), |
|
|
|
transforms.ToTensor(), |
|
|
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
|
]) |
|
|
|
#self.tfms = transforms.Compose([ |
|
|
|
# transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), |
|
|
|
# transforms.ToTensor(), |
|
|
|
# transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
|
# ]) |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
if self._modality == 'image': |
|
|
@ -93,21 +93,18 @@ class Blip(NNOperator): |
|
|
|
return vec.detach().cpu().numpy().flatten() |
|
|
|
|
|
|
|
def _inference_from_text(self, text): |
|
|
|
tokens = self.tokenizer(text, return_tensors="pt").to(self.device) |
|
|
|
text_feature = self.model(tokens.input_ids, tokens.attention_mask) |
|
|
|
inputs = self.processor(text=text, padding=True, return_tensors="pt") |
|
|
|
inputs = inputs.to(self.device) |
|
|
|
text_feature = self.model(input_ids = inputs. , attention_mask)[0] |
|
|
|
return text_feature |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def _inference_from_image(self, img): |
|
|
|
img = self._preprocess(img) |
|
|
|
image_feature = self.model(img) |
|
|
|
inputs = self.processor(images=img, return_tensors="pt") |
|
|
|
inputs = inputs.to(self.device) |
|
|
|
image_feature = self.model(inputs) |
|
|
|
return image_feature |
|
|
|
|
|
|
|
def _preprocess(self, img): |
|
|
|
img = to_pil(img) |
|
|
|
processed_img = self.tfms(img).unsqueeze(0).to(self.device) |
|
|
|
return processed_img |
|
|
|
|
|
|
|
def _configs(self): |
|
|
|
config = {} |
|
|
|
config['blip_base'] = {} |
|
|
|