lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
319 lines
12 KiB
319 lines
12 KiB
"""
|
|
Referring Expression Comprehension dataset
|
|
"""
|
|
import sys
|
|
import json
|
|
import random
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
from toolz.sandbox import unzip
|
|
|
|
from .data import TxtLmdb
|
|
|
|
|
|
class ReImageFeatDir(object):
|
|
def __init__(self, img_dir):
|
|
self.img_dir = img_dir
|
|
|
|
def __getitem__(self, file_name):
|
|
img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True)
|
|
img_feat = torch.tensor(img_dump['features'])
|
|
img_bb = torch.tensor(img_dump['norm_bb'])
|
|
return img_feat, img_bb
|
|
|
|
|
|
class ReDetectFeatDir(object):
|
|
def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36,
|
|
format_='npz'):
|
|
assert format_ == 'npz', 'only support npz for now.'
|
|
assert isinstance(img_dir, str), 'img_dir is path, not db.'
|
|
self.img_dir = img_dir
|
|
self.conf_th = conf_th
|
|
self.max_bb = max_bb
|
|
self.min_bb = min_bb
|
|
self.num_bb = num_bb
|
|
|
|
def _compute_num_bb(self, img_dump):
|
|
num_bb = max(self.min_bb, (img_dump['conf'] > self.conf_th).sum())
|
|
num_bb = min(self.max_bb, num_bb)
|
|
return num_bb
|
|
|
|
def __getitem__(self, file_name):
|
|
# image input features
|
|
img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True)
|
|
num_bb = self._compute_num_bb(img_dump)
|
|
img_feat = torch.tensor(img_dump['features'][:num_bb, :])
|
|
img_bb = torch.tensor(img_dump['norm_bb'][:num_bb, :])
|
|
return img_feat, img_bb
|
|
|
|
|
|
class ReferringExpressionDataset(Dataset):
|
|
def __init__(self, db_dir, img_dir, max_txt_len=60):
|
|
assert isinstance(img_dir, ReImageFeatDir) or \
|
|
isinstance(img_dir, ReDetectFeatDir)
|
|
self.img_dir = img_dir
|
|
|
|
# load refs = [{ref_id, sent_ids, ann_id, image_id, sentences, split}]
|
|
refs = json.load(open(f'{db_dir}/refs.json', 'r'))
|
|
self.ref_ids = [ref['ref_id'] for ref in refs]
|
|
self.Refs = {ref['ref_id']: ref for ref in refs}
|
|
|
|
# load annotations = [{id, area, bbox, image_id, category_id}]
|
|
anns = json.load(open(f'{db_dir}/annotations.json', 'r'))
|
|
self.Anns = {ann['id']: ann for ann in anns}
|
|
|
|
# load categories = [{id, name, supercategory}]
|
|
categories = json.load(open(f'{db_dir}/categories.json', 'r'))
|
|
self.Cats = {cat['id']: cat['name'] for cat in categories}
|
|
|
|
# load images = [{id, file_name, ann_ids, height, width}]
|
|
images = json.load(open(f'{db_dir}/images.json', 'r'))
|
|
self.Images = {img['id']: img for img in images}
|
|
|
|
# id2len: sent_id -> sent_len
|
|
id2len = json.load(open(f'{db_dir}/id2len.json', 'r'))
|
|
self.id2len = {int(_id): _len for _id, _len in id2len.items()}
|
|
self.max_txt_len = max_txt_len
|
|
self.sent_ids = self._get_sent_ids()
|
|
|
|
# db[str(sent_id)] =
|
|
# {sent_id, sent, ref_id, ann_id, image_id,
|
|
# bbox, input_ids, toked_sent}
|
|
self.db = TxtLmdb(db_dir, readonly=True)
|
|
|
|
# meta
|
|
meta = json.load(open(f'{db_dir}/meta.json', 'r'))
|
|
self.cls_ = meta['CLS']
|
|
self.sep = meta['SEP']
|
|
self.mask = meta['MASK']
|
|
self.v_range = meta['v_range']
|
|
|
|
def shuffle(self):
|
|
# we shuffle ref_ids and make sent_ids according to ref_ids
|
|
random.shuffle(self.ref_ids)
|
|
self.sent_ids = self._get_sent_ids()
|
|
|
|
def _get_sent_ids(self):
|
|
sent_ids = []
|
|
for ref_id in self.ref_ids:
|
|
for sent_id in self.Refs[ref_id]['sent_ids']:
|
|
sent_len = self.id2len[sent_id]
|
|
if self.max_txt_len == -1 or sent_len < self.max_txt_len:
|
|
sent_ids.append(sent_id)
|
|
return sent_ids
|
|
|
|
def _get_img_feat(self, fname):
|
|
img_feat, bb = self.img_dir[fname]
|
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
|
|
num_bb = img_feat.size(0)
|
|
return img_feat, img_bb, num_bb
|
|
|
|
def __len__(self):
|
|
return len(self.sent_ids)
|
|
|
|
def __getitem__(self, i):
|
|
"""
|
|
Return:
|
|
:input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0]
|
|
:position_ids : range(L)
|
|
:img_feat : (num_bb, d)
|
|
:img_pos_feat : (num_bb, 7)
|
|
:attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1]
|
|
:obj_masks : (num_bb, ) all 0's
|
|
:target : (1, )
|
|
"""
|
|
# {sent_id, sent, ref_id, ann_id, image_id,
|
|
# bbox, input_ids, toked_sent}
|
|
sent_id = self.sent_ids[i]
|
|
txt_dump = self.db[str(sent_id)]
|
|
image_id = txt_dump['image_id']
|
|
fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz'
|
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(fname)
|
|
|
|
# text input
|
|
input_ids = txt_dump['input_ids']
|
|
input_ids = [self.cls_] + input_ids + [self.sep]
|
|
attn_masks = [1] * len(input_ids)
|
|
position_ids = list(range(len(input_ids)))
|
|
attn_masks += [1] * num_bb
|
|
|
|
input_ids = torch.tensor(input_ids)
|
|
position_ids = torch.tensor(position_ids)
|
|
attn_masks = torch.tensor(attn_masks)
|
|
|
|
# target bbox
|
|
img = self.Images[image_id]
|
|
assert len(img['ann_ids']) == num_bb, \
|
|
'Please use visual_grounding_coco_gt'
|
|
target = img['ann_ids'].index(txt_dump['ann_id'])
|
|
target = torch.tensor([target])
|
|
|
|
# obj_masks, to be padded with 1, for masking out non-object prob.
|
|
obj_masks = torch.tensor([0]*len(img['ann_ids'])).bool()
|
|
|
|
return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks,
|
|
obj_masks, target)
|
|
|
|
|
|
def re_collate(inputs):
|
|
"""
|
|
Return:
|
|
:input_ids : (n, max_L) padded with 0
|
|
:position_ids : (n, max_L) padded with 0
|
|
:txt_lens : list of [txt_len]
|
|
:img_feat : (n, max_num_bb, feat_dim)
|
|
:img_pos_feat : (n, max_num_bb, 7)
|
|
:num_bbs : list of [num_bb]
|
|
:attn_masks : (n, max_{L+num_bb}) padded with 0
|
|
:obj_masks : (n, max_num_bb) padded with 1
|
|
:targets : (n, )
|
|
"""
|
|
(input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks,
|
|
targets) = map(list, unzip(inputs))
|
|
|
|
txt_lens = [i.size(0) for i in input_ids]
|
|
num_bbs = [f.size(0) for f in img_feats]
|
|
|
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
|
position_ids = pad_sequence(position_ids,
|
|
batch_first=True, padding_value=0)
|
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
|
|
targets = torch.cat(targets, dim=0)
|
|
obj_masks = pad_sequence(obj_masks,
|
|
batch_first=True, padding_value=1).bool()
|
|
|
|
batch_size = len(img_feats)
|
|
num_bb = max(num_bbs)
|
|
feat_dim = img_feats[0].size(1)
|
|
pos_dim = img_pos_feats[0].size(1)
|
|
img_feat = torch.zeros(batch_size, num_bb, feat_dim)
|
|
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
|
|
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)):
|
|
len_ = im.size(0)
|
|
img_feat.data[i, :len_, :] = im.data
|
|
img_pos_feat.data[i, :len_, :] = pos.data
|
|
|
|
return (input_ids, position_ids, txt_lens,
|
|
img_feat, img_pos_feat, num_bbs,
|
|
attn_masks, obj_masks, targets)
|
|
|
|
|
|
class ReferringExpressionEvalDataset(ReferringExpressionDataset):
|
|
def __getitem__(self, i):
|
|
"""
|
|
Return:
|
|
:input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0]
|
|
:position_ids : range(L)
|
|
:img_feat : (num_bb, d)
|
|
:img_pos_feat : (num_bb, 7)
|
|
:attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1]
|
|
:obj_masks : (num_bb, ) all 0's
|
|
:tgt_box : ndarray (4, ) xywh
|
|
:obj_boxes : ndarray (num_bb, 4) xywh
|
|
:sent_id
|
|
"""
|
|
# {sent_id, sent, ref_id, ann_id, image_id,
|
|
# bbox, input_ids, toked_sent}
|
|
sent_id = self.sent_ids[i]
|
|
txt_dump = self.db[str(sent_id)]
|
|
image_id = txt_dump['image_id']
|
|
if isinstance(self.img_dir, ReImageFeatDir):
|
|
if '_gt' in self.img_dir.img_dir:
|
|
fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz'
|
|
elif '_det' in self.img_dir.img_dir:
|
|
fname = f'visual_grounding_det_coco_{int(image_id):012}.npz'
|
|
elif isinstance(self.img_dir, ReDetectFeatDir):
|
|
fname = f'coco_train2014_{int(image_id):012}.npz'
|
|
else:
|
|
sys.exit('%s not supported.' % self.img_dir)
|
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(fname)
|
|
|
|
# image info
|
|
img = self.Images[image_id]
|
|
im_width, im_height = img['width'], img['height']
|
|
|
|
# object boxes, img_pos_feat (xyxywha) -> xywh
|
|
obj_boxes = np.stack([img_pos_feat[:, 0]*im_width,
|
|
img_pos_feat[:, 1]*im_height,
|
|
img_pos_feat[:, 4]*im_width,
|
|
img_pos_feat[:, 5]*im_height], axis=1)
|
|
obj_masks = torch.tensor([0]*num_bb).bool()
|
|
|
|
# target box
|
|
tgt_box = np.array(txt_dump['bbox']) # xywh
|
|
|
|
# text input
|
|
input_ids = txt_dump['input_ids']
|
|
input_ids = [self.cls_] + input_ids + [self.sep]
|
|
attn_masks = [1] * len(input_ids)
|
|
position_ids = list(range(len(input_ids)))
|
|
attn_masks += [1] * num_bb
|
|
|
|
input_ids = torch.tensor(input_ids)
|
|
position_ids = torch.tensor(position_ids)
|
|
attn_masks = torch.tensor(attn_masks)
|
|
|
|
return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks,
|
|
obj_masks, tgt_box, obj_boxes, sent_id)
|
|
|
|
# IoU function
|
|
def computeIoU(self, box1, box2):
|
|
# each box is of [x1, y1, w, h]
|
|
inter_x1 = max(box1[0], box2[0])
|
|
inter_y1 = max(box1[1], box2[1])
|
|
inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1)
|
|
inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1)
|
|
|
|
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
|
|
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
|
|
else:
|
|
inter = 0
|
|
union = box1[2]*box1[3] + box2[2]*box2[3] - inter
|
|
return float(inter)/union
|
|
|
|
|
|
def re_eval_collate(inputs):
|
|
"""
|
|
Return:
|
|
:input_ids : (n, max_L)
|
|
:position_ids : (n, max_L)
|
|
:txt_lens : list of [txt_len]
|
|
:img_feat : (n, max_num_bb, d)
|
|
:img_pos_feat : (n, max_num_bb, 7)
|
|
:num_bbs : list of [num_bb]
|
|
:attn_masks : (n, max{L+num_bb})
|
|
:obj_masks : (n, max_num_bb)
|
|
:tgt_box : list of n [xywh]
|
|
:obj_boxes : list of n [[xywh, xywh, ...]]
|
|
:sent_ids : list of n [sent_id]
|
|
"""
|
|
(input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks,
|
|
tgt_box, obj_boxes, sent_ids) = map(list, unzip(inputs))
|
|
|
|
txt_lens = [i.size(0) for i in input_ids]
|
|
num_bbs = [f.size(0) for f in img_feats]
|
|
|
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
|
position_ids = pad_sequence(position_ids,
|
|
batch_first=True, padding_value=0)
|
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
|
|
obj_masks = pad_sequence(obj_masks,
|
|
batch_first=True, padding_value=1).bool()
|
|
|
|
batch_size = len(img_feats)
|
|
num_bb = max(num_bbs)
|
|
feat_dim = img_feats[0].size(1)
|
|
pos_dim = img_pos_feats[0].size(1)
|
|
img_feat = torch.zeros(batch_size, num_bb, feat_dim)
|
|
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
|
|
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)):
|
|
len_ = im.size(0)
|
|
img_feat.data[i, :len_, :] = im.data
|
|
img_pos_feat.data[i, :len_, :] = pos.data
|
|
|
|
return (input_ids, position_ids, txt_lens,
|
|
img_feat, img_pos_feat, num_bbs,
|
|
attn_masks, obj_masks, tgt_box, obj_boxes, sent_ids)
|