|
|
|
# Copyright 2021 Zilliz. All rights reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
import logging
|
|
|
|
import sys
|
|
|
|
import os
|
|
|
|
import json
|
|
|
|
import torch
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
from transformers.tokenization_bert import BertTokenizer
|
|
|
|
|
|
|
|
from towhee.types.image_utils import to_pil
|
|
|
|
from towhee.operator.base import NNOperator, OperatorFlag
|
|
|
|
from towhee.types.arg import arg, to_image_color
|
|
|
|
from towhee import register
|
|
|
|
|
|
|
|
from .utils import Configs, get_gather_index
|
|
|
|
|
|
|
|
|
|
|
|
def arg_process(args):
|
|
|
|
dirname = os.path.dirname(__file__)
|
|
|
|
#args.img_checkpoint = dirname + '/' + args.img_checkpoint
|
|
|
|
args.img_model_config = dirname + '/' + args.img_model_config
|
|
|
|
return args
|
|
|
|
|
|
|
|
@register(output_schema=['vec'])
|
|
|
|
class LightningDOT(NNOperator):
|
|
|
|
"""
|
|
|
|
CLIP multi-modal embedding operator
|
|
|
|
"""
|
|
|
|
def __init__(self, model_name:str, modality: str):
|
|
|
|
logger = logging.getLogger()
|
|
|
|
sys.path.append(str(Path(__file__).parent))
|
|
|
|
from dvl.models.bi_encoder import BiEncoder
|
|
|
|
from detector.faster_rcnn import Net, process_img
|
|
|
|
from utils import download_file
|
|
|
|
|
|
|
|
config_path = os.path.dirname(__file__) + self._configs()[model_name]['config']
|
|
|
|
model_url = self._configs()[model_name]['weights']
|
|
|
|
weight_name = os.path.basename(model_url)
|
|
|
|
weight_path = os.path.dirname(__file__) + '/data/model/' + weight_name
|
|
|
|
|
|
|
|
if os.path.exists(weight_path) is False:
|
|
|
|
download_file(model_url, os.path.dirname(__file__) + '/data/model/')
|
|
|
|
|
|
|
|
with open(config_path) as fw:
|
|
|
|
content = fw.read()
|
|
|
|
args = json.loads(content)
|
|
|
|
|
|
|
|
#args['img_checkpoint'] = './data/model/' + weight_name
|
|
|
|
args = Configs(args)
|
|
|
|
|
|
|
|
args = arg_process(args)
|
|
|
|
self.bi_encoder = BiEncoder(args, True, True, project_dim=args.project_dim)
|
|
|
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
|
|
|
|
|
|
|
state_dict = torch.load(weight_path, map_location='cpu')
|
|
|
|
try:
|
|
|
|
if 'model_dict' in state_dict:
|
|
|
|
self.bi_encoder.load_state_dict(state_dict['model_dict'])
|
|
|
|
else:
|
|
|
|
self.bi_encoder.load_state_dict(state_dict)
|
|
|
|
except RuntimeError:
|
|
|
|
logger.info('loading from pre-trained model instead')
|
|
|
|
for k in list(state_dict.keys()):
|
|
|
|
if k.startswith('bert.'):
|
|
|
|
state_dict[k[5:]] = state_dict.pop(k)
|
|
|
|
else:
|
|
|
|
state_dict.pop(k)
|
|
|
|
self.bi_encoder.load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
|
|
img_model, txt_model = self.bi_encoder.img_model, self.bi_encoder.txt_model
|
|
|
|
img_model.eval()
|
|
|
|
txt_model.eval()
|
|
|
|
self.faster_rcnn_preprocess = process_img
|
|
|
|
self.faster_rcnn = Net()
|
|
|
|
self.faster_rcnn.load_state_dict(torch.load(os.path.dirname(__file__) + '/data/model/resnet101_faster_rcnn_final.pth'))
|
|
|
|
self.faster_rcnn.eval()
|
|
|
|
self.modality = modality
|
|
|
|
|
|
|
|
def img_detfeat_extract(self, img):
|
|
|
|
orig_im_scale = [img.shape[1], img.shape[0]]
|
|
|
|
img, im_scale = self.faster_rcnn_preprocess(img)
|
|
|
|
img = np.expand_dims(img.transpose((2,0,1)), 0)
|
|
|
|
img = torch.FloatTensor(img)
|
|
|
|
bboxes, feat, confidence = self.faster_rcnn(img, im_scale)
|
|
|
|
bboxes = self.bbox_feat_process(bboxes, orig_im_scale)
|
|
|
|
img_bb = torch.cat([bboxes, bboxes[:, 4:5]*bboxes[:, 5:]], dim=-1)
|
|
|
|
return img_bb, feat, confidence
|
|
|
|
|
|
|
|
def bbox_feat_process(self, bboxes, im_scale):
|
|
|
|
image_w, image_h = im_scale
|
|
|
|
box_width = bboxes[:, 2] - bboxes[:, 0]
|
|
|
|
box_height = bboxes[:, 3] - bboxes[:, 1]
|
|
|
|
scaled_width = box_width / image_w
|
|
|
|
scaled_height = box_height / image_h
|
|
|
|
scaled_x = bboxes[:, 0] / image_w
|
|
|
|
scaled_y = bboxes[:, 1] / image_h
|
|
|
|
|
|
|
|
box_width = box_width.unsqueeze(1)
|
|
|
|
box_height = box_height.unsqueeze(1)
|
|
|
|
scaled_width = scaled_width.unsqueeze(1)
|
|
|
|
scaled_height = scaled_height.unsqueeze(1)
|
|
|
|
scaled_x = scaled_x.unsqueeze(1)
|
|
|
|
scaled_y = scaled_y .unsqueeze(1)
|
|
|
|
|
|
|
|
normalized_bbox = torch.hstack((scaled_x, scaled_y,
|
|
|
|
scaled_x + scaled_width,
|
|
|
|
scaled_y + scaled_height,
|
|
|
|
scaled_width, scaled_height))
|
|
|
|
|
|
|
|
return normalized_bbox
|
|
|
|
|
|
|
|
def get_img_feat(self, data):
|
|
|
|
img_pos_feat, img_feat, _ = self.img_detfeat_extract(data)
|
|
|
|
num_bb = img_pos_feat.shape[1]
|
|
|
|
|
|
|
|
img_input_ids = torch.Tensor([101]).long()
|
|
|
|
|
|
|
|
return img_feat, img_pos_feat, img_input_ids
|
|
|
|
|
|
|
|
def __call__(self, data):
|
|
|
|
if self.modality == 'image':
|
|
|
|
vec = self._inference_from_image(data)
|
|
|
|
elif self.modality == 'text':
|
|
|
|
vec = self._inference_from_text(data)
|
|
|
|
else:
|
|
|
|
raise ValueError("modality[{}] not implemented.".format(self._modality))
|
|
|
|
return vec.detach().cpu().numpy().flatten()
|
|
|
|
|
|
|
|
def _inference_from_text(self, data):
|
|
|
|
ids = self.tokenizer.encode(data)
|
|
|
|
ids = torch.LongTensor(ids).unsqueeze(0)
|
|
|
|
attn_mask = torch.ones(len(ids), dtype=torch.long).unsqueeze(0)
|
|
|
|
pos_ids = torch.arange(len(ids), dtype=torch.long).unsqueeze(0)
|
|
|
|
_, query_vector, _ = self.bi_encoder.txt_model(ids, None, attn_mask, pos_ids)
|
|
|
|
return query_vector
|
|
|
|
|
|
|
|
def _inference_from_image(self, data):
|
|
|
|
img_pos_feat, img_feat, _ = self.img_detfeat_extract(data)
|
|
|
|
num_bb = img_pos_feat.shape[0]
|
|
|
|
|
|
|
|
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
|
|
|
|
bs = 1
|
|
|
|
num_bbs = [num_bb]
|
|
|
|
out_size = attn_masks_img.size(0)
|
|
|
|
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
|
|
|
|
img_feat, img_pos_feat, img_input_ids = self.get_img_feat(data)
|
|
|
|
fix_txt_encoder = False
|
|
|
|
|
|
|
|
position_ids = torch.arange(0, img_input_ids.size(0), dtype=torch.long).unsqueeze(0)
|
|
|
|
img_input_ids = img_input_ids.unsqueeze(0)
|
|
|
|
attn_masks_img = attn_masks_img.unsqueeze(0)
|
|
|
|
img_feat = img_feat.unsqueeze(0)
|
|
|
|
img_pos_feat = img_pos_feat.unsqueeze(0)
|
|
|
|
|
|
|
|
img_seq, img_pooled, img_hidden = self.bi_encoder.get_representation(self.bi_encoder.img_model, img_input_ids,
|
|
|
|
attn_masks_img, position_ids,
|
|
|
|
img_feat, img_pos_feat,
|
|
|
|
None,
|
|
|
|
gather_index, fix_txt_encoder)
|
|
|
|
|
|
|
|
return img_pooled
|
|
|
|
|
|
|
|
def _configs(self):
|
|
|
|
config = {}
|
|
|
|
config['lightningdot_base'] = {}
|
|
|
|
config['lightningdot_base']['weights'] = 'https://convaisharables.blob.core.windows.net/lightningdot/LightningDot.pt'
|
|
|
|
config['lightningdot_base']['config'] = '/config/pretrain-alldata-base.json'
|
|
|
|
|
|
|
|
config['lightningdot_coco_ft'] = {}
|
|
|
|
config['lightningdot_coco_ft']['weights'] = 'https://convaisharables.blob.core.windows.net/lightningdot/coco-ft.pt'
|
|
|
|
config['lightningdot_coco_ft']['config'] = '/config/coco_eval_config.json'
|
|
|
|
|
|
|
|
config['lightningdot_flickr_ft'] = {}
|
|
|
|
config['lightningdot_flickr_ft']['weights'] = 'https://convaisharables.blob.core.windows.net/lightningdot/flickr-ft.pt'
|
|
|
|
config['lightningdot_flickr_ft']['config'] = '/config/flickr30k_eval_config.json'
|
|
|
|
return config
|