logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

138 lines
5.4 KiB

2 years ago
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Dict, Union
from towhee.models.mdmmt.mmt import MMTVID, MMTTXT
from towhee.operator.base import NNOperator
from towhee import register
from pathlib import Path
from transformers.models.bert.modeling_bert import BertModel as TxtBertModel
from transformers import AutoTokenizer
import warnings
warnings.filterwarnings('ignore')
@register(output_schema=['vec'])
class MDMMT(NNOperator):
"""
MDMMT multi-modal embedding operator
"""
def __init__(self, modality: str, weight_path: str = None, device: str = None, mmtvid_params: Dict = None,
mmttxt_params: Dict = None):
super().__init__()
self.modality = modality
if weight_path is None:
weight_path = str(Path(__file__).parent / 'mdmmt_3mod.pth')
# print('weight_path is None, use default path: {}'.format(weight_path))
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.mmtvid_model = None
self.mmttxt_model = None
state = torch.load(weight_path, map_location='cpu')
if self.modality == 'video':
if mmtvid_params is None:
expert_dims = {
"VIDEO": {"dim": 2048, "idx": 1, "max_tok": 30},
"CLIP": {"dim": 512, "idx": 2, "max_tok": 30},
"tf_vggish": {"dim": 128, "idx": 3, "max_tok": 30},
}
vid_bert_params = {
"vocab_size_or_config_json_file": 10,
"hidden_size": 512,
"num_hidden_layers": 9,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.2,
"attention_probs_dropout_prob": 0.2,
"max_position_embeddings": 32,
"type_vocab_size": 19,
"initializer_range": 0.02,
"layer_norm_eps": 1e-12,
"num_attention_heads": 8,
}
class Struct:
def __init__(self, **entries):
self.__dict__.update(entries)
config = Struct(**vid_bert_params)
self.mmtvid_model = MMTVID(
expert_dims=expert_dims,
same_dim=512,
hidden_size=512,
vid_bert_config=config
)
else:
self.mmtvid_model = MMTVID(**mmtvid_params)
self.mmtvid_model.load_state_dict(state['vid_state_dict'])
self.mmtvid_model.to(device)
self.mmtvid_model.eval()
elif self.modality == 'text':
if mmttxt_params is None:
txt_bert_params = {
'hidden_dropout_prob': 0.2,
'attention_probs_dropout_prob': 0.2,
}
self.mmttxt_model = MMTTXT(
txt_bert=TxtBertModel.from_pretrained('bert-base-cased', **txt_bert_params),
tokenizer=AutoTokenizer.from_pretrained('bert-base-cased'),
max_length=30,
modalities=["CLIP", "tf_vggish", "VIDEO"],
add_special_tokens=True,
add_dot=True,
same_dim=512,
dout_prob=0.2,
)
else:
self.mmttxt_model = MMTTXT(**mmttxt_params)
self.mmttxt_model.load_state_dict(state['txt_state_dict'])
self.mmttxt_model.to(device)
self.mmttxt_model.eval()
def __call__(self, data: Union[Dict, str]):
if self.modality == 'video':
vec = self._inference_from_video(**data) # {"features"=..., "features_t"=..., "features_ind"=...}
elif self.modality == 'text':
vec = self._inference_from_text(data) # str
else:
raise ValueError("modality[{}] not implemented.".format(self._modality))
return vec
def _inference_from_text(self, text: str):
self.mmttxt_model.eval()
output = self.mmttxt_model([text])
# self.assertTrue(output.shape == (batch_size, 1024))
return output.detach().flatten().cpu().numpy()
def _inference_from_video(self, features, features_t, features_ind):
self.mmtvid_model.eval()
output = self.mmtvid_model(
features=self._preprocess_video_input(features),
features_t=self._preprocess_video_input(features_t),
features_ind=self._preprocess_video_input(features_ind),
features_maxp=None,
)
return output.detach().flatten().cpu().numpy()
def _preprocess_video_input(self, data: Dict):
for k, v in data.items():
data[k] = v.unsqueeze(0).to(self.device)
return data