mdmmt
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
138 lines
5.4 KiB
138 lines
5.4 KiB
2 years ago
|
# Copyright 2021 Zilliz. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from typing import Dict, Union
|
||
|
from towhee.models.mdmmt.mmt import MMTVID, MMTTXT
|
||
|
from towhee.operator.base import NNOperator
|
||
|
from towhee import register
|
||
|
from pathlib import Path
|
||
|
from transformers.models.bert.modeling_bert import BertModel as TxtBertModel
|
||
|
from transformers import AutoTokenizer
|
||
|
|
||
|
import warnings
|
||
|
warnings.filterwarnings('ignore')
|
||
|
|
||
|
|
||
|
@register(output_schema=['vec'])
|
||
|
class MDMMT(NNOperator):
|
||
|
"""
|
||
|
MDMMT multi-modal embedding operator
|
||
|
"""
|
||
|
|
||
|
def __init__(self, modality: str, weight_path: str = None, device: str = None, mmtvid_params: Dict = None,
|
||
|
mmttxt_params: Dict = None):
|
||
|
super().__init__()
|
||
|
self.modality = modality
|
||
|
if weight_path is None:
|
||
|
weight_path = str(Path(__file__).parent / 'mdmmt_3mod.pth')
|
||
|
# print('weight_path is None, use default path: {}'.format(weight_path))
|
||
|
if device is None:
|
||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
else:
|
||
|
self.device = device
|
||
|
self.mmtvid_model = None
|
||
|
self.mmttxt_model = None
|
||
|
state = torch.load(weight_path, map_location='cpu')
|
||
|
if self.modality == 'video':
|
||
|
if mmtvid_params is None:
|
||
|
expert_dims = {
|
||
|
"VIDEO": {"dim": 2048, "idx": 1, "max_tok": 30},
|
||
|
"CLIP": {"dim": 512, "idx": 2, "max_tok": 30},
|
||
|
"tf_vggish": {"dim": 128, "idx": 3, "max_tok": 30},
|
||
|
}
|
||
|
vid_bert_params = {
|
||
|
"vocab_size_or_config_json_file": 10,
|
||
|
"hidden_size": 512,
|
||
|
"num_hidden_layers": 9,
|
||
|
"intermediate_size": 3072,
|
||
|
"hidden_act": "gelu",
|
||
|
"hidden_dropout_prob": 0.2,
|
||
|
"attention_probs_dropout_prob": 0.2,
|
||
|
"max_position_embeddings": 32,
|
||
|
"type_vocab_size": 19,
|
||
|
"initializer_range": 0.02,
|
||
|
"layer_norm_eps": 1e-12,
|
||
|
"num_attention_heads": 8,
|
||
|
}
|
||
|
|
||
|
class Struct:
|
||
|
def __init__(self, **entries):
|
||
|
self.__dict__.update(entries)
|
||
|
|
||
|
config = Struct(**vid_bert_params)
|
||
|
self.mmtvid_model = MMTVID(
|
||
|
expert_dims=expert_dims,
|
||
|
same_dim=512,
|
||
|
hidden_size=512,
|
||
|
vid_bert_config=config
|
||
|
)
|
||
|
else:
|
||
|
self.mmtvid_model = MMTVID(**mmtvid_params)
|
||
|
self.mmtvid_model.load_state_dict(state['vid_state_dict'])
|
||
|
self.mmtvid_model.to(device)
|
||
|
self.mmtvid_model.eval()
|
||
|
elif self.modality == 'text':
|
||
|
if mmttxt_params is None:
|
||
|
txt_bert_params = {
|
||
|
'hidden_dropout_prob': 0.2,
|
||
|
'attention_probs_dropout_prob': 0.2,
|
||
|
}
|
||
|
self.mmttxt_model = MMTTXT(
|
||
|
txt_bert=TxtBertModel.from_pretrained('bert-base-cased', **txt_bert_params),
|
||
|
tokenizer=AutoTokenizer.from_pretrained('bert-base-cased'),
|
||
|
max_length=30,
|
||
|
modalities=["CLIP", "tf_vggish", "VIDEO"],
|
||
|
add_special_tokens=True,
|
||
|
add_dot=True,
|
||
|
same_dim=512,
|
||
|
dout_prob=0.2,
|
||
|
)
|
||
|
else:
|
||
|
self.mmttxt_model = MMTTXT(**mmttxt_params)
|
||
|
self.mmttxt_model.load_state_dict(state['txt_state_dict'])
|
||
|
self.mmttxt_model.to(device)
|
||
|
self.mmttxt_model.eval()
|
||
|
|
||
|
def __call__(self, data: Union[Dict, str]):
|
||
|
if self.modality == 'video':
|
||
|
vec = self._inference_from_video(**data) # {"features"=..., "features_t"=..., "features_ind"=...}
|
||
|
elif self.modality == 'text':
|
||
|
vec = self._inference_from_text(data) # str
|
||
|
else:
|
||
|
raise ValueError("modality[{}] not implemented.".format(self._modality))
|
||
|
return vec
|
||
|
|
||
|
def _inference_from_text(self, text: str):
|
||
|
self.mmttxt_model.eval()
|
||
|
output = self.mmttxt_model([text])
|
||
|
# self.assertTrue(output.shape == (batch_size, 1024))
|
||
|
return output.detach().flatten().cpu().numpy()
|
||
|
|
||
|
def _inference_from_video(self, features, features_t, features_ind):
|
||
|
self.mmtvid_model.eval()
|
||
|
output = self.mmtvid_model(
|
||
|
features=self._preprocess_video_input(features),
|
||
|
features_t=self._preprocess_video_input(features_t),
|
||
|
features_ind=self._preprocess_video_input(features_ind),
|
||
|
features_maxp=None,
|
||
|
)
|
||
|
return output.detach().flatten().cpu().numpy()
|
||
|
|
||
|
def _preprocess_video_input(self, data: Dict):
|
||
|
for k, v in data.items():
|
||
|
data[k] = v.unsqueeze(0).to(self.device)
|
||
|
return data
|