mdmmt
              
                 
                
            
          copied
				 6 changed files with 281 additions and 1 deletions
			
			
		| @ -1,2 +1,121 @@ | |||
| # mdmmt | |||
| # Video-Text Retrieval Embdding with MDMMT | |||
| 
 | |||
| *author: Chen Zhang* | |||
| 
 | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| 
 | |||
| 
 | |||
| ## Description | |||
| 
 | |||
| This operator extracts features for video or text with [MDMMT: Multidomain Multimodal Transformer for Video Retrieval](https://arxiv.org/pdf/2103.10699.pdf), which can generate embeddings for text and video by jointly training a video encoder and text encoder to maximize the cosine similarity. | |||
| 
 | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| 
 | |||
| ## Code Example | |||
| 
 | |||
| Load a video embeddings extracted from different upstream expert networks, such as video, RGB, audio. | |||
| 
 | |||
| Read the text to generate a text embedding.  | |||
| 
 | |||
|  *Write the pipeline code*: | |||
| 
 | |||
| ```python | |||
| import towhee | |||
| import torch | |||
| 
 | |||
| torch.manual_seed(42) | |||
| 
 | |||
| # features are embeddings extracted from the upstream models. | |||
| features = { | |||
|     "VIDEO": torch.rand(30, 2048), | |||
|     "CLIP": torch.rand(30, 512), | |||
|     "tf_vggish": torch.rand(30, 128), | |||
| } | |||
| 
 | |||
| # features_t is the time series of the features, usually uniformly sampled.  | |||
| features_t = { | |||
|     "VIDEO": torch.linspace(1, 30, steps=30), | |||
|     "CLIP": torch.linspace(1, 30, steps=30), | |||
|     "tf_vggish": torch.linspace(1, 30, steps=30), | |||
| } | |||
| 
 | |||
| # features_ind is the mask of the features. | |||
| features_ind = { | |||
|     "VIDEO": torch.as_tensor([1] * 25 + [0] * 5), | |||
|     "CLIP": torch.as_tensor([1] * 25 + [0] * 5), | |||
|     "tf_vggish": torch.as_tensor([1] * 25 + [0] * 5), | |||
| } | |||
| 
 | |||
| video_input_dict = {"features": features, "features_t": features_t, "features_ind": features_ind} | |||
| 
 | |||
| towhee.dc([video_input_dict]).video_text_embedding.mdmmt(modality='video', device='cpu').show() | |||
| 
 | |||
| towhee.dc(['Hello world.']).video_text_embedding.mdmmt(modality='text', device='cpu').show() | |||
| ``` | |||
|  | |||
|  | |||
| 
 | |||
| *Write a same pipeline with explicit inputs/outputs name specifications:* | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| 
 | |||
| 
 | |||
| ## Factory Constructor | |||
| 
 | |||
| Create the operator via the following factory method | |||
| 
 | |||
| ***mdmmt(modality: str)*** | |||
| 
 | |||
| **Parameters:** | |||
| 
 | |||
|    ***modality:*** *str* | |||
| 
 | |||
|    Which modality(*video* or *text*) is used to generate the embedding.  | |||
| 
 | |||
|    ***weight_path:*** *Optional[str]* | |||
| 
 | |||
|    pretrained model weights path.   | |||
| 
 | |||
|    ***device:*** *Optional[str]* | |||
| 
 | |||
|    cpu or cuda.   | |||
| 
 | |||
|    ***mmtvid_params:*** *Optional[dict]* | |||
| 
 | |||
|    mmtvid model params for custom model.   | |||
| 
 | |||
|    ***mmttxt_params:*** *Optional[dict]* | |||
| 
 | |||
|    mmttxt model params for custom model.   | |||
| 
 | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| 
 | |||
| 
 | |||
| ## Interface | |||
| 
 | |||
| When video modality, load a video embeddings extracted from different upstream expert networks, such as video, RGB, audio.   | |||
| When text modality, read the text to generate a text embedding.  | |||
| 
 | |||
| 
 | |||
| **Parameters:** | |||
| 
 | |||
| 	***data:*** *dict* or *str* | |||
| 
 | |||
|   The embedding dict extracted from different upstream expert networks or text, based on specified modality).	 | |||
| 
 | |||
| 
 | |||
| 
 | |||
| **Returns:** *numpy.ndarray* | |||
| 
 | |||
|    The data embedding extracted by model. | |||
| 
 | |||
| 
 | |||
| 
 | |||
| 
 | |||
|  | |||
| @ -0,0 +1,20 @@ | |||
| # Copyright 2021 Zilliz. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| #     http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| 
 | |||
| from .mdmmt import MDMMT | |||
| 
 | |||
| 
 | |||
| def mdmmt(modality: str, **kwargs): | |||
|     return MDMMT(modality, **kwargs) | |||
| 
 | |||
| @ -0,0 +1,137 @@ | |||
| # Copyright 2021 Zilliz. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| #     http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| 
 | |||
| import torch | |||
| 
 | |||
| from typing import Dict, Union | |||
| from towhee.models.mdmmt.mmt import MMTVID, MMTTXT | |||
| from towhee.operator.base import NNOperator | |||
| from towhee import register | |||
| from pathlib import Path | |||
| from transformers.models.bert.modeling_bert import BertModel as TxtBertModel | |||
| from transformers import AutoTokenizer | |||
| 
 | |||
| import warnings | |||
| warnings.filterwarnings('ignore') | |||
| 
 | |||
| 
 | |||
| @register(output_schema=['vec']) | |||
| class MDMMT(NNOperator): | |||
|     """ | |||
|     MDMMT multi-modal embedding operator | |||
|     """ | |||
| 
 | |||
