# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os import json import torch from pathlib import Path import numpy as np from transformers.tokenization_bert import BertTokenizer from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register from .utils import Configs, get_gather_index def arg_process(args): dirname = os.path.dirname(__file__) args.img_checkpoint = dirname + '/' + args.img_checkpoint args.img_model_config = dirname + '/' + args.img_model_config return args @register(output_schema=['vec']) class LightningDOT(NNOperator): """ CLIP multi-modal embedding operator """ def __init__(self, modality: str): sys.path.append(str(Path(__file__).parent)) from dvl.models.bi_encoder import BiEncoder from detector.faster_rcnn import Net, process_img full_path = os.path.dirname(__file__) + '/config/flickr30k_ft_config.json' with open(full_path) as fw: content = fw.read() args = json.loads(content) args = Configs(args) args = arg_process(args) self.bi_encoder = BiEncoder(args, True, True, project_dim=args.project_dim) self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased') img_model, txt_model = self.bi_encoder.img_model, self.bi_encoder.txt_model img_model.eval() txt_model.eval() self.faster_rcnn_preprocess = process_img self.faster_rcnn = Net() self.faster_rcnn.load_state_dict(torch.load(os.path.dirname(__file__) + '/data/model/resnet101_faster_rcnn_final.pth')) self.faster_rcnn.eval() self.modality = modality def img_detfeat_extract(self, img): orig_im_scale = [img.shape[1], img.shape[0]] img, im_scale = self.faster_rcnn_preprocess(img) img = np.expand_dims(img.transpose((2,0,1)), 0) img = torch.FloatTensor(img) bboxes, feat, confidence = self.faster_rcnn(img, im_scale) bboxes = self.bbox_feat_process(bboxes, orig_im_scale) img_bb = torch.cat([bboxes, bboxes[:, 4:5]*bboxes[:, 5:]], dim=-1) return img_bb, feat, confidence def bbox_feat_process(self, bboxes, im_scale): image_w, image_h = im_scale box_width = bboxes[:, 2] - bboxes[:, 0] box_height = bboxes[:, 3] - bboxes[:, 1] scaled_width = box_width / image_w scaled_height = box_height / image_h scaled_x = bboxes[:, 0] / image_w scaled_y = bboxes[:, 1] / image_h box_width = box_width.unsqueeze(1) box_height = box_height.unsqueeze(1) scaled_width = scaled_width.unsqueeze(1) scaled_height = scaled_height.unsqueeze(1) scaled_x = scaled_x.unsqueeze(1) scaled_y = scaled_y .unsqueeze(1) normalized_bbox = torch.hstack((scaled_x, scaled_y, scaled_x + scaled_width, scaled_y + scaled_height, scaled_width, scaled_height)) return normalized_bbox def get_img_feat(self, data): img_pos_feat, img_feat, _ = self.img_detfeat_extract(data) num_bb = img_pos_feat.shape[1] img_input_ids = torch.Tensor([101]).long() return img_feat, img_pos_feat, img_input_ids def __call__(self, data): if self.modality == 'image': vec = self._inference_from_image(data) elif self.modality == 'text': vec = self._inference_from_text(data) else: raise ValueError("modality[{}] not implemented.".format(self._modality)) return vec.detach().cpu().numpy() def _inference_from_text(self, data): ids = self.tokenizer.encode(data) ids = torch.LongTensor(ids).unsqueeze(0) attn_mask = torch.ones(len(ids), dtype=torch.long).unsqueeze(0) pos_ids = torch.arange(len(ids), dtype=torch.long).unsqueeze(0) _, query_vector, _ = self.bi_encoder.txt_model(ids, None, attn_mask, pos_ids) return query_vector def _inference_from_image(self, data): img_pos_feat, img_feat, _ = self.img_detfeat_extract(data) num_bb = img_pos_feat.shape[0] attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) bs = 1 num_bbs = [num_bb] out_size = attn_masks_img.size(0) gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) img_feat, img_pos_feat, img_input_ids = self.get_img_feat(data) fix_txt_encoder = False position_ids = torch.arange(0, img_input_ids.size(0), dtype=torch.long).unsqueeze(0) img_input_ids = img_input_ids.unsqueeze(0) attn_masks_img = attn_masks_img.unsqueeze(0) img_feat = img_feat.unsqueeze(0) img_pos_feat = img_pos_feat.unsqueeze(0) img_seq, img_pooled, img_hidden = self.bi_encoder.get_representation(self.bi_encoder.img_model, img_input_ids, attn_masks_img, position_ids, img_feat, img_pos_feat, None, gather_index, fix_txt_encoder) return img_pooled