diff --git a/blip.py b/blip.py index ea6e062..29e6e6c 100644 --- a/blip.py +++ b/blip.py @@ -24,18 +24,53 @@ from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee.types.image_utils import from_pil, to_pil +#@accelerate +class BLIPModelVision(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, image): + image_embeds = self.model.visual_encoder(image) + return image_embeds + +#@accelerate +class BLIPModelText(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_ids, attention_mask): + text_output = self.text_encoder(input_ids, attention_mask = attention_mask, + return_dict = False, mode = 'text') + return text_output + + @register(output_schema=['vec']) class Blip(NNOperator): """ BLIP multi-modal embedding operator """ - def __init__(self, model_name: str, modality: str): + def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None): super().__init__() sys.path.append(str(Path(__file__).parent)) from models.blip import blip_feature_extractor - image_size = 224 - model_url = self._configs()[model_name]['weights'] - self.model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base') + self.model_name = model_name + self.device = device + cfg = self._configs()[model_name] + + model_url = cfg['weights'] + image_size = cfg['image_size'] + + model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base') + self.tokenizer = model.tokenizer + + if self.modality == 'image': + self.model = BLIPModelVision(model) + elif self.modality == 'text': + self.model = BLIPModelText(model) + else: + raise ValueError("modality[{}] not implemented.".format(self.modality)) self._modality = modality self.device = "cuda" if torch.cuda.is_available() else "cpu" @@ -58,14 +93,14 @@ class Blip(NNOperator): return vec.detach().cpu().numpy().flatten() def _inference_from_text(self, text): - text_feature = self.model(None, text, mode='text', device=self.device)[0,0] + tokens = self.tokenizer(text, return_tensors="pt").to(self.device) + text_feature = self.model(tokens.input_ids, tokens.attention_mask) return text_feature @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): img = self._preprocess(img) - caption = '' - image_feature = self.model(img, caption, mode='image', device=self.device)[0,0] + image_feature = self.model(img) return image_feature def _preprocess(self, img): @@ -77,5 +112,6 @@ class Blip(NNOperator): config = {} config['blip_base'] = {} config['blip_base']['weights'] = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth' + config['blip_base']['image_size'] = 224 return config