# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from typing import Dict, Union from towhee.models.mdmmt.mmt import MMTVID, MMTTXT from towhee.operator.base import NNOperator from towhee import register from pathlib import Path from transformers.models.bert.modeling_bert import BertModel as TxtBertModel from transformers import AutoTokenizer import warnings warnings.filterwarnings('ignore') @register(output_schema=['vec']) class MDMMT(NNOperator): """ MDMMT multi-modal embedding operator """ def __init__(self, modality: str, weight_path: str = None, device: str = None, mmtvid_params: Dict = None, mmttxt_params: Dict = None): super().__init__() self.modality = modality if weight_path is None: weight_path = str(Path(__file__).parent / 'mdmmt_3mod.pth') # print('weight_path is None, use default path: {}'.format(weight_path)) if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device self.mmtvid_model = None self.mmttxt_model = None state = torch.load(weight_path, map_location='cpu') if self.modality == 'video': if mmtvid_params is None: expert_dims = { "VIDEO": {"dim": 2048, "idx": 1, "max_tok": 30}, "CLIP": {"dim": 512, "idx": 2, "max_tok": 30}, "tf_vggish": {"dim": 128, "idx": 3, "max_tok": 30}, } vid_bert_params = { "vocab_size_or_config_json_file": 10, "hidden_size": 512, "num_hidden_layers": 9, "intermediate_size": 3072, "hidden_act": "gelu", "hidden_dropout_prob": 0.2, "attention_probs_dropout_prob": 0.2, "max_position_embeddings": 32, "type_vocab_size": 19, "initializer_range": 0.02, "layer_norm_eps": 1e-12, "num_attention_heads": 8, } class Struct: def __init__(self, **entries): self.__dict__.update(entries) config = Struct(**vid_bert_params) self.mmtvid_model = MMTVID( expert_dims=expert_dims, same_dim=512, hidden_size=512, vid_bert_config=config ) else: self.mmtvid_model = MMTVID(**mmtvid_params) self.mmtvid_model.load_state_dict(state['vid_state_dict']) self.mmtvid_model.to(device) self.mmtvid_model.eval() elif self.modality == 'text': if mmttxt_params is None: txt_bert_params = { 'hidden_dropout_prob': 0.2, 'attention_probs_dropout_prob': 0.2, } self.mmttxt_model = MMTTXT( txt_bert=TxtBertModel.from_pretrained('bert-base-cased', **txt_bert_params), tokenizer=AutoTokenizer.from_pretrained('bert-base-cased'), max_length=30, modalities=["CLIP", "tf_vggish", "VIDEO"], add_special_tokens=True, add_dot=True, same_dim=512, dout_prob=0.2, ) else: self.mmttxt_model = MMTTXT(**mmttxt_params) self.mmttxt_model.load_state_dict(state['txt_state_dict']) self.mmttxt_model.to(device) self.mmttxt_model.eval() def __call__(self, data: Union[Dict, str]): if self.modality == 'video': vec = self._inference_from_video(**data) # {"features"=..., "features_t"=..., "features_ind"=...} elif self.modality == 'text': vec = self._inference_from_text(data) # str else: raise ValueError("modality[{}] not implemented.".format(self._modality)) return vec def _inference_from_text(self, text: str): self.mmttxt_model.eval() output = self.mmttxt_model([text]) # self.assertTrue(output.shape == (batch_size, 1024)) return output.detach().flatten().cpu().numpy() def _inference_from_video(self, features, features_t, features_ind): self.mmtvid_model.eval() output = self.mmtvid_model( features=self._preprocess_video_input(features), features_t=self._preprocess_video_input(features_t), features_ind=self._preprocess_video_input(features_ind), features_maxp=None, ) return output.detach().flatten().cpu().numpy() def _preprocess_video_input(self, data: Dict): for k, v in data.items(): data[k] = v.unsqueeze(0).to(self.device) return data