logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
b0dd7f82cb
  1. 43
      README.md
  2. 192
      blip.py
  3. 2
      requirements.txt

43
README.md

@ -21,38 +21,27 @@ Load an image from path './teddy.jpg' to generate an image embedding.
Read the text 'A teddybear on a skateboard in Times Square.' to generate an text embedding.
*Write the pipeline in simplified style*:
*Write a pipeline with explicit inputs/outputs name specifications:*
```python
import towhee
from towhee.dc2 import pipe, ops, DataCollection
towhee.glob('./teddy.jpg') \
.image_decode() \
.image_text_embedding.blip(model_name='blip_base', modality='image') \
.show()
img_pipe = (
pipe.input('url')
.map('url', 'img', ops.image_decode.cv2_rgb())
.map('img', 'vec', ops.image_text_embedding.blip(model_name='blip_itm_base', modality='image'))
.output('img', 'vec')
)
towhee.dc(["A teddybear on a skateboard in Times Square."]) \
.image_text_embedding.blip(model_name='blip_base', modality='text') \
.show()
```
<img src="https://towhee.io/image-text-embedding/blip/raw/branch/main/vec1.png" alt="result1" style="height:20px;"/>
<img src="https://towhee.io/image-text-embedding/blip/raw/branch/main/vec2.png" alt="result2" style="height:20px;"/>
text_pipe = (
pipe.input('text')
.map('text', 'vec', ops.image_text_embedding.blip(model_name='blip_itm_base', modality='text'))
.output('text', 'vec')
)
*Write a same pipeline with explicit inputs/outputs name specifications:*
DataCollection(image_pipe('./teddy.jpg')).show()
DataCollection(text_pipe('A teddybear on a skateboard in Times Square.')).show()
```python
import towhee
towhee.glob['path']('./teddy.jpg') \
.image_decode['path', 'img']() \
.image_text_embedding.blip['img', 'vec'](model_name='blip_base', modality='image') \
.select['img', 'vec']() \
.show()
towhee.dc['text'](["A teddybear on a skateboard in Times Square."]) \
.image_text_embedding.blip['text','vec'](model_name='blip_base', modality='text') \
.select['text', 'vec']() \
.show()
```
<img src="https://towhee.io/image-text-embedding/blip/raw/branch/main/tabular1.png" alt="result1" style="height:60px;"/>
<img src="https://towhee.io/image-text-embedding/blip/raw/branch/main/tabular2.png" alt="result2" style="height:60px;"/>
@ -73,7 +62,7 @@ Create the operator via the following factory method
***model_name:*** *str*
​ The model name of BLIP. Supported model names:
- blip_base
- blip_itm_base
***modality:*** *str*

192
blip.py

@ -16,66 +16,202 @@ import sys
from pathlib import Path
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from torch import nn
from transformers import AutoProcessor, BlipForImageTextRetrieval
from transformers import logging as t_logging
from towhee import register
from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
from towhee.types.image_utils import from_pil, to_pil
log = logging.getLogger('run_op')
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error()
#@accelerate
class BLIPModelVision(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, image):
image_embeds = self.model.visual_encoder(image)
image_embeds = self.model.vision_proj(image_embeds[:,0,:])
return image_embeds
#@accelerate
class BLIPModelText(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
text_features = self.model.text_encoder(input_ids, attention_mask = attention_mask,
return_dict = False)[0]
text_features = self.model.text_proj(text_features[:,0,:])
return text_features
@register(output_schema=['vec'])
class Blip(NNOperator):
"""
BLIP multi-modal embedding operator
"""
def __init__(self, model_name: str, modality: str):
def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None):
super().__init__()
sys.path.append(str(Path(__file__).parent))
from models.blip import blip_feature_extractor
image_size = 224
model_url = self._configs()[model_name]['weights']
self.model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base')
self.modality = modality
self.model_name = model_name
self.device = device
cfg = self._configs()[model_name]
try:
blip_model = BlipForImageTextRetrieval.from_pretrained(cfg)
except Exception as e:
log.error(f'Fail to load model by name: {self.model_name}')
raise e
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(state_dict)
except Exception as e:
log.error(f'Fail to load state dict from {checkpoint_path}: {e}')
self.processor = AutoProcessor.from_pretrained('Salesforce/blip-itm-base-coco')
if self.modality == 'image':
self.model = BLIPModelVision(blip_model)
elif self.modality == 'text':
self.model = BLIPModelText(blip_model)
else:
raise ValueError('modality[{}] not implemented.'.format(self.modality))
self._modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model.to(self.device)
self.model.eval()
self.tfms = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
def __call__(self, data):
if self._modality == 'image':
vec = self._inference_from_image(data)
elif self._modality == 'text':
vec = self._inference_from_text(data)
else:
raise ValueError("modality[{}] not implemented.".format(self._modality))
raise ValueError('modality[{}] not implemented.'.format(self._modality))
return vec.detach().cpu().numpy().flatten()
def _inference_from_text(self, text):
text_feature = self.model(None, text, mode='text', device=self.device)[0,0]
inputs = self.processor(text=text, padding=True, return_tensors='pt')
inputs = inputs.to(self.device)
text_feature = self.model(input_ids = inputs.input_ids, attention_mask = inputs.attention_mask)[0]
return text_feature
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = self._preprocess(img)
caption = ''
image_feature = self.model(img, caption, mode='image', device=self.device)[0,0]
inputs = self.processor(images=img, return_tensors='pt')
inputs = inputs.to(self.device)
image_feature = self.model(inputs)
return image_feature
def _preprocess(self, img):
img = to_pil(img)
processed_img = self.tfms(img).unsqueeze(0).to(self.device)
return processed_img
def _configs(self):
config = {}
config['blip_base'] = {}
config['blip_base']['weights'] = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth'
config['blip_itm_base'] = {}
config['blip_itm_base']['weights'] = 'Salesforce/blip-itm-base-coco'
config['blip_itm_base']['image_size'] = 224
return config
@property
def _model(self):
return self.model
def train(self, **kwargs):
raise NotImplementedError
@property
def supported_formats(self):
onnxes = self.supported_model_names(format='onnx')
if self.model_name in onnxes:
return ['onnx']
else:
return ['pytorch']
@staticmethod
def supported_model_names(format: str = None):
if format == 'pytorch' or format == 'torchscript' or format == 'onnx':
model_list = [
'blip_itm_base',
]
else:
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
return model_list
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
import os
from PIL import Image
from torch.onnx import export as onnx_export
if output_file == 'default':
output_file = str(Path(__file__).parent)
output_file = os.path.join(output_file, 'saved', model_type)
os.makedirs(output_file, exist_ok=True)
name = self.model_name.replace('/', '-')
output_file = os.path.join(output_file, name)
if model_type in ['pytorch', 'torchscript']:
output_file = output_file + '.pt'
elif model_type == 'onnx':
output_file = output_file + '.onnx'
else:
raise AttributeError('Unsupported model_type.')
if self.modality == 'image':
sz = self.processor.feature_extractor.crop_size
if isinstance(sz, int):
h = sz
w = sz
elif isinstance(sz, dict):
h = sz['height']
w = sz['width']
dummy_input = Image.new('RGB', (w, h), color = 'red')
inputs = self.processor(images=dummy_input, return_tensors='pt') # a dictionary
elif self.modality == 'text':
dummy_input = 'dummy'
inputs = self.processor(text=dummy_input, padding=True, return_tensors='pt')
else:
raise ValueError('modality[{}] not implemented.'.format(self.modality))
if model_type == 'pytorch':
torch.save(self._model, output_file)
elif model_type == 'torchscript':
inputs = list(inputs.values())
try:
try:
jit_model = torch.jit.script(self._model)
except Exception:
jit_model = torch.jit.trace(self._model, inputs, strict=False)
torch.jit.save(jit_model, output_file)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif model_type == 'onnx':
if self.modality == 'image':
input_names= ['pixel_values']
output_names=['image_embeds']
dynamic_axes={'pixel_values': {0: 'batch'}, 'image_embeds': {0: 'batch'}}
elif self.modality == 'text':
input_names= ['input_ids', 'attention_mask']
output_names=['text_embeds']
dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}, 'attention_mask': {0: 'batch', 1: 'sequence'}, 'text_embeds': {0: 'batch'}}
else:
raise ValueError('modality[{}] not implemented.'.format(self.modality))
onnx_export(self.model,
(dict(inputs),),
f=Path(output_file),
input_names= input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=14,
)
else:
pass
raise NotImplementedError

2
requirements.txt

@ -3,4 +3,4 @@ torchvision>=0.10.0
Pillow
towhee
timm
transformers>=4.15.0
transformers>=4.26.0

Loading…
Cancel
Save