diff --git a/README.md b/README.md index 7a3deaf..0811f1e 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,97 @@ -# lightningdot +# Image-Text Retrieval Embdding with LightningDOT + +*author: David Wang* + + +
+ + + +## Description + +This operator extracts features for image or text with [LightningDOT](https://arxiv.org/abs/2103.08784) which can generate embeddings for text and image by jointly training an image encoder and text encoder to maximize the cosine similarity. + + +
+ + +## Code Example + +Load an image from path './teddy.jpg' to generate an image embedding. + +Read the text 'A teddybear on a skateboard in Times Square.' to generate an text embedding. + + *Write the pipeline in simplified style*: + +```python +import towhee + +towhee.glob('./teddy.jpg') \ + .image_decode() \ + .image_text_embedding.lightningdot(modality='image') \ + .show() + +towhee.dc(["A teddybear on a skateboard in Times Square."]) \ + .image_text_embedding.lightningdot(modality='text') \ + .show() +``` +result1 +result2 + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +towhee.glob['path']('./teddy.jpg') \ + .image_decode['path', 'img']() \ + .image_text_embedding.lightningdot['img', 'vec'](modality='image') \ + .select['img', 'vec']() \ + .show() + +towhee.dc['text'](["A teddybear on a skateboard in Times Square."]) \ + .image_text_embedding.lightningdot['text','vec'](modality='text') \ + .select['text', 'vec']() \ + .show() +``` +result1 +result2 + + +
+ + + +## Factory Constructor + +Create the operator via the following factory method + +***lightningdot(modality)*** + +**Parameters:** + +​ ***modality:*** *str* + +​ Which modality(*image* or *text*) is used to generate the embedding. + +
+ + + +## Interface + +An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) or string as input and generate an embedding in ndarray. + + +**Parameters:** + +​ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str* + +​ The data (image or text based on specified modality) to generate embedding. + + + +**Returns:** *numpy.ndarray* + +​ The data embedding extracted by model. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..4e32104 --- /dev/null +++ b/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .lightningdot import LightningDOT + +def lightningdot(modality: str): + return LightningDOT(modality) diff --git a/lightningdot.py b/lightningdot.py new file mode 100644 index 0000000..b314edc --- /dev/null +++ b/lightningdot.py @@ -0,0 +1,146 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import os +import json +import torch +from pathlib import Path + +import numpy as np +from transformers.tokenization_bert import BertTokenizer + +from towhee.types.image_utils import to_pil +from towhee.operator.base import NNOperator, OperatorFlag +from towhee.types.arg import arg, to_image_color +from towhee import register + +from .utils import Configs, get_gather_index + + +def arg_process(args): + dirname = os.path.dirname(__file__) + args.img_checkpoint = dirname + '/' + args.img_checkpoint + args.img_model_config = dirname + '/' + args.img_model_config + return args + +@register(output_schema=['vec']) +class LightningDOT(NNOperator): + """ + CLIP multi-modal embedding operator + """ + def __init__(self, modality: str): + sys.path.append(str(Path(__file__).parent)) + from dvl.models.bi_encoder import BiEncoder + from detector.faster_rcnn import Net, process_img + + full_path = os.path.dirname(__file__) + '/config/flickr30k_ft_config.json' + with open(full_path) as fw: + content = fw.read() + args = json.loads(content) + args = Configs(args) + args = arg_process(args) + self.bi_encoder = BiEncoder(args, True, True, project_dim=args.project_dim) + self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + img_model, txt_model = self.bi_encoder.img_model, self.bi_encoder.txt_model + img_model.eval() + txt_model.eval() + self.faster_rcnn_preprocess = process_img + self.faster_rcnn = Net() + self.faster_rcnn.load_state_dict(torch.load(os.path.dirname(__file__) + '/data/model/resnet101_faster_rcnn_final.pth')) + self.faster_rcnn.eval() + self.modality = modality + + def img_detfeat_extract(self, img): + orig_im_scale = [img.shape[1], img.shape[0]] + img, im_scale = self.faster_rcnn_preprocess(img) + img = np.expand_dims(img.transpose((2,0,1)), 0) + img = torch.FloatTensor(img) + bboxes, feat, confidence = self.faster_rcnn(img, im_scale) + bboxes = self.bbox_feat_process(bboxes, orig_im_scale) + img_bb = torch.cat([bboxes, bboxes[:, 4:5]*bboxes[:, 5:]], dim=-1) + return img_bb, feat, confidence + + def bbox_feat_process(self, bboxes, im_scale): + image_w, image_h = im_scale + box_width = bboxes[:, 2] - bboxes[:, 0] + box_height = bboxes[:, 3] - bboxes[:, 1] + scaled_width = box_width / image_w + scaled_height = box_height / image_h + scaled_x = bboxes[:, 0] / image_w + scaled_y = bboxes[:, 1] / image_h + + box_width = box_width.unsqueeze(1) + box_height = box_height.unsqueeze(1) + scaled_width = scaled_width.unsqueeze(1) + scaled_height = scaled_height.unsqueeze(1) + scaled_x = scaled_x.unsqueeze(1) + scaled_y = scaled_y .unsqueeze(1) + + normalized_bbox = torch.hstack((scaled_x, scaled_y, + scaled_x + scaled_width, + scaled_y + scaled_height, + scaled_width, scaled_height)) + + return normalized_bbox + + def get_img_feat(self, data): + img_pos_feat, img_feat, _ = self.img_detfeat_extract(data) + num_bb = img_pos_feat.shape[1] + + img_input_ids = torch.Tensor([101]).long() + + return img_feat, img_pos_feat, img_input_ids + + def __call__(self, data): + if self.modality == 'image': + vec = self._inference_from_image(data) + elif self.modality == 'text': + vec = self._inference_from_text(data) + else: + raise ValueError("modality[{}] not implemented.".format(self._modality)) + return vec.detach().cpu().numpy() + + def _inference_from_text(self, data): + ids = self.tokenizer.encode(data) + ids = torch.LongTensor(ids).unsqueeze(0) + attn_mask = torch.ones(len(ids), dtype=torch.long).unsqueeze(0) + pos_ids = torch.arange(len(ids), dtype=torch.long).unsqueeze(0) + _, query_vector, _ = self.bi_encoder.txt_model(ids, None, attn_mask, pos_ids) + return query_vector + + def _inference_from_image(self, data): + img_pos_feat, img_feat, _ = self.img_detfeat_extract(data) + num_bb = img_pos_feat.shape[0] + + attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) + bs = 1 + num_bbs = [num_bb] + out_size = attn_masks_img.size(0) + gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) + img_feat, img_pos_feat, img_input_ids = self.get_img_feat(data) + fix_txt_encoder = False + + position_ids = torch.arange(0, img_input_ids.size(0), dtype=torch.long).unsqueeze(0) + img_input_ids = img_input_ids.unsqueeze(0) + attn_masks_img = attn_masks_img.unsqueeze(0) + img_feat = img_feat.unsqueeze(0) + img_pos_feat = img_pos_feat.unsqueeze(0) + + img_seq, img_pooled, img_hidden = self.bi_encoder.get_representation(self.bi_encoder.img_model, img_input_ids, + attn_masks_img, position_ids, + img_feat, img_pos_feat, + None, + gather_index, fix_txt_encoder) + + return img_pooled diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2f7e696 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +torch>=1.9.0 +torchvision>=0.10.0 +transformers==2.3.0 +Pillow diff --git a/tabular1.png b/tabular1.png new file mode 100644 index 0000000..f0a8844 Binary files /dev/null and b/tabular1.png differ diff --git a/tabular2.png b/tabular2.png new file mode 100644 index 0000000..9a90fe9 Binary files /dev/null and b/tabular2.png differ diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..8350c97 --- /dev/null +++ b/utils.py @@ -0,0 +1,41 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from types import SimpleNamespace + +class Configs(SimpleNamespace): + + def __init__(self, dictionary, **kwargs): + super().__init__(**kwargs) + for key, value in dictionary.items(): + if isinstance(value, dict): + self.__setattr__(key, Configs(value)) + else: + self.__setattr__(key, value) + + def __getattribute__(self, value): + try: + return super().__getattribute__(value) + except AttributeError: + return None + +def get_gather_index(txt_lens, num_bbs, batch_size, max_len, out_size): + # assert len(txt_lens) == len(num_bbs) == batch_size + gather_index = torch.arange(0, out_size, dtype=torch.long, + ).unsqueeze(0).repeat(len(num_bbs), 1) + + # for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)): + # gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb, + # dtype=torch.long).data + return gather_index diff --git a/vec1.png b/vec1.png new file mode 100644 index 0000000..809949f Binary files /dev/null and b/vec1.png differ diff --git a/vec2.png b/vec2.png new file mode 100644 index 0000000..a2f4169 Binary files /dev/null and b/vec2.png differ