# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os from pathlib import Path import torch import logging import warnings from torch import nn from transformers import AutoProcessor, BlipForImageTextRetrieval from transformers import logging as t_logging from towhee import register from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color #from towhee.dc2 import accelerate log = logging.getLogger('run_op') warnings.filterwarnings('ignore') os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' t_logging.set_verbosity_error() def create_model(cfg, modality, checkpoint_path, device): hf_blip_model = BlipForImageTextRetrieval.from_pretrained(cfg) if checkpoint_path: try: state_dict = torch.load(checkpoint_path, map_location=device) hf_blip_model.load_state_dict(state_dict) except Exception as e: log.error(f"Fail to load state dict from {checkpoint_path}: {e}") hf_blip_model.to(device) hf_blip_model.eval() if modality == 'image': blip = BLIPModelVision(hf_blip_model) elif modality == 'text': blip = BLIPModelText(hf_blip_model) else: raise ValueError("modality[{}] not implemented.".format(modality)) return blip #@accelerate class BLIPModelVision(nn.Module): def __init__(self, model): super().__init__() self.backbone = model def forward(self, pixel_values): image_embeds = self.backbone.vision_model(pixel_values)[0] image_embeds = self.backbone.vision_proj(image_embeds[:,0,:]) return image_embeds #@accelerate class BLIPModelText(nn.Module): def __init__(self, model): super().__init__() self.backbone = model def forward(self, input_ids, attention_mask): text_features = self.backbone.text_encoder(input_ids, attention_mask = attention_mask, return_dict = False)[0] text_features = self.backbone.text_proj(text_features[:,0,:]) return text_features class Model: def __init__(self, model_name, modality, checkpoint_path, device): self.model = create_model(model_name, modality, checkpoint_path, device) self.device = device def __call__(self, *args, **kwargs): new_args = [] for item in args: new_args.append(item.to(self.device)) new_kwargs = {} for k, value in kwargs.items(): new_kwargs[k] = value.to(self.device) outs = self.model(*new_args, **new_kwargs) return outs @register(output_schema=['vec']) class Blip(NNOperator): """ BLIP multi-modal embedding operator """ def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None): super().__init__() self.model_name = model_name real_name = self._configs()[model_name]['name'] self.model = Model(real_name, modality, checkpoint_path, device) self.modality = modality self.device = device self.checkpoint_path = checkpoint_path self.processor = AutoProcessor.from_pretrained(real_name) def __call__(self, data): if not isinstance(data, list): data = [data] else: data = data results = [] for single_data in data: result = self.inference_single_data(single_data) results.append(result) if len(data) == 1: return results[0] else: return results def _inference_from_text(self, text): inputs = self.processor(text=text, padding=True, return_tensors='pt') inputs = inputs.to(self.device) text_feature = self.model(input_ids = inputs.input_ids, attention_mask = inputs.attention_mask)[0] return text_feature @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): inputs = self.processor(images=img, return_tensors='pt') inputs = inputs.to(self.device) image_feature = self.model(inputs['pixel_values']) return image_feature def inference_single_data(self, data): if self.modality == 'image': vec = self._inference_from_image(data) elif self.modality == 'text': vec = self._inference_from_text(data) else: raise ValueError("modality[{}] not implemented.".format(self.modality)) return vec.detach().cpu().numpy().flatten() def _configs(self): config = {} config['blip_itm_base_coco'] = {} config['blip_itm_base_coco']['name'] = 'Salesforce/blip-itm-base-coco' config['blip_itm_base_flickr'] = {} config['blip_itm_base_flickr']['name'] = 'Salesforce/blip-itm-base-flickr' config['blip_itm_large_coco'] = {} config['blip_itm_large_coco']['name'] = 'Salesforce/blip-itm-large-coco' config['blip_itm_large_flickr'] = {} config['blip_itm_large_flickr']['name'] = 'Salesforce/blip-itm-large-flickr' return config @property def _model(self): return self.model.model def train(self, **kwargs): raise NotImplementedError @property def supported_formats(self): onnxes = self.supported_model_names(format='onnx') if self.model_name in onnxes: return ['onnx'] else: return ['pytorch'] @staticmethod def supported_model_names(format: str = None): full_list = ['blip_itm_base'] if format == None: model_list = full_list elif format == 'pytorch' or format == 'torchscript' or format == 'onnx': model_list = full_list else: log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') return model_list def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): import os from PIL import Image from torch.onnx import export as onnx_export if output_file == 'default': output_file = str(Path(__file__).parent) output_file = os.path.join(output_file, 'saved', model_type) os.makedirs(output_file, exist_ok=True) name = self.model_name.replace('/', '-') output_file = os.path.join(output_file, name) if model_type in ['pytorch', 'torchscript']: output_file = output_file + '.pt' elif model_type == 'onnx': output_file = output_file + '.onnx' else: raise AttributeError('Unsupported model_type.') if self.modality == 'image': sz = self.processor.image_processor.size if isinstance(sz, int): h = sz w = sz elif isinstance(sz, dict): h = sz['height'] w = sz['width'] dummy_input = Image.new('RGB', (w, h), color = 'red') inputs = self.processor(images=dummy_input, return_tensors='pt') # a dictionary elif self.modality == 'text': dummy_input = 'dummy' inputs = self.processor(text=dummy_input, padding=True, return_tensors='pt') else: raise ValueError('modality[{}] not implemented.'.format(self.modality)) if model_type == 'pytorch': torch.save(self._model, output_file) elif model_type == 'torchscript': inputs = list(inputs.values()) try: try: jit_model = torch.jit.script(self._model) except Exception: jit_model = torch.jit.trace(self._model, inputs, strict=False) torch.jit.save(jit_model, output_file) except Exception as e: log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') elif model_type == 'onnx': if self.modality == 'image': input_names= ['pixel_values'] output_names=['image_embeds'] dynamic_axes={'pixel_values': {0: 'batch'}, 'image_embeds': {0: 'batch'}} elif self.modality == 'text': input_names= ['input_ids', 'attention_mask'] output_names=['text_embeds'] dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}, 'attention_mask': {0: 'batch', 1: 'sequence'}, 'text_embeds': {0: 'batch'}} else: raise ValueError('modality[{}] not implemented.'.format(self.modality)) onnx_export(self._model, (dict(inputs),), f=Path(output_file), input_names= input_names, output_names=output_names, dynamic_axes=dynamic_axes, do_constant_folding=True, opset_version=14, ) else: pass raise NotImplementedError