mdmmt
              
                 
                
            
          copied
				 6 changed files with 281 additions and 1 deletions
			
			
		| @ -1,2 +1,121 @@ | |||||
| # mdmmt |  | ||||
|  | # Video-Text Retrieval Embdding with MDMMT | ||||
|  | 
 | ||||
|  | *author: Chen Zhang* | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | ## Description | ||||
|  | 
 | ||||
|  | This operator extracts features for video or text with [MDMMT: Multidomain Multimodal Transformer for Video Retrieval](https://arxiv.org/pdf/2103.10699.pdf), which can generate embeddings for text and video by jointly training a video encoder and text encoder to maximize the cosine similarity. | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | ## Code Example | ||||
|  | 
 | ||||
|  | Load a video embeddings extracted from different upstream expert networks, such as video, RGB, audio. | ||||
|  | 
 | ||||
|  | Read the text to generate a text embedding.  | ||||
|  | 
 | ||||
|  |  *Write the pipeline code*: | ||||
|  | 
 | ||||
|  | ```python | ||||
|  | import towhee | ||||
|  | import torch | ||||
|  | 
 | ||||
|  | torch.manual_seed(42) | ||||
|  | 
 | ||||
|  | # features are embeddings extracted from the upstream models. | ||||
|  | features = { | ||||
|  |     "VIDEO": torch.rand(30, 2048), | ||||
|  |     "CLIP": torch.rand(30, 512), | ||||
|  |     "tf_vggish": torch.rand(30, 128), | ||||
|  | } | ||||
|  | 
 | ||||
|  | # features_t is the time series of the features, usually uniformly sampled.  | ||||
|  | features_t = { | ||||
|  |     "VIDEO": torch.linspace(1, 30, steps=30), | ||||
|  |     "CLIP": torch.linspace(1, 30, steps=30), | ||||
|  |     "tf_vggish": torch.linspace(1, 30, steps=30), | ||||
|  | } | ||||
|  | 
 | ||||
|  | # features_ind is the mask of the features. | ||||
|  | features_ind = { | ||||
|  |     "VIDEO": torch.as_tensor([1] * 25 + [0] * 5), | ||||
|  |     "CLIP": torch.as_tensor([1] * 25 + [0] * 5), | ||||
|  |     "tf_vggish": torch.as_tensor([1] * 25 + [0] * 5), | ||||
|  | } | ||||
|  | 
 | ||||
|  | video_input_dict = {"features": features, "features_t": features_t, "features_ind": features_ind} | ||||
|  | 
 | ||||
|  | towhee.dc([video_input_dict]).video_text_embedding.mdmmt(modality='video', device='cpu').show() | ||||
|  | 
 | ||||
|  | towhee.dc(['Hello world.']).video_text_embedding.mdmmt(modality='text', device='cpu').show() | ||||
|  | ``` | ||||
|  |  | ||||
|  |  | ||||
|  | 
 | ||||
|  | *Write a same pipeline with explicit inputs/outputs name specifications:* | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | ## Factory Constructor | ||||
|  | 
 | ||||
|  | Create the operator via the following factory method | ||||
|  | 
 | ||||
|  | ***mdmmt(modality: str)*** | ||||
|  | 
 | ||||
|  | **Parameters:** | ||||
|  | 
 | ||||
|  | ​   ***modality:*** *str* | ||||
|  | 
 | ||||
|  | ​   Which modality(*video* or *text*) is used to generate the embedding.  | ||||
|  | 
 | ||||
|  | ​   ***weight_path:*** *Optional[str]* | ||||
|  | 
 | ||||
|  | ​   pretrained model weights path.   | ||||
|  | 
 | ||||
|  | ​   ***device:*** *Optional[str]* | ||||
|  | 
 | ||||
|  | ​   cpu or cuda.   | ||||
|  | 
 | ||||
|  | ​   ***mmtvid_params:*** *Optional[dict]* | ||||
|  | 
 | ||||
|  | ​   mmtvid model params for custom model.   | ||||
|  | 
 | ||||
|  | ​   ***mmttxt_params:*** *Optional[dict]* | ||||
|  | 
 | ||||
|  | ​   mmttxt model params for custom model.   | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | ## Interface | ||||
|  | 
 | ||||
|  | When video modality, load a video embeddings extracted from different upstream expert networks, such as video, RGB, audio.   | ||||
|  | When text modality, read the text to generate a text embedding.  | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | **Parameters:** | ||||
|  | 
 | ||||
|  | ​	***data:*** *dict* or *str* | ||||
|  | 
 | ||||
|  | ​  The embedding dict extracted from different upstream expert networks or text, based on specified modality).	 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | **Returns:** *numpy.ndarray* | ||||
|  | 
 | ||||
