diff --git a/README.md b/README.md index 1a92208..dc22870 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,121 @@ -# mdmmt +# Video-Text Retrieval Embdding with MDMMT + +*author: Chen Zhang* + + +
+ + + +## Description + +This operator extracts features for video or text with [MDMMT: Multidomain Multimodal Transformer for Video Retrieval](https://arxiv.org/pdf/2103.10699.pdf), which can generate embeddings for text and video by jointly training a video encoder and text encoder to maximize the cosine similarity. + + +
+ + +## Code Example + +Load a video embeddings extracted from different upstream expert networks, such as video, RGB, audio. + +Read the text to generate a text embedding. + + *Write the pipeline code*: + +```python +import towhee +import torch + +torch.manual_seed(42) + +# features are embeddings extracted from the upstream models. +features = { + "VIDEO": torch.rand(30, 2048), + "CLIP": torch.rand(30, 512), + "tf_vggish": torch.rand(30, 128), +} + +# features_t is the time series of the features, usually uniformly sampled. +features_t = { + "VIDEO": torch.linspace(1, 30, steps=30), + "CLIP": torch.linspace(1, 30, steps=30), + "tf_vggish": torch.linspace(1, 30, steps=30), +} + +# features_ind is the mask of the features. +features_ind = { + "VIDEO": torch.as_tensor([1] * 25 + [0] * 5), + "CLIP": torch.as_tensor([1] * 25 + [0] * 5), + "tf_vggish": torch.as_tensor([1] * 25 + [0] * 5), +} + +video_input_dict = {"features": features, "features_t": features_t, "features_ind": features_ind} + +towhee.dc([video_input_dict]).video_text_embedding.mdmmt(modality='video', device='cpu').show() + +towhee.dc(['Hello world.']).video_text_embedding.mdmmt(modality='text', device='cpu').show() +``` +![](vect_simplified_video.png) +![](vect_simplified_text.png) + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +
+ + + +## Factory Constructor + +Create the operator via the following factory method + +***mdmmt(modality: str)*** + +**Parameters:** + +​ ***modality:*** *str* + +​ Which modality(*video* or *text*) is used to generate the embedding. + +​ ***weight_path:*** *Optional[str]* + +​ pretrained model weights path. + +​ ***device:*** *Optional[str]* + +​ cpu or cuda. + +​ ***mmtvid_params:*** *Optional[dict]* + +​ mmtvid model params for custom model. + +​ ***mmttxt_params:*** *Optional[dict]* + +​ mmttxt model params for custom model. + + +
+ + + +## Interface + +When video modality, load a video embeddings extracted from different upstream expert networks, such as video, RGB, audio. +When text modality, read the text to generate a text embedding. + + +**Parameters:** + +​ ***data:*** *dict* or *str* + +​ The embedding dict extracted from different upstream expert networks or text, based on specified modality). + + + +**Returns:** *numpy.ndarray* + +​ The data embedding extracted by model. + + + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..eeededc --- /dev/null +++ b/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mdmmt import MDMMT + + +def mdmmt(modality: str, **kwargs): + return MDMMT(modality, **kwargs) + diff --git a/mdmmt.py b/mdmmt.py new file mode 100644 index 0000000..ac03b06 --- /dev/null +++ b/mdmmt.py @@ -0,0 +1,137 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from typing import Dict, Union +from towhee.models.mdmmt.mmt import MMTVID, MMTTXT +from towhee.operator.base import NNOperator +from towhee import register +from pathlib import Path +from transformers.models.bert.modeling_bert import BertModel as TxtBertModel +from transformers import AutoTokenizer + +import warnings +warnings.filterwarnings('ignore') + + +@register(output_schema=['vec']) +class MDMMT(NNOperator): + """ + MDMMT multi-modal embedding operator + """ + + def __init__(self, modality: str, weight_path: str = None, device: str = None, mmtvid_params: Dict = None, + mmttxt_params: Dict = None): + super().__init__() + self.modality = modality + if weight_path is None: + weight_path = str(Path(__file__).parent / 'mdmmt_3mod.pth') + # print('weight_path is None, use default path: {}'.format(weight_path)) + if device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device + self.mmtvid_model = None + self.mmttxt_model = None + state = torch.load(weight_path, map_location='cpu') + if self.modality == 'video': + if mmtvid_params is None: + expert_dims = { + "VIDEO": {"dim": 2048, "idx": 1, "max_tok": 30}, + "CLIP": {"dim": 512, "idx": 2, "max_tok": 30}, + "tf_vggish": {"dim": 128, "idx": 3, "max_tok": 30}, + } + vid_bert_params = { + "vocab_size_or_config_json_file": 10, + "hidden_size": 512, + "num_hidden_layers": 9, + "intermediate_size": 3072, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.2, + "attention_probs_dropout_prob": 0.2, + "max_position_embeddings": 32, + "type_vocab_size": 19, + "initializer_range": 0.02, + "layer_norm_eps": 1e-12, + "num_attention_heads": 8, + } + + class Struct: + def __init__(self, **entries): + self.__dict__.update(entries) + + config = Struct(**vid_bert_params) + self.mmtvid_model = MMTVID( + expert_dims=expert_dims, + same_dim=512, + hidden_size=512, + vid_bert_config=config + ) + else: + self.mmtvid_model = MMTVID(**mmtvid_params) + self.mmtvid_model.load_state_dict(state['vid_state_dict']) + self.mmtvid_model.to(device) + self.mmtvid_model.eval() + elif self.modality == 'text': + if mmttxt_params is None: + txt_bert_params = { + 'hidden_dropout_prob': 0.2, + 'attention_probs_dropout_prob': 0.2, + } + self.mmttxt_model = MMTTXT( + txt_bert=TxtBertModel.from_pretrained('bert-base-cased', **txt_bert_params), + tokenizer=AutoTokenizer.from_pretrained('bert-base-cased'), + max_length=30, + modalities=["CLIP", "tf_vggish", "VIDEO"], + add_special_tokens=True, + add_dot=True, + same_dim=512, + dout_prob=0.2, + ) + else: + self.mmttxt_model = MMTTXT(**mmttxt_params) + self.mmttxt_model.load_state_dict(state['txt_state_dict']) + self.mmttxt_model.to(device) + self.mmttxt_model.eval() + + def __call__(self, data: Union[Dict, str]): + if self.modality == 'video': + vec = self._inference_from_video(**data) # {"features"=..., "features_t"=..., "features_ind"=...} + elif self.modality == 'text': + vec = self._inference_from_text(data) # str + else: + raise ValueError("modality[{}] not implemented.".format(self._modality)) + return vec + + def _inference_from_text(self, text: str): + self.mmttxt_model.eval() + output = self.mmttxt_model([text]) + # self.assertTrue(output.shape == (batch_size, 1024)) + return output.detach().flatten().cpu().numpy() + + def _inference_from_video(self, features, features_t, features_ind): + self.mmtvid_model.eval() + output = self.mmtvid_model( + features=self._preprocess_video_input(features), + features_t=self._preprocess_video_input(features_t), + features_ind=self._preprocess_video_input(features_ind), + features_maxp=None, + ) + return output.detach().flatten().cpu().numpy() + + def _preprocess_video_input(self, data: Dict): + for k, v in data.items(): + data[k] = v.unsqueeze(0).to(self.device) + return data diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1bf57e6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +transformers +torch +towhee.models +towhee \ No newline at end of file diff --git a/vect_simplified_text.png b/vect_simplified_text.png new file mode 100644 index 0000000..91004b6 Binary files /dev/null and b/vect_simplified_text.png differ diff --git a/vect_simplified_video.png b/vect_simplified_video.png new file mode 100644 index 0000000..056d876 Binary files /dev/null and b/vect_simplified_video.png differ