diff --git a/blip.py b/blip.py index 29e6e6c..b7a98f9 100644 --- a/blip.py +++ b/blip.py @@ -19,6 +19,7 @@ import torch from torchvision import transforms from torchvision.transforms.functional import InterpolationMode +from transformers import AutoProcessor, BlipForImageTextRetrieval from towhee import register from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color @@ -32,6 +33,7 @@ class BLIPModelVision(nn.Module): def forward(self, image): image_embeds = self.model.visual_encoder(image) + image_embeds = self.model.vision_proj(image_embeds[:,0,:]) return image_embeds #@accelerate @@ -41,10 +43,10 @@ class BLIPModelText(nn.Module): self.model = model def forward(self, input_ids, attention_mask): - text_output = self.text_encoder(input_ids, attention_mask = attention_mask, - return_dict = False, mode = 'text') - return text_output - + text_features = self.model.text_encoder(input_ids, attention_mask = attention_mask, + return_dict = False, mode = 'text')[0] + text_features = self.model.text_proj(text_features[:,0,:])j + return text_features @register(output_schema=['vec']) class Blip(NNOperator): @@ -53,8 +55,6 @@ class Blip(NNOperator): """ def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None): super().__init__() - sys.path.append(str(Path(__file__).parent)) - from models.blip import blip_feature_extractor self.model_name = model_name self.device = device cfg = self._configs()[model_name] @@ -62,8 +62,8 @@ class Blip(NNOperator): model_url = cfg['weights'] image_size = cfg['image_size'] - model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base') - self.tokenizer = model.tokenizer + model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco") + self.processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco") if self.modality == 'image': self.model = BLIPModelVision(model) @@ -77,11 +77,11 @@ class Blip(NNOperator): self.model.to(self.device) self.model.eval() - self.tfms = transforms.Compose([ - transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ]) + #self.tfms = transforms.Compose([ + # transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), + # transforms.ToTensor(), + # transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + # ]) def __call__(self, data): if self._modality == 'image': @@ -93,21 +93,18 @@ class Blip(NNOperator): return vec.detach().cpu().numpy().flatten() def _inference_from_text(self, text): - tokens = self.tokenizer(text, return_tensors="pt").to(self.device) - text_feature = self.model(tokens.input_ids, tokens.attention_mask) + inputs = self.processor(text=text, padding=True, return_tensors="pt") + inputs = inputs.to(self.device) + text_feature = self.model(input_ids = inputs. , attention_mask)[0] return text_feature @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): - img = self._preprocess(img) - image_feature = self.model(img) + inputs = self.processor(images=img, return_tensors="pt") + inputs = inputs.to(self.device) + image_feature = self.model(inputs) return image_feature - def _preprocess(self, img): - img = to_pil(img) - processed_img = self.tfms(img).unsqueeze(0).to(self.device) - return processed_img - def _configs(self): config = {} config['blip_base'] = {}