logo
Browse Source

init the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
a4ce3d8caf
  1. 97
      README.md
  2. 18
      __init__.py
  3. 146
      lightningdot.py
  4. 4
      requirements.txt
  5. BIN
      tabular1.png
  6. BIN
      tabular2.png
  7. 41
      utils.py
  8. BIN
      vec1.png
  9. BIN
      vec2.png

97
README.md

@ -1,2 +1,97 @@
# lightningdot
# Image-Text Retrieval Embdding with LightningDOT
*author: David Wang*
<br />
## Description
This operator extracts features for image or text with [LightningDOT](https://arxiv.org/abs/2103.08784) which can generate embeddings for text and image by jointly training an image encoder and text encoder to maximize the cosine similarity.
<br />
## Code Example
Load an image from path './teddy.jpg' to generate an image embedding.
Read the text 'A teddybear on a skateboard in Times Square.' to generate an text embedding.
*Write the pipeline in simplified style*:
```python
import towhee
towhee.glob('./teddy.jpg') \
.image_decode() \
.image_text_embedding.lightningdot(modality='image') \
.show()
towhee.dc(["A teddybear on a skateboard in Times Square."]) \
.image_text_embedding.lightningdot(modality='text') \
.show()
```
<img src="https://towhee.io/towhee/lightningdot/raw/branch/main/vec1.png" alt="result1" style="height:20px;"/>
<img src="https://towhee.io/towhee/lightningdot/raw/branch/main/vec2.png" alt="result2" style="height:20px;"/>
*Write a same pipeline with explicit inputs/outputs name specifications:*
```python
import towhee
towhee.glob['path']('./teddy.jpg') \
.image_decode['path', 'img']() \
.image_text_embedding.lightningdot['img', 'vec'](modality='image') \
.select['img', 'vec']() \
.show()
towhee.dc['text'](["A teddybear on a skateboard in Times Square."]) \
.image_text_embedding.lightningdot['text','vec'](modality='text') \
.select['text', 'vec']() \
.show()
```
<img src="https://towhee.io/image-text-embedding/lightningdot/raw/branch/main/tabular1.png" alt="result1" style="height:60px;"/>
<img src="https://towhee.io/image-text-embedding/lightningdot/raw/branch/main/tabular2.png" alt="result2" style="height:60px;"/>
<br />
## Factory Constructor
Create the operator via the following factory method
***lightningdot(modality)***
**Parameters:**
​ ***modality:*** *str*
​ Which modality(*image* or *text*) is used to generate the embedding.
<br />
## Interface
An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) or string as input and generate an embedding in ndarray.
**Parameters:**
​ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str*
​ The data (image or text based on specified modality) to generate embedding.
**Returns:** *numpy.ndarray*
​ The data embedding extracted by model.

18
__init__.py

@ -0,0 +1,18 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .lightningdot import LightningDOT
def lightningdot(modality: str):
return LightningDOT(modality)

146
lightningdot.py

