logo
Browse Source

update the blip operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
v2
wxywb 2 years ago
parent
commit
08a4d10c78
  1. 41
      blip.py

41
blip.py

@ -19,6 +19,7 @@ import torch
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from transformers import AutoProcessor, BlipForImageTextRetrieval
from towhee import register from towhee import register
from towhee.operator.base import NNOperator, OperatorFlag from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color from towhee.types.arg import arg, to_image_color
@ -32,6 +33,7 @@ class BLIPModelVision(nn.Module):
def forward(self, image): def forward(self, image):
image_embeds = self.model.visual_encoder(image) image_embeds = self.model.visual_encoder(image)
image_embeds = self.model.vision_proj(image_embeds[:,0,:])
return image_embeds return image_embeds
#@accelerate #@accelerate
@ -41,10 +43,10 @@ class BLIPModelText(nn.Module):
self.model = model self.model = model
def forward(self, input_ids, attention_mask): def forward(self, input_ids, attention_mask):
text_output = self.text_encoder(input_ids, attention_mask = attention_mask,
return_dict = False, mode = 'text')
return text_output
text_features = self.model.text_encoder(input_ids, attention_mask = attention_mask,
return_dict = False, mode = 'text')[0]
text_features = self.model.text_proj(text_features[:,0,:])j
return text_features
@register(output_schema=['vec']) @register(output_schema=['vec'])
class Blip(NNOperator): class Blip(NNOperator):
@ -53,8 +55,6 @@ class Blip(NNOperator):
""" """
def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None): def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None):
super().__init__() super().__init__()
sys.path.append(str(Path(__file__).parent))
from models.blip import blip_feature_extractor
self.model_name = model_name self.model_name = model_name
self.device = device self.device = device
cfg = self._configs()[model_name] cfg = self._configs()[model_name]
@ -62,8 +62,8 @@ class Blip(NNOperator):
model_url = cfg['weights'] model_url = cfg['weights']
image_size = cfg['image_size'] image_size = cfg['image_size']
model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base')
self.tokenizer = model.tokenizer
model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
self.processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
if self.modality == 'image': if self.modality == 'image':
self.model = BLIPModelVision(model) self.model = BLIPModelVision(model)
@ -77,11 +77,11 @@ class Blip(NNOperator):
self.model.to(self.device) self.model.to(self.device)
self.model.eval() self.model.eval()
self.tfms = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
#self.tfms = transforms.Compose([
# transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
# transforms.ToTensor(),
# transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
# ])
def __call__(self, data): def __call__(self, data):
if self._modality == 'image': if self._modality == 'image':
@ -93,21 +93,18 @@ class Blip(NNOperator):
return vec.detach().cpu().numpy().flatten() return vec.detach().cpu().numpy().flatten()
def _inference_from_text(self, text): def _inference_from_text(self, text):
tokens = self.tokenizer(text, return_tensors="pt").to(self.device)
text_feature = self.model(tokens.input_ids, tokens.attention_mask)
inputs = self.processor(text=text, padding=True, return_tensors="pt")
inputs = inputs.to(self.device)
text_feature = self.model(input_ids = inputs. , attention_mask)[0]
return text_feature return text_feature
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def _inference_from_image(self, img): def _inference_from_image(self, img):
img = self._preprocess(img)
image_feature = self.model(img)
inputs = self.processor(images=img, return_tensors="pt")
inputs = inputs.to(self.device)
image_feature = self.model(inputs)
return image_feature return image_feature
def _preprocess(self, img):
img = to_pil(img)
processed_img = self.tfms(img).unsqueeze(0).to(self.device)
return processed_img
def _configs(self): def _configs(self):
config = {} config = {}
config['blip_base'] = {} config['blip_base'] = {}

Loading…
Cancel
Save