|  | ​   The data embedding extracted by model. | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
| 
 | 
 | ||||
|  | |||||
| @ -0,0 +1,20 @@ | |||||
|  | # Copyright 2021 Zilliz. All rights reserved. | ||||
|  | # | ||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  | # you may not use this file except in compliance with the License. | ||||
|  | # You may obtain a copy of the License at | ||||
|  | # | ||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | # | ||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  | # See the License for the specific language governing permissions and | ||||
|  | # limitations under the License. | ||||
|  | 
 | ||||
|  | from .mdmmt import MDMMT | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | def mdmmt(modality: str, **kwargs): | ||||
|  |     return MDMMT(modality, **kwargs) | ||||
|  | 
 | ||||
| @ -0,0 +1,137 @@ | |||||
|  | # Copyright 2021 Zilliz. All rights reserved. | ||||
|  | # | ||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  | # you may not use this file except in compliance with the License. | ||||
|  | # You may obtain a copy of the License at | ||||
|  | # | ||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | # | ||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  | # See the License for the specific language governing permissions and | ||||
|  | # limitations under the License. | ||||
|  | 
 | ||||
|  | import torch | ||||
|  | 
 | ||||
|  | from typing import Dict, Union | ||||
|  | from towhee.models.mdmmt.mmt import MMTVID, MMTTXT | ||||
|  | from towhee.operator.base import NNOperator | ||||
|  | from towhee import register | ||||
|  | from pathlib import Path | ||||
|  | from transformers.models.bert.modeling_bert import BertModel as TxtBertModel | ||||
|  | from transformers import AutoTokenizer | ||||
|  | 
 | ||||
|  | import warnings | ||||
|  | warnings.filterwarnings('ignore') | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | @register(output_schema=['vec']) | ||||
|  | class MDMMT(NNOperator): | ||||
|  |     """ | ||||
|  |     MDMMT multi-modal embedding operator | ||||
|  |     """ | ||||
|  | 
 | ||||
|  |     def __init__(self, modality: str, weight_path: str = None, device: str = None, mmtvid_params: Dict = None, | ||||
|  |                  mmttxt_params: Dict = None): | ||||
|  |         super().__init__() | ||||
|  |         self.modality = modality | ||||
|  |         if weight_path is None: | ||||
|  |             weight_path = str(Path(__file__).parent / 'mdmmt_3mod.pth') | ||||
|  |             # print('weight_path is None, use default path: {}'.format(weight_path)) | ||||
|  |         if device is None: | ||||
|  |             self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||||
|  |         else: | ||||
|  |             self.device = device | ||||
|  |         self.mmtvid_model = None | ||||
|  |         self.mmttxt_model = None | ||||
|  |         state = torch.load(weight_path, map_location='cpu') | ||||
|  |         if self.modality == 'video': | ||||
|  |             if mmtvid_params is None: | ||||
|  |                 expert_dims = { | ||||
|  |                     "VIDEO": {"dim": 2048, "idx": 1, "max_tok": 30}, | ||||
|  |                     "CLIP": {"dim": 512, "idx": 2, "max_tok": 30}, | ||||
|  |                     "tf_vggish": {"dim": 128, "idx": 3, "max_tok": 30}, | ||||
|  |                 } | ||||
|  |                 vid_bert_params = { | ||||
|  |                     "vocab_size_or_config_json_file": 10, | ||||
|  |                     "hidden_size": 512, | ||||
|  |                     "num_hidden_layers": 9, | ||||
|  |                     "intermediate_size": 3072, | ||||
|  |                     "hidden_act": "gelu", | ||||
|  |                     "hidden_dropout_prob": 0.2, | ||||
|  |                     "attention_probs_dropout_prob": 0.2, | ||||
|  |                     "max_position_embeddings": 32, | ||||
|  |                     "type_vocab_size": 19, | ||||
|  |                     "initializer_range": 0.02, | ||||
|  |                     "layer_norm_eps": 1e-12, | ||||
|  |                     "num_attention_heads": 8, | ||||
|  |                 } | ||||
|  | 
 | ||||