|     def __init__(self, modality: str, weight_path: str = None, device: str = None, mmtvid_params: Dict = None, | |||
|                  mmttxt_params: Dict = None): | |||
|         super().__init__() | |||
|         self.modality = modality | |||
|         if weight_path is None: | |||
|             weight_path = str(Path(__file__).parent / 'mdmmt_3mod.pth') | |||
|             # print('weight_path is None, use default path: {}'.format(weight_path)) | |||
|         if device is None: | |||
|             self.device = "cuda" if torch.cuda.is_available() else "cpu" | |||
|         else: | |||
|             self.device = device | |||
|         self.mmtvid_model = None | |||
|         self.mmttxt_model = None | |||
|         state = torch.load(weight_path, map_location='cpu') | |||
|         if self.modality == 'video': | |||
|             if mmtvid_params is None: | |||
|                 expert_dims = { | |||
|                     "VIDEO": {"dim": 2048, "idx": 1, "max_tok": 30}, | |||
|                     "CLIP": {"dim": 512, "idx": 2, "max_tok": 30}, | |||
|                     "tf_vggish": {"dim": 128, "idx": 3, "max_tok": 30}, | |||
|                 } | |||
|                 vid_bert_params = { | |||
|                     "vocab_size_or_config_json_file": 10, | |||
|                     "hidden_size": 512, | |||
|                     "num_hidden_layers": 9, | |||
|                     "intermediate_size": 3072, | |||
|                     "hidden_act": "gelu", | |||
|                     "hidden_dropout_prob": 0.2, | |||
|                     "attention_probs_dropout_prob": 0.2, | |||
|                     "max_position_embeddings": 32, | |||
|                     "type_vocab_size": 19, | |||
|                     "initializer_range": 0.02, | |||
|                     "layer_norm_eps": 1e-12, | |||
|                     "num_attention_heads": 8, | |||
|                 } | |||
| 
 | |||
|                 class Struct: | |||
|                     def __init__(self, **entries): | |||
|                         self.__dict__.update(entries) | |||
| 
 | |||
|                 config = Struct(**vid_bert_params) | |||
|                 self.mmtvid_model = MMTVID( | |||
|                     expert_dims=expert_dims, | |||
|                     same_dim=512, | |||
|                     hidden_size=512, | |||
|                     vid_bert_config=config | |||
|                 ) | |||
|             else: | |||
|                 self.mmtvid_model = MMTVID(**mmtvid_params) | |||
|             self.mmtvid_model.load_state_dict(state['vid_state_dict']) | |||
|             self.mmtvid_model.to(device) | |||
|             self.mmtvid_model.eval() | |||
|         elif self.modality == 'text': | |||
|             if mmttxt_params is None: | |||
|                 txt_bert_params = { | |||
|                     'hidden_dropout_prob': 0.2, | |||
|                     'attention_probs_dropout_prob': 0.2, | |||
|                 } | |||
|                 self.mmttxt_model = MMTTXT( | |||
|                     txt_bert=TxtBertModel.from_pretrained('bert-base-cased', **txt_bert_params), | |||
|                     tokenizer=AutoTokenizer.from_pretrained('bert-base-cased'), | |||
|                     max_length=30, | |||
|                     modalities=["CLIP", "tf_vggish", "VIDEO"], | |||
|                     add_special_tokens=True, | |||
|                     add_dot=True, | |||
|                     same_dim=512, | |||
|                     dout_prob=0.2, | |||
|                 ) | |||
|             else: | |||
|                 self.mmttxt_model = MMTTXT(**mmttxt_params) | |||
|             self.mmttxt_model.load_state_dict(state['txt_state_dict']) | |||
|             self.mmttxt_model.to(device) | |||
|             self.mmttxt_model.eval() | |||
| 
 | |||
|     def __call__(self, data: Union[Dict, str]): | |||
|         if self.modality == 'video': | |||
|             vec = self._inference_from_video(**data) # {"features"=..., "features_t"=..., "features_ind"=...} | |||
|         elif self.modality == 'text': | |||
|             vec = self._inference_from_text(data) # str | |||
|         else: | |||
|             raise ValueError("modality[{}] not implemented.".format(self._modality)) | |||
|         return vec | |||
| 
 | |||
|     def _inference_from_text(self, text: str): | |||
|         self.mmttxt_model.eval() | |||
|         output = self.mmttxt_model([text]) | |||
|         # self.assertTrue(output.shape == (batch_size, 1024)) | |||
|         return output.detach().flatten().cpu().numpy() | |||
| 
 | |||
|     def _inference_from_video(self, features, features_t, features_ind): | |||
|         self.mmtvid_model.eval() | |||
|         output = self.mmtvid_model( | |||
|             features=self._preprocess_video_input(features), | |||
|             features_t=self._preprocess_video_input(features_t), | |||
|             features_ind=self._preprocess_video_input(features_ind), | |||
|             features_maxp=None, | |||
|         ) | |||
|         return output.detach().flatten().cpu().numpy() | |||
| 
 | |||
|     def _preprocess_video_input(self, data: Dict): | |||
|         for k, v in data.items(): | |||
|             data[k] = v.unsqueeze(0).to(self.device) | |||
|         return data | |||
| @ -0,0 +1,4 @@ | |||
| transformers | |||
| torch | |||
| towhee.models | |||
| towhee | |||
| After Width: | Height: | Size: 7.3 KiB | 
| After Width: | Height: | Size: 7.6 KiB | 
					Loading…
					
					
				
		Reference in new issue
	
	