diff --git a/README.md b/README.md
index 1a92208..dc22870 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,121 @@
-# mdmmt
+# Video-Text Retrieval Embdding with MDMMT
+
+*author: Chen Zhang*
+
+
+
+
+
+
+## Description
+
+This operator extracts features for video or text with [MDMMT: Multidomain Multimodal Transformer for Video Retrieval](https://arxiv.org/pdf/2103.10699.pdf), which can generate embeddings for text and video by jointly training a video encoder and text encoder to maximize the cosine similarity.
+
+
+
+
+
+## Code Example
+
+Load a video embeddings extracted from different upstream expert networks, such as video, RGB, audio.
+
+Read the text to generate a text embedding.
+
+ *Write the pipeline code*:
+
+```python
+import towhee
+import torch
+
+torch.manual_seed(42)
+
+# features are embeddings extracted from the upstream models.
+features = {
+ "VIDEO": torch.rand(30, 2048),
+ "CLIP": torch.rand(30, 512),
+ "tf_vggish": torch.rand(30, 128),
+}
+
+# features_t is the time series of the features, usually uniformly sampled.
+features_t = {
+ "VIDEO": torch.linspace(1, 30, steps=30),
+ "CLIP": torch.linspace(1, 30, steps=30),
+ "tf_vggish": torch.linspace(1, 30, steps=30),
+}
+
+# features_ind is the mask of the features.
+features_ind = {
+ "VIDEO": torch.as_tensor([1] * 25 + [0] * 5),
+ "CLIP": torch.as_tensor([1] * 25 + [0] * 5),
+ "tf_vggish": torch.as_tensor([1] * 25 + [0] * 5),
+}
+
+video_input_dict = {"features": features, "features_t": features_t, "features_ind": features_ind}
+
+towhee.dc([video_input_dict]).video_text_embedding.mdmmt(modality='video', device='cpu').show()
+
+towhee.dc(['Hello world.']).video_text_embedding.mdmmt(modality='text', device='cpu').show()
+```
+![](vect_simplified_video.png)
+![](vect_simplified_text.png)
+
+*Write a same pipeline with explicit inputs/outputs name specifications:*
+
+
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method
+
+***mdmmt(modality: str)***
+
+**Parameters:**
+
+ ***modality:*** *str*
+
+ Which modality(*video* or *text*) is used to generate the embedding.
+
+ ***weight_path:*** *Optional[str]*
+
+ pretrained model weights path.
+
+ ***device:*** *Optional[str]*
+
+ cpu or cuda.
+
+ ***mmtvid_params:*** *Optional[dict]*
+
+ mmtvid model params for custom model.
+
+ ***mmttxt_params:*** *Optional[dict]*
+
+ mmttxt model params for custom model.
+
+
+
+
+
+
+## Interface
+
+When video modality, load a video embeddings extracted from different upstream expert networks, such as video, RGB, audio.
+When text modality, read the text to generate a text embedding.
+
+
+**Parameters:**
+
+ ***data:*** *dict* or *str*
+
+ The embedding dict extracted from different upstream expert networks or text, based on specified modality).
+
+
+
+**Returns:** *numpy.ndarray*
+
+ The data embedding extracted by model.
+
+
+
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..eeededc
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .mdmmt import MDMMT
+
+
+def mdmmt(modality: str, **kwargs):
+ return MDMMT(modality, **kwargs)
+
diff --git a/mdmmt.py b/mdmmt.py
new file mode 100644
index 0000000..ac03b06
--- /dev/null
+++ b/mdmmt.py
@@ -0,0 +1,137 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from typing import Dict, Union
+from towhee.models.mdmmt.mmt import MMTVID, MMTTXT
+from towhee.operator.base import NNOperator
+from towhee import register
+from pathlib import Path
+from transformers.models.bert.modeling_bert import BertModel as TxtBertModel
+from transformers import AutoTokenizer
+
+import warnings
+warnings.filterwarnings('ignore')
+
+
+@register(output_schema=['vec'])
+class MDMMT(NNOperator):
+ """
+ MDMMT multi-modal embedding operator
+ """
+
+ def __init__(self, modality: str, weight_path: str = None, device: str = None, mmtvid_params: Dict = None,
+ mmttxt_params: Dict = None):
+ super().__init__()
+ self.modality = modality
+ if weight_path is None:
+ weight_path = str(Path(__file__).parent / 'mdmmt_3mod.pth')
+ # print('weight_path is None, use default path: {}'.format(weight_path))
+ if device is None:
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ else:
+ self.device = device
+ self.mmtvid_model = None
+ self.mmttxt_model = None
+ state = torch.load(weight_path, map_location='cpu')
+ if self.modality == 'video':
+ if mmtvid_params is None:
+ expert_dims = {
+ "VIDEO": {"dim": 2048, "idx": 1, "max_tok": 30},
+ "CLIP": {"dim": 512, "idx": 2, "max_tok": 30},
+ "tf_vggish": {"dim": 128, "idx": 3, "max_tok": 30},
+ }
+ vid_bert_params = {
+ "vocab_size_or_config_json_file": 10,
+ "hidden_size": 512,
+ "num_hidden_layers": 9,
+ "intermediate_size": 3072,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.2,
+ "attention_probs_dropout_prob": 0.2,
+ "max_position_embeddings": 32,
+ "type_vocab_size": 19,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-12,
+ "num_attention_heads": 8,
+ }
+
+ class Struct:
+ def __init__(self, **entries):
+ self.__dict__.update(entries)
+
+ config = Struct(**vid_bert_params)
+ self.mmtvid_model = MMTVID(
+ expert_dims=expert_dims,
+ same_dim=512,
+ hidden_size=512,
+ vid_bert_config=config
+ )
+ else:
+ self.mmtvid_model = MMTVID(**mmtvid_params)
+ self.mmtvid_model.load_state_dict(state['vid_state_dict'])
+ self.mmtvid_model.to(device)
+ self.mmtvid_model.eval()
+ elif self.modality == 'text':
+ if mmttxt_params is None:
+ txt_bert_params = {
+ 'hidden_dropout_prob': 0.2,
+ 'attention_probs_dropout_prob': 0.2,
+ }
+ self.mmttxt_model = MMTTXT(
+ txt_bert=TxtBertModel.from_pretrained('bert-base-cased', **txt_bert_params),
+ tokenizer=AutoTokenizer.from_pretrained('bert-base-cased'),
+ max_length=30,
+ modalities=["CLIP", "tf_vggish", "VIDEO"],
+ add_special_tokens=True,
+ add_dot=True,
+ same_dim=512,
+ dout_prob=0.2,
+ )
+ else:
+ self.mmttxt_model = MMTTXT(**mmttxt_params)
+ self.mmttxt_model.load_state_dict(state['txt_state_dict'])
+ self.mmttxt_model.to(device)
+ self.mmttxt_model.eval()
+
+ def __call__(self, data: Union[Dict, str]):
+ if self.modality == 'video':
+ vec = self._inference_from_video(**data) # {"features"=..., "features_t"=..., "features_ind"=...}
+ elif self.modality == 'text':
+ vec = self._inference_from_text(data) # str
+ else:
+ raise ValueError("modality[{}] not implemented.".format(self._modality))
+ return vec
+
+ def _inference_from_text(self, text: str):
+ self.mmttxt_model.eval()
+ output = self.mmttxt_model([text])
+ # self.assertTrue(output.shape == (batch_size, 1024))
+ return output.detach().flatten().cpu().numpy()
+
+ def _inference_from_video(self, features, features_t, features_ind):
+ self.mmtvid_model.eval()
+ output = self.mmtvid_model(
+ features=self._preprocess_video_input(features),
+ features_t=self._preprocess_video_input(features_t),
+ features_ind=self._preprocess_video_input(features_ind),
+ features_maxp=None,
+ )
+ return output.detach().flatten().cpu().numpy()
+
+ def _preprocess_video_input(self, data: Dict):
+ for k, v in data.items():
+ data[k] = v.unsqueeze(0).to(self.device)
+ return data
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..1bf57e6
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+transformers
+torch
+towhee.models
+towhee
\ No newline at end of file
diff --git a/vect_simplified_text.png b/vect_simplified_text.png
new file mode 100644
index 0000000..91004b6
Binary files /dev/null and b/vect_simplified_text.png differ
diff --git a/vect_simplified_video.png b/vect_simplified_video.png
new file mode 100644
index 0000000..056d876
Binary files /dev/null and b/vect_simplified_video.png differ