mdmmt
copied
ChengZi
2 years ago
6 changed files with 281 additions and 1 deletions
@ -1,2 +1,121 @@ |
|||
# mdmmt |
|||
# Video-Text Retrieval Embdding with MDMMT |
|||
|
|||
*author: Chen Zhang* |
|||
|
|||
|
|||
<br /> |
|||
|
|||
|
|||
|
|||
## Description |
|||
|
|||
This operator extracts features for video or text with [MDMMT: Multidomain Multimodal Transformer for Video Retrieval](https://arxiv.org/pdf/2103.10699.pdf), which can generate embeddings for text and video by jointly training a video encoder and text encoder to maximize the cosine similarity. |
|||
|
|||
|
|||
<br /> |
|||
|
|||
|
|||
## Code Example |
|||
|
|||
Load a video embeddings extracted from different upstream expert networks, such as video, RGB, audio. |
|||
|
|||
Read the text to generate a text embedding. |
|||
|
|||
*Write the pipeline code*: |
|||
|
|||
```python |
|||
import towhee |
|||
import torch |
|||
|
|||
torch.manual_seed(42) |
|||
|
|||
# features are embeddings extracted from the upstream models. |
|||
features = { |
|||
"VIDEO": torch.rand(30, 2048), |
|||
"CLIP": torch.rand(30, 512), |
|||
"tf_vggish": torch.rand(30, 128), |
|||
} |
|||
|
|||
# features_t is the time series of the features, usually uniformly sampled. |
|||
features_t = { |
|||
"VIDEO": torch.linspace(1, 30, steps=30), |
|||
"CLIP": torch.linspace(1, 30, steps=30), |
|||
"tf_vggish": torch.linspace(1, 30, steps=30), |
|||
} |
|||
|
|||
# features_ind is the mask of the features. |
|||
features_ind = { |
|||
"VIDEO": torch.as_tensor([1] * 25 + [0] * 5), |
|||
"CLIP": torch.as_tensor([1] * 25 + [0] * 5), |
|||
"tf_vggish": torch.as_tensor([1] * 25 + [0] * 5), |
|||
} |
|||
|
|||
video_input_dict = {"features": features, "features_t": features_t, "features_ind": features_ind} |
|||
|
|||
towhee.dc([video_input_dict]).video_text_embedding.mdmmt(modality='video', device='cpu').show() |
|||
|
|||
towhee.dc(['Hello world.']).video_text_embedding.mdmmt(modality='text', device='cpu').show() |
|||
``` |
|||
![](vect_simplified_video.png) |
|||
![](vect_simplified_text.png) |
|||
|
|||
*Write a same pipeline with explicit inputs/outputs name specifications:* |
|||
|
|||
<br /> |
|||
|
|||
|
|||
|
|||
## Factory Constructor |
|||
|
|||
Create the operator via the following factory method |
|||
|
|||
***mdmmt(modality: str)*** |
|||
|
|||
**Parameters:** |
|||
|
|||
***modality:*** *str* |
|||
|
|||
Which modality(*video* or *text*) is used to generate the embedding. |
|||
|
|||
***weight_path:*** *Optional[str]* |
|||
|
|||
pretrained model weights path. |
|||
|
|||
***device:*** *Optional[str]* |
|||
|
|||
cpu or cuda. |
|||
|
|||
***mmtvid_params:*** *Optional[dict]* |
|||
|
|||
mmtvid model params for custom model. |
|||
|
|||
***mmttxt_params:*** *Optional[dict]* |
|||
|
|||
mmttxt model params for custom model. |
|||
|
|||
|
|||
<br /> |
|||
|
|||
|
|||
|
|||
## Interface |
|||
|
|||
When video modality, load a video embeddings extracted from different upstream expert networks, such as video, RGB, audio. |
|||
When text modality, read the text to generate a text embedding. |
|||
|
|||
|
|||
**Parameters:** |
|||
|
|||
***data:*** *dict* or *str* |
|||
|
|||
The embedding dict extracted from different upstream expert networks or text, based on specified modality). |
|||
|
|||
|
|||
|
|||
**Returns:** *numpy.ndarray* |
|||
|
|||
The data embedding extracted by model. |
|||
|
|||
|
|||
|
|||
|
|||
|
@ -0,0 +1,20 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from .mdmmt import MDMMT |
|||
|
|||
|
|||
def mdmmt(modality: str, **kwargs): |
|||
return MDMMT(modality, **kwargs) |
|||
|
@ -0,0 +1,137 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import torch |
|||
|
|||
from typing import Dict, Union |
|||
from towhee.models.mdmmt.mmt import MMTVID, MMTTXT |
|||
from towhee.operator.base import NNOperator |
|||
from towhee import register |
|||
from pathlib import Path |
|||
from transformers.models.bert.modeling_bert import BertModel as TxtBertModel |
|||
from transformers import AutoTokenizer |
|||
|
|||
import warnings |
|||
warnings.filterwarnings('ignore') |
|||
|
|||
|
|||
@register(output_schema=['vec']) |
|||
class MDMMT(NNOperator): |
|||
""" |
|||
MDMMT multi-modal embedding operator |
|||
""" |
|||
|
|||
def __init__(self, modality: str, weight_path: str = None, device: str = None, mmtvid_params: Dict = None, |
|||
mmttxt_params: Dict = None): |
|||
super().__init__() |
|||
self.modality = modality |
|||
if weight_path is None: |
|||
weight_path = str(Path(__file__).parent / 'mdmmt_3mod.pth') |
|||
# print('weight_path is None, use default path: {}'.format(weight_path)) |
|||
if device is None: |
|||
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|||
else: |
|||
self.device = device |
|||
self.mmtvid_model = None |
|||
self.mmttxt_model = None |
|||
state = torch.load(weight_path, map_location='cpu') |
|||
if self.modality == 'video': |
|||
if mmtvid_params is None: |
|||
expert_dims = { |
|||
"VIDEO": {"dim": 2048, "idx": 1, "max_tok": 30}, |
|||
"CLIP": {"dim": 512, "idx": 2, "max_tok": 30}, |
|||
"tf_vggish": {"dim": 128, "idx": 3, "max_tok": 30}, |
|||
} |
|||
vid_bert_params = { |
|||
"vocab_size_or_config_json_file": 10, |
|||
"hidden_size": 512, |
|||
"num_hidden_layers": 9, |
|||
"intermediate_size": 3072, |
|||
"hidden_act": "gelu", |
|||
"hidden_dropout_prob": 0.2, |
|||
"attention_probs_dropout_prob": 0.2, |
|||
"max_position_embeddings": 32, |
|||
"type_vocab_size": 19, |
|||
"initializer_range": 0.02, |
|||
"layer_norm_eps": 1e-12, |
|||
"num_attention_heads": 8, |
|||
} |
|||
|
|||
class Struct: |
|||
def __init__(self, **entries): |
|||
self.__dict__.update(entries) |
|||
|
|||
config = Struct(**vid_bert_params) |
|||
self.mmtvid_model = MMTVID( |
|||
expert_dims=expert_dims, |
|||
same_dim=512, |
|||
hidden_size=512, |
|||
vid_bert_config=config |
|||
) |
|||
else: |
|||
self.mmtvid_model = MMTVID(**mmtvid_params) |
|||
self.mmtvid_model.load_state_dict(state['vid_state_dict']) |
|||
self.mmtvid_model.to(device) |
|||
self.mmtvid_model.eval() |
|||
elif self.modality == 'text': |
|||
if mmttxt_params is None: |
|||
txt_bert_params = { |
|||
'hidden_dropout_prob': 0.2, |
|||
'attention_probs_dropout_prob': 0.2, |
|||
} |
|||
self.mmttxt_model = MMTTXT( |
|||
txt_bert=TxtBertModel.from_pretrained('bert-base-cased', **txt_bert_params), |
|||
tokenizer=AutoTokenizer.from_pretrained('bert-base-cased'), |
|||
max_length=30, |
|||
modalities=["CLIP", "tf_vggish", "VIDEO"], |
|||
add_special_tokens=True, |
|||
add_dot=True, |
|||
same_dim=512, |
|||
dout_prob=0.2, |
|||
) |
|||
else: |
|||
self.mmttxt_model = MMTTXT(**mmttxt_params) |
|||
self.mmttxt_model.load_state_dict(state['txt_state_dict']) |
|||
self.mmttxt_model.to(device) |
|||
self.mmttxt_model.eval() |
|||
|
|||
def __call__(self, data: Union[Dict, str]): |
|||
if self.modality == 'video': |
|||
vec = self._inference_from_video(**data) # {"features"=..., "features_t"=..., "features_ind"=...} |
|||
elif self.modality == 'text': |
|||
vec = self._inference_from_text(data) # str |
|||
else: |
|||
raise ValueError("modality[{}] not implemented.".format(self._modality)) |
|||
return vec |
|||
|
|||
def _inference_from_text(self, text: str): |
|||
self.mmttxt_model.eval() |
|||
output = self.mmttxt_model([text]) |
|||
# self.assertTrue(output.shape == (batch_size, 1024)) |
|||
return output.detach().flatten().cpu().numpy() |
|||
|
|||
def _inference_from_video(self, features, features_t, features_ind): |
|||
self.mmtvid_model.eval() |
|||
output = self.mmtvid_model( |
|||
features=self._preprocess_video_input(features), |
|||
features_t=self._preprocess_video_input(features_t), |
|||
features_ind=self._preprocess_video_input(features_ind), |
|||
features_maxp=None, |
|||
) |
|||
return output.detach().flatten().cpu().numpy() |
|||
|
|||
def _preprocess_video_input(self, data: Dict): |
|||
for k, v in data.items(): |
|||
data[k] = v.unsqueeze(0).to(self.device) |
|||
return data |
@ -0,0 +1,4 @@ |
|||
transformers |
|||
torch |
|||
towhee.models |
|||
towhee |
After Width: | Height: | Size: 7.3 KiB |
After Width: | Height: | Size: 7.6 KiB |
Loading…
Reference in new issue