|  |                 class Struct: | ||||
|  |                     def __init__(self, **entries): | ||||
|  |                         self.__dict__.update(entries) | ||||
|  | 
 | ||||
|  |                 config = Struct(**vid_bert_params) | ||||
|  |                 self.mmtvid_model = MMTVID( | ||||
|  |                     expert_dims=expert_dims, | ||||
|  |                     same_dim=512, | ||||
|  |                     hidden_size=512, | ||||
|  |                     vid_bert_config=config | ||||
|  |                 ) | ||||
|  |             else: | ||||
|  |                 self.mmtvid_model = MMTVID(**mmtvid_params) | ||||
|  |             self.mmtvid_model.load_state_dict(state['vid_state_dict']) | ||||
|  |             self.mmtvid_model.to(device) | ||||
|  |             self.mmtvid_model.eval() | ||||
|  |         elif self.modality == 'text': | ||||
|  |             if mmttxt_params is None: | ||||
|  |                 txt_bert_params = { | ||||
|  |                     'hidden_dropout_prob': 0.2, | ||||
|  |                     'attention_probs_dropout_prob': 0.2, | ||||
|  |                 } | ||||
|  |                 self.mmttxt_model = MMTTXT( | ||||
|  |                     txt_bert=TxtBertModel.from_pretrained('bert-base-cased', **txt_bert_params), | ||||
|  |                     tokenizer=AutoTokenizer.from_pretrained('bert-base-cased'), | ||||
|  |                     max_length=30, | ||||
|  |                     modalities=["CLIP", "tf_vggish", "VIDEO"], | ||||
|  |                     add_special_tokens=True, | ||||
|  |                     add_dot=True, | ||||
|  |                     same_dim=512, | ||||
|  |                     dout_prob=0.2, | ||||
|  |                 ) | ||||
|  |             else: | ||||
|  |                 self.mmttxt_model = MMTTXT(**mmttxt_params) | ||||
|  |             self.mmttxt_model.load_state_dict(state['txt_state_dict']) | ||||
|  |             self.mmttxt_model.to(device) | ||||
|  |             self.mmttxt_model.eval() | ||||
|  | 
 | ||||
|  |     def __call__(self, data: Union[Dict, str]): | ||||
|  |         if self.modality == 'video': | ||||
|  |             vec = self._inference_from_video(**data) # {"features"=..., "features_t"=..., "features_ind"=...} | ||||
|  |         elif self.modality == 'text': | ||||
|  |             vec = self._inference_from_text(data) # str | ||||
|  |         else: | ||||
|  |             raise ValueError("modality[{}] not implemented.".format(self._modality)) | ||||
|  |         return vec | ||||
|  | 
 | ||||
|  |     def _inference_from_text(self, text: str): | ||||
|  |         self.mmttxt_model.eval() | ||||
|  |         output = self.mmttxt_model([text]) | ||||
|  |         # self.assertTrue(output.shape == (batch_size, 1024)) | ||||
|  |         return output.detach().flatten().cpu().numpy() | ||||
|  | 
 | ||||
|  |     def _inference_from_video(self, features, features_t, features_ind): | ||||
|  |         self.mmtvid_model.eval() | ||||
|  |         output = self.mmtvid_model( | ||||
|  |             features=self._preprocess_video_input(features), | ||||
|  |             features_t=self._preprocess_video_input(features_t), | ||||
|  |             features_ind=self._preprocess_video_input(features_ind), | ||||
|  |             features_maxp=None, | ||||
|  |         ) | ||||
|  |         return output.detach().flatten().cpu().numpy() | ||||
|  | 
 | ||||
|  |     def _preprocess_video_input(self, data: Dict): | ||||
|  |         for k, v in data.items(): | ||||
|  |             data[k] = v.unsqueeze(0).to(self.device) | ||||
|  |         return data | ||||
| @ -0,0 +1,4 @@ | |||||
|  | transformers | ||||
|  | torch | ||||
|  | towhee.models | ||||
|  | towhee | ||||
| After Width: | Height: | Size: 7.3 KiB | 
| After Width: | Height: | Size: 7.6 KiB | 
					Loading…
					
					
				
		Reference in new issue
	
	