logo
Browse Source

update the blip operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
v2
wxywb 2 years ago
parent
commit
08a4d10c78
  1. 41
      blip.py

41
blip.py

@ -19,6 +19,7 @@ import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoProcessor, BlipForImageTextRetrieval
from towhee import register
from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
@ -32,6 +33,7 @@ class BLIPModelVision(nn.Module):
def forward(self, image):
image_embeds = self.model.visual_encoder(image)
image_embeds = self.model.vision_proj(image_embeds[:,0,:])
return image_embeds
#@accelerate
@ -41,10 +43,10 @@ class BLIPModelText(nn.Module):
self.model = model
def forward(self, input_ids, attention_mask):
text_output = self.text_encoder(input_ids, attention_mask = attention_mask,
return_dict = False, mode = 'text')
return text_output
text_features = self.model.text_encoder(input_ids, attention_mask = attention_mask,
return_dict = False, mode = 'text')[0]
text_features = self.model.text_proj(text_features[:,0,:])j
return text_features
@register(output_schema=['vec'])
class Blip(NNOperator):
@ -53,8 +55,6 @@ class Blip(NNOperator):
"""
def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None):
super().__init__()
sys.path.append(str(Path(__file__).parent))
from models.blip import blip_feature_extractor
self.model_name = model_name
self.device = device
cfg = self._configs()[model_name]
@ -62,8 +62,8 @@ class Blip(NNOperator):
model_url = cfg['weights']
image_size = cfg['image_size']
model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base')
self.tokenizer = model.tokenizer
model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
self.processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
if self.modality == 'image':
self.model = BLIPModelVision(model)
@ -77,11 +77,11 @@ class Blip(NNOperator):
self.model.to(self.device)
self.model.eval()
self.tfms = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
#self.tfms = transforms.Compose([
# transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
# transforms.ToTensor(),
# transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
# ])
def __call__(self, data):
if self._modality == 'image':
@ -93,21 +93,18 @@ class Blip(NNOperator):
return vec.detach().cpu().numpy().flatten()
def _inference_from_text(self, text):
tokens = self.tokenizer(text, return_tensors="pt").to(self.device)
text_feature = self.model(tokens.input_ids, tokens.attention_mask)
inputs = self.processor(text=text, padding=True, return_tensors="pt")
inputs = inputs.to(self.device)
text_feature = self.model(input_ids = inputs. , attention_mask)[0]
return text_feature
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = self._preprocess(img)
image_feature = self.model(img)
inputs = self.processor(images=img, return_tensors="pt")
inputs = inputs.to(self.device)
image_feature = self.model(inputs)
return image_feature
def _preprocess(self, img):
img = to_pil(img)
processed_img = self.tfms(img).unsqueeze(0).to(self.device)
return processed_img
def _configs(self):
config = {}
config['blip_base'] = {}

Loading…
Cancel
Save