diff --git a/blip.py b/blip.py index 1c04ca6..83e3ef9 100644 --- a/blip.py +++ b/blip.py @@ -32,29 +32,63 @@ warnings.filterwarnings('ignore') os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' t_logging.set_verbosity_error() +def create_model(cfg, modality, checkpoint_path, device): + hf_blip_model = BlipForImageTextRetrieval.from_pretrained(cfg) + if checkpoint_path: + try: + state_dict = torch.load(checkpoint_path, map_location=device) + hf_blip_model.load_state_dict(state_dict) + except Exception as e: + log.error(f"Fail to load state dict from {checkpoint_path}: {e}") + hf_blip_model.to(device) + hf_blip_model.eval() + + if modality == 'image': + blip = BLIPModelVision(hf_blip_model) + elif modality == 'text': + blip = BLIPModelText(hf_blip_model) + else: + raise ValueError("modality[{}] not implemented.".format(modality)) + return blip + #@accelerate class BLIPModelVision(nn.Module): def __init__(self, model): super().__init__() - self.model = model + self.backbone = model def forward(self, image): - image_embeds = self.model.visual_encoder(image) - image_embeds = self.model.vision_proj(image_embeds[:,0,:]) + image_embeds = self.backbone.vision_model(image)[0] + image_embeds = self.backbone.vision_proj(image_embeds[:,0,:]) return image_embeds #@accelerate class BLIPModelText(nn.Module): def __init__(self, model): super().__init__() - self.model = model + self.backbone = model def forward(self, input_ids, attention_mask): - text_features = self.model.text_encoder(input_ids, attention_mask = attention_mask, + text_features = self.backbone.text_encoder(input_ids, attention_mask = attention_mask, return_dict = False)[0] - text_features = self.model.text_proj(text_features[:,0,:]) + text_features = self.backbone.text_proj(text_features[:,0,:]) return text_features +class Model: + def __init__(self, model_name, modality, checkpoint_path, device): + self.model = create_model(model_name, modality, checkpoint_path, device) + self.device = device + + def __call__(self, *args, **kwargs): + new_args = [] + for item in args: + new_args.append(item.to(self.device)) + new_kwargs = {} + for k, value in kwargs.items(): + new_kwargs[k] = value.to(self.device) + outs = self.model(*new_args, **new_kwargs) + return outs + @register(output_schema=['vec']) class Blip(NNOperator): """ @@ -62,44 +96,26 @@ class Blip(NNOperator): """ def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None): super().__init__() + real_name = self._configs()[model_name]['name'] + self.model = Model(real_name, modality, checkpoint_path, device) self.modality = modality - self.model_name = model_name self.device = device - cfg = self._configs()[model_name] - - try: - blip_model = BlipForImageTextRetrieval.from_pretrained(cfg) - except Exception as e: - log.error(f'Fail to load model by name: {self.model_name}') - raise e - if checkpoint_path: - try: - state_dict = torch.load(checkpoint_path, map_location=self.device) - self.model.load_state_dict(state_dict) - except Exception as e: - log.error(f'Fail to load state dict from {checkpoint_path}: {e}') + self.checkpoint_path = checkpoint_path self.processor = AutoProcessor.from_pretrained('Salesforce/blip-itm-base-coco') - if self.modality == 'image': - self.model = BLIPModelVision(blip_model) - elif self.modality == 'text': - self.model = BLIPModelText(blip_model) - else: - raise ValueError('modality[{}] not implemented.'.format(self.modality)) - - self._modality = modality - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.model.to(self.device) - self.model.eval() - def __call__(self, data): - if self._modality == 'image': - vec = self._inference_from_image(data) - elif self._modality == 'text': - vec = self._inference_from_text(data) + if not isinstance(data, list): + data = [data] else: - raise ValueError('modality[{}] not implemented.'.format(self._modality)) - return vec.detach().cpu().numpy().flatten() + data = data + results = [] + for single_data in data: + result = self.inference_single_data(single_data) + results.append(result) + if len(data) == 1: + return results[0] + else: + return results def _inference_from_text(self, text): inputs = self.processor(text=text, padding=True, return_tensors='pt') @@ -109,15 +125,24 @@ class Blip(NNOperator): @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): - inputs = self.processor(images=img, return_tensors='pt') + inputs = self.processor(images=img, return_tensors='pt')['pixel_values'] inputs = inputs.to(self.device) image_feature = self.model(inputs) return image_feature + def inference_single_data(self, data): + if self.modality == 'image': + vec = self._inference_from_image(data) + elif self.modality == 'text': + vec = self._inference_from_text(data) + else: + raise ValueError("modality[{}] not implemented.".format(self.modality)) + return vec.detach().cpu().numpy().flatten() + def _configs(self): config = {} config['blip_itm_base'] = {} - config['blip_itm_base']['weights'] = 'Salesforce/blip-itm-base-coco' + config['blip_itm_base']['name'] = 'Salesforce/blip-itm-base-coco' config['blip_itm_base']['image_size'] = 224 return config