diff --git a/README.md b/README.md
index 7a3deaf..0811f1e 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,97 @@
-# lightningdot
+# Image-Text Retrieval Embdding with LightningDOT
+
+*author: David Wang*
+
+
+
+
+
+
+## Description
+
+This operator extracts features for image or text with [LightningDOT](https://arxiv.org/abs/2103.08784) which can generate embeddings for text and image by jointly training an image encoder and text encoder to maximize the cosine similarity.
+
+
+
+
+
+## Code Example
+
+Load an image from path './teddy.jpg' to generate an image embedding.
+
+Read the text 'A teddybear on a skateboard in Times Square.' to generate an text embedding.
+
+ *Write the pipeline in simplified style*:
+
+```python
+import towhee
+
+towhee.glob('./teddy.jpg') \
+ .image_decode() \
+ .image_text_embedding.lightningdot(modality='image') \
+ .show()
+
+towhee.dc(["A teddybear on a skateboard in Times Square."]) \
+ .image_text_embedding.lightningdot(modality='text') \
+ .show()
+```
+
+
+
+*Write a same pipeline with explicit inputs/outputs name specifications:*
+
+```python
+import towhee
+
+towhee.glob['path']('./teddy.jpg') \
+ .image_decode['path', 'img']() \
+ .image_text_embedding.lightningdot['img', 'vec'](modality='image') \
+ .select['img', 'vec']() \
+ .show()
+
+towhee.dc['text'](["A teddybear on a skateboard in Times Square."]) \
+ .image_text_embedding.lightningdot['text','vec'](modality='text') \
+ .select['text', 'vec']() \
+ .show()
+```
+
+
+
+
+
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method
+
+***lightningdot(modality)***
+
+**Parameters:**
+
+ ***modality:*** *str*
+
+ Which modality(*image* or *text*) is used to generate the embedding.
+
+
+
+
+
+## Interface
+
+An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) or string as input and generate an embedding in ndarray.
+
+
+**Parameters:**
+
+ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str*
+
+ The data (image or text based on specified modality) to generate embedding.
+
+
+
+**Returns:** *numpy.ndarray*
+
+ The data embedding extracted by model.
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..4e32104
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .lightningdot import LightningDOT
+
+def lightningdot(modality: str):
+ return LightningDOT(modality)
diff --git a/lightningdot.py b/lightningdot.py
new file mode 100644
index 0000000..b314edc
--- /dev/null
+++ b/lightningdot.py
@@ -0,0 +1,146 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import os
+import json
+import torch
+from pathlib import Path
+
+import numpy as np
+from transformers.tokenization_bert import BertTokenizer
+
+from towhee.types.image_utils import to_pil
+from towhee.operator.base import NNOperator, OperatorFlag
+from towhee.types.arg import arg, to_image_color
+from towhee import register
+
+from .utils import Configs, get_gather_index
+
+
+def arg_process(args):
+ dirname = os.path.dirname(__file__)
+ args.img_checkpoint = dirname + '/' + args.img_checkpoint
+ args.img_model_config = dirname + '/' + args.img_model_config
+ return args
+
+@register(output_schema=['vec'])
+class LightningDOT(NNOperator):
+ """
+ CLIP multi-modal embedding operator
+ """
+ def __init__(self, modality: str):
+ sys.path.append(str(Path(__file__).parent))
+ from dvl.models.bi_encoder import BiEncoder
+ from detector.faster_rcnn import Net, process_img
+
+ full_path = os.path.dirname(__file__) + '/config/flickr30k_ft_config.json'
+ with open(full_path) as fw:
+ content = fw.read()
+ args = json.loads(content)
+ args = Configs(args)
+ args = arg_process(args)
+ self.bi_encoder = BiEncoder(args, True, True, project_dim=args.project_dim)
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ img_model, txt_model = self.bi_encoder.img_model, self.bi_encoder.txt_model
+ img_model.eval()
+ txt_model.eval()
+ self.faster_rcnn_preprocess = process_img
+ self.faster_rcnn = Net()
+ self.faster_rcnn.load_state_dict(torch.load(os.path.dirname(__file__) + '/data/model/resnet101_faster_rcnn_final.pth'))
+ self.faster_rcnn.eval()
+ self.modality = modality
+
+ def img_detfeat_extract(self, img):
+ orig_im_scale = [img.shape[1], img.shape[0]]
+ img, im_scale = self.faster_rcnn_preprocess(img)
+ img = np.expand_dims(img.transpose((2,0,1)), 0)
+ img = torch.FloatTensor(img)
+ bboxes, feat, confidence = self.faster_rcnn(img, im_scale)
+ bboxes = self.bbox_feat_process(bboxes, orig_im_scale)
+ img_bb = torch.cat([bboxes, bboxes[:, 4:5]*bboxes[:, 5:]], dim=-1)
+ return img_bb, feat, confidence
+
+ def bbox_feat_process(self, bboxes, im_scale):
+ image_w, image_h = im_scale
+ box_width = bboxes[:, 2] - bboxes[:, 0]
+ box_height = bboxes[:, 3] - bboxes[:, 1]
+ scaled_width = box_width / image_w
+ scaled_height = box_height / image_h
+ scaled_x = bboxes[:, 0] / image_w
+ scaled_y = bboxes[:, 1] / image_h
+
+ box_width = box_width.unsqueeze(1)
+ box_height = box_height.unsqueeze(1)
+ scaled_width = scaled_width.unsqueeze(1)
+ scaled_height = scaled_height.unsqueeze(1)
+ scaled_x = scaled_x.unsqueeze(1)
+ scaled_y = scaled_y .unsqueeze(1)
+
+ normalized_bbox = torch.hstack((scaled_x, scaled_y,
+ scaled_x + scaled_width,
+ scaled_y + scaled_height,
+ scaled_width, scaled_height))
+
+ return normalized_bbox
+
+ def get_img_feat(self, data):
+ img_pos_feat, img_feat, _ = self.img_detfeat_extract(data)
+ num_bb = img_pos_feat.shape[1]
+
+ img_input_ids = torch.Tensor([101]).long()
+
+ return img_feat, img_pos_feat, img_input_ids
+
+ def __call__(self, data):
+ if self.modality == 'image':
+ vec = self._inference_from_image(data)
+ elif self.modality == 'text':
+ vec = self._inference_from_text(data)
+ else:
+ raise ValueError("modality[{}] not implemented.".format(self._modality))
+ return vec.detach().cpu().numpy()
+
+ def _inference_from_text(self, data):
+ ids = self.tokenizer.encode(data)
+ ids = torch.LongTensor(ids).unsqueeze(0)
+ attn_mask = torch.ones(len(ids), dtype=torch.long).unsqueeze(0)
+ pos_ids = torch.arange(len(ids), dtype=torch.long).unsqueeze(0)
+ _, query_vector, _ = self.bi_encoder.txt_model(ids, None, attn_mask, pos_ids)
+ return query_vector
+
+ def _inference_from_image(self, data):
+ img_pos_feat, img_feat, _ = self.img_detfeat_extract(data)
+ num_bb = img_pos_feat.shape[0]
+
+ attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
+ bs = 1
+ num_bbs = [num_bb]
+ out_size = attn_masks_img.size(0)
+ gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
+ img_feat, img_pos_feat, img_input_ids = self.get_img_feat(data)
+ fix_txt_encoder = False
+
+ position_ids = torch.arange(0, img_input_ids.size(0), dtype=torch.long).unsqueeze(0)
+ img_input_ids = img_input_ids.unsqueeze(0)
+ attn_masks_img = attn_masks_img.unsqueeze(0)
+ img_feat = img_feat.unsqueeze(0)
+ img_pos_feat = img_pos_feat.unsqueeze(0)
+
+ img_seq, img_pooled, img_hidden = self.bi_encoder.get_representation(self.bi_encoder.img_model, img_input_ids,
+ attn_masks_img, position_ids,
+ img_feat, img_pos_feat,
+ None,
+ gather_index, fix_txt_encoder)
+
+ return img_pooled
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..2f7e696
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+torch>=1.9.0
+torchvision>=0.10.0
+transformers==2.3.0
+Pillow
diff --git a/tabular1.png b/tabular1.png
new file mode 100644
index 0000000..f0a8844
Binary files /dev/null and b/tabular1.png differ
diff --git a/tabular2.png b/tabular2.png
new file mode 100644
index 0000000..9a90fe9
Binary files /dev/null and b/tabular2.png differ
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..8350c97
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,41 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from types import SimpleNamespace
+
+class Configs(SimpleNamespace):
+
+ def __init__(self, dictionary, **kwargs):
+ super().__init__(**kwargs)
+ for key, value in dictionary.items():
+ if isinstance(value, dict):
+ self.__setattr__(key, Configs(value))
+ else:
+ self.__setattr__(key, value)
+
+ def __getattribute__(self, value):
+ try:
+ return super().__getattribute__(value)
+ except AttributeError:
+ return None
+
+def get_gather_index(txt_lens, num_bbs, batch_size, max_len, out_size):
+ # assert len(txt_lens) == len(num_bbs) == batch_size
+ gather_index = torch.arange(0, out_size, dtype=torch.long,
+ ).unsqueeze(0).repeat(len(num_bbs), 1)
+
+ # for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)):
+ # gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb,
+ # dtype=torch.long).data
+ return gather_index
diff --git a/vec1.png b/vec1.png
new file mode 100644
index 0000000..809949f
Binary files /dev/null and b/vec1.png differ
diff --git a/vec2.png b/vec2.png
new file mode 100644
index 0000000..a2f4169
Binary files /dev/null and b/vec2.png differ