@ -0,0 +1,146 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import json
import torch
from pathlib import Path
import numpy as np
from transformers.tokenization_bert import BertTokenizer
from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
from towhee import register
from .utils import Configs, get_gather_index
def arg_process(args):
dirname = os.path.dirname(__file__)
args.img_checkpoint = dirname + '/' + args.img_checkpoint
args.img_model_config = dirname + '/' + args.img_model_config
return args
@register(output_schema=['vec'])
class LightningDOT(NNOperator):
"""
CLIP multi-modal embedding operator
"""
def __init__(self, modality: str):
sys.path.append(str(Path(__file__).parent))
from dvl.models.bi_encoder import BiEncoder
from detector.faster_rcnn import Net, process_img
full_path = os.path.dirname(__file__) + '/config/flickr30k_ft_config.json'
with open(full_path) as fw:
content = fw.read()
args = json.loads(content)
args = Configs(args)
args = arg_process(args)
self.bi_encoder = BiEncoder(args, True, True, project_dim=args.project_dim)
self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
img_model, txt_model = self.bi_encoder.img_model, self.bi_encoder.txt_model
img_model.eval()
txt_model.eval()
self.faster_rcnn_preprocess = process_img
self.faster_rcnn = Net()
self.faster_rcnn.load_state_dict(torch.load(os.path.dirname(__file__) + '/data/model/resnet101_faster_rcnn_final.pth'))
self.faster_rcnn.eval()
self.modality = modality
def img_detfeat_extract(self, img):
orig_im_scale = [img.shape[1], img.shape[0]]
img, im_scale = self.faster_rcnn_preprocess(img)
img = np.expand_dims(img.transpose((2,0,1)), 0)
img = torch.FloatTensor(img)
bboxes, feat, confidence = self.faster_rcnn(img, im_scale)
bboxes = self.bbox_feat_process(bboxes, orig_im_scale)
img_bb = torch.cat([bboxes, bboxes[:, 4:5]*bboxes[:, 5:]], dim=-1)
return img_bb, feat, confidence
def bbox_feat_process(self, bboxes, im_scale):
image_w, image_h = im_scale
box_width = bboxes[:, 2] - bboxes[:, 0]
box_height = bboxes[:, 3] - bboxes[:, 1]
scaled_width = box_width / image_w
scaled_height = box_height / image_h
scaled_x = bboxes[:, 0] / image_w
scaled_y = bboxes[:, 1] / image_h
box_width = box_width.unsqueeze(1)
box_height = box_height.unsqueeze(1)
scaled_width = scaled_width.unsqueeze(1)
scaled_height = scaled_height.unsqueeze(1)
scaled_x = scaled_x.unsqueeze(1)
scaled_y = scaled_y .unsqueeze(1)
normalized_bbox = torch.hstack((scaled_x, scaled_y,
scaled_x + scaled_width,
scaled_y + scaled_height,
scaled_width, scaled_height))
return normalized_bbox
def get_img_feat(self, data):
img_pos_feat, img_feat, _ = self.img_detfeat_extract(data)
num_bb = img_pos_feat.shape[1]
img_input_ids = torch.Tensor([101]).long()
return img_feat, img_pos_feat, img_input_ids
def __call__(self, data):
if self.modality == 'image':
vec = self._inference_from_image(data)
elif self.modality == 'text':
vec = self._inference_from_text(data)
else:
raise ValueError("modality[{}] not implemented.".format(self._modality))
return vec.detach().cpu().numpy()
def _inference_from_text(self, data):
ids = self.tokenizer.encode(data)
ids = torch.LongTensor(ids).unsqueeze(0)
attn_mask = torch.ones(len(ids), dtype=torch.long).unsqueeze(0)
pos_ids = torch.arange(len(ids), dtype=torch.long).unsqueeze(0)
_, query_vector, _ = self.bi_encoder.txt_model(ids, None, attn_mask, pos_ids)
return query_vector
def _inference_from_image(self, data):
img_pos_feat, img_feat, _ = self.img_detfeat_extract(data)
num_bb = img_pos_feat.shape[0]
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
bs = 1
num_bbs = [num_bb]
out_size = attn_masks_img.size(0)
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
img_feat, img_pos_feat, img_input_ids = self.get_img_feat(data)
fix_txt_encoder = False
position_ids = torch.arange(0, img_input_ids.size(0), dtype=torch.long).unsqueeze(0)
img_input_ids = img_input_ids.unsqueeze(0)
attn_masks_img = attn_masks_img.unsqueeze(0)
img_feat = img_feat.unsqueeze(0)
img_pos_feat = img_pos_feat.unsqueeze(0)
img_seq, img_pooled, img_hidden = self.bi_encoder.get_representation(self.bi_encoder.img_model, img_input_ids,
attn_masks_img, position_ids,
img_feat, img_pos_feat,
None,
gather_index, fix_txt_encoder)
return img_pooled

4
requirements.txt

@ -0,0 +1,4 @@
torch>=1.9.0
torchvision>=0.10.0
transformers==2.3.0
Pillow

BIN
tabular1.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

BIN
tabular2.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 176 KiB

41
utils.py

@ -0,0 +1,41 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from types import SimpleNamespace
class Configs(SimpleNamespace):
def __init__(self, dictionary, **kwargs):
super().__init__(**kwargs)
for key, value in dictionary.items():
if isinstance(value, dict):
self.__setattr__(key, Configs(value))
else:
self.__setattr__(key, value)
def __getattribute__(self, value):
try:
return super().__getattribute__(value)
except AttributeError:
return None
def get_gather_index(txt_lens, num_bbs, batch_size, max_len, out_size):
# assert len(txt_lens) == len(num_bbs) == batch_size
gather_index = torch.arange(0, out_size, dtype=torch.long,
).unsqueeze(0).repeat(len(num_bbs), 1)
# for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)):
# gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb,
# dtype=torch.long).data
return gather_index

BIN
vec1.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

BIN
vec2.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

Loading…
Cancel
Save