logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
b0dd7f82cb
  1. 43
      README.md
  2. 192
      blip.py
  3. 2
      requirements.txt

43
README.md

@ -21,38 +21,27 @@ Load an image from path './teddy.jpg' to generate an image embedding.
Read the text 'A teddybear on a skateboard in Times Square.' to generate an text embedding. Read the text 'A teddybear on a skateboard in Times Square.' to generate an text embedding.
*Write the pipeline in simplified style*:
*Write a pipeline with explicit inputs/outputs name specifications:*
```python ```python
import towhee
from towhee.dc2 import pipe, ops, DataCollection
towhee.glob('./teddy.jpg') \
.image_decode() \
.image_text_embedding.blip(model_name='blip_base', modality='image') \
.show()
img_pipe = (
pipe.input('url')
.map('url', 'img', ops.image_decode.cv2_rgb())
.map('img', 'vec', ops.image_text_embedding.blip(model_name='blip_itm_base', modality='image'))
.output('img', 'vec')
)
towhee.dc(["A teddybear on a skateboard in Times Square."]) \
.image_text_embedding.blip(model_name='blip_base', modality='text') \
.show()
```
<img src="https://towhee.io/image-text-embedding/blip/raw/branch/main/vec1.png" alt="result1" style="height:20px;"/>
<img src="https://towhee.io/image-text-embedding/blip/raw/branch/main/vec2.png" alt="result2" style="height:20px;"/>
text_pipe = (
pipe.input('text')
.map('text', 'vec', ops.image_text_embedding.blip(model_name='blip_itm_base', modality='text'))
.output('text', 'vec')
)
*Write a same pipeline with explicit inputs/outputs name specifications:*
DataCollection(image_pipe('./teddy.jpg')).show()
DataCollection(text_pipe('A teddybear on a skateboard in Times Square.')).show()
```python
import towhee
towhee.glob['path']('./teddy.jpg') \
.image_decode['path', 'img']() \
.image_text_embedding.blip['img', 'vec'](model_name='blip_base', modality='image') \
.select['img', 'vec']() \
.show()
towhee.dc['text'](["A teddybear on a skateboard in Times Square."]) \
.image_text_embedding.blip['text','vec'](model_name='blip_base', modality='text') \
.select['text', 'vec']() \
.show()
``` ```
<img src="https://towhee.io/image-text-embedding/blip/raw/branch/main/tabular1.png" alt="result1" style="height:60px;"/> <img src="https://towhee.io/image-text-embedding/blip/raw/branch/main/tabular1.png" alt="result1" style="height:60px;"/>
<img src="https://towhee.io/image-text-embedding/blip/raw/branch/main/tabular2.png" alt="result2" style="height:60px;"/> <img src="https://towhee.io/image-text-embedding/blip/raw/branch/main/tabular2.png" alt="result2" style="height:60px;"/>
@ -73,7 +62,7 @@ Create the operator via the following factory method
***model_name:*** *str* ***model_name:*** *str*
​ The model name of BLIP. Supported model names: ​ The model name of BLIP. Supported model names:
- blip_base
- blip_itm_base
***modality:*** *str* ***modality:*** *str*

192
blip.py

@ -16,66 +16,202 @@ import sys
from pathlib import Path from pathlib import Path
import torch import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from torch import nn
from transformers import AutoProcessor, BlipForImageTextRetrieval
from transformers import logging as t_logging
from towhee import register from towhee import register
from towhee.operator.base import NNOperator, OperatorFlag from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color from towhee.types.arg import arg, to_image_color
from towhee.types.image_utils import from_pil, to_pil
log = logging.getLogger('run_op')
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error()
#@accelerate
class BLIPModelVision(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, image):
image_embeds = self.model.visual_encoder(image)
image_embeds = self.model.vision_proj(image_embeds[:,0,:])
return image_embeds
#@accelerate
class BLIPModelText(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
text_features = self.model.text_encoder(input_ids, attention_mask = attention_mask,
return_dict = False)[0]
text_features = self.model.text_proj(text_features[:,0,:])
return text_features
@register(output_schema=['vec']) @register(output_schema=['vec'])
class Blip(NNOperator): class Blip(NNOperator):
""" """
BLIP multi-modal embedding operator BLIP multi-modal embedding operator
""" """
def __init__(self, model_name: str, modality: str):
def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None):
super().__init__() super().__init__()
sys.path.append(str(Path(__file__).parent))
from models.blip import blip_feature_extractor
image_size = 224
model_url = self._configs()[model_name]['weights']
self.model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base')
self.modality = modality
self.model_name = model_name
self.device = device
cfg = self._configs()[model_name]
try:
blip_model = BlipForImageTextRetrieval.from_pretrained(cfg)
except Exception as e:
log.error(f'Fail to load model by name: {self.model_name}')
raise e
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(state_dict)
except Exception as e:
log.error(f'Fail to load state dict from {checkpoint_path}: {e}')
self.processor = AutoProcessor.from_pretrained('Salesforce/blip-itm-base-coco')
if self.modality == 'image':
self.model = BLIPModelVision(blip_model)
elif self.modality == 'text':
self.model = BLIPModelText(blip_model)
else:
raise ValueError('modality[{}] not implemented.'.format(self.modality))
self._modality = modality self._modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model.to(self.device) self.model.to(self.device)
self.model.eval() self.model.eval()
self.tfms = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
def __call__(self, data): def __call__(self, data):
if self._modality == 'image': if self._modality == 'image':
vec = self._inference_from_image(data) vec = self._inference_from_image(data)
elif self._modality == 'text': elif self._modality == 'text':
vec = self._inference_from_text(data) vec = self._inference_from_text(data)
else: else:
raise ValueError("modality[{}] not implemented.".format(self._modality))
raise ValueError('modality[{}] not implemented.'.format(self._modality))
return vec.detach().cpu().numpy().flatten() return vec.detach().cpu().numpy().flatten()
def _inference_from_text(self, text): def _inference_from_text(self, text):
text_feature = self.model(None, text, mode='text', device=self.device)[0,0]
inputs = self.processor(text=text, padding=True, return_tensors='pt')
inputs = inputs.to(self.device)
text_feature = self.model(input_ids = inputs.input_ids, attention_mask = inputs.attention_mask)[0]
return text_feature return text_feature
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def _inference_from_image(self, img): def _inference_from_image(self, img):
img = self._preprocess(img)
caption = ''
image_feature = self.model(img, caption, mode='image', device=self.device)[0,0]
inputs = self.processor(images=img, return_tensors='pt')
inputs = inputs.to(self.device)
image_feature = self.model(inputs)
return image_feature return image_feature
def _preprocess(self, img):
img = to_pil(img)
processed_img = self.tfms(img).unsqueeze(0).to(self.device)
return processed_img
def _configs(self): def _configs(self):
config = {} config = {}
config['blip_base'] = {}
config['blip_base']['weights'] = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth'
config['blip_itm_base'] = {}
config['blip_itm_base']['weights'] = 'Salesforce/blip-itm-base-coco'
config['blip_itm_base']['image_size'] = 224
return config return config
@property
def _model(self):
return self.model
def train(self, **kwargs):
raise NotImplementedError
@property
def supported_formats(self):
onnxes = self.supported_model_names(format='onnx')
if self.model_name in onnxes:
return ['onnx']
else:
return ['pytorch']
@staticmethod
def supported_model_names(format: str = None):
if format == 'pytorch' or format == 'torchscript' or format == 'onnx':
model_list = [
'blip_itm_base',
]
else:
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
return model_list
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
import os
from PIL import Image
from torch.onnx import export as onnx_export
if output_file == 'default':
output_file = str(Path(__file__).parent)
output_file = os.path.join(output_file, 'saved', model_type)
os.makedirs(output_file, exist_ok=True)
name = self.model_name.replace('/', '-')
output_file = os.path.join(output_file, name)
if model_type in ['pytorch', 'torchscript']:
output_file = output_file + '.pt'
elif model_type == 'onnx':
output_file = output_file + '.onnx'
else:
raise AttributeError('Unsupported model_type.')
if self.modality == 'image':
sz = self.processor.feature_extractor.crop_size
if isinstance(sz, int):
h = sz
w = sz
elif isinstance(sz, dict):
h = sz['height']
w = sz['width']
dummy_input = Image.new('RGB', (w, h), color = 'red')
inputs = self.processor(images=dummy_input, return_tensors='pt') # a dictionary
elif self.modality == 'text':
dummy_input = 'dummy'
inputs = self.processor(text=dummy_input, padding=True, return_tensors='pt')
else:
raise ValueError('modality[{}] not implemented.'.format(self.modality))
if model_type == 'pytorch':
torch.save(self._model, output_file)
elif model_type == 'torchscript':
inputs = list(inputs.values())
try:
try:
jit_model = torch.jit.script(self._model)
except Exception:
jit_model = torch.jit.trace(self._model, inputs, strict=False)
torch.jit.save(jit_model, output_file)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif model_type == 'onnx':
if self.modality == 'image':
input_names= ['pixel_values']
output_names=['image_embeds']
dynamic_axes={'pixel_values': {0: 'batch'}, 'image_embeds': {0: 'batch'}}
elif self.modality == 'text':
input_names= ['input_ids', 'attention_mask']
output_names=['text_embeds']
dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}, 'attention_mask': {0: 'batch', 1: 'sequence'}, 'text_embeds': {0: 'batch'}}
else:
raise ValueError('modality[{}] not implemented.'.format(self.modality))
onnx_export(self.model,
(dict(inputs),),
f=Path(output_file),
input_names= input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=14,
)
else:
pass
raise NotImplementedError

2
requirements.txt

@ -3,4 +3,4 @@ torchvision>=0.10.0
Pillow Pillow
towhee towhee
timm timm
transformers>=4.15.0
transformers>=4.26.0

Loading…
Cancel
Save