logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

320 lines
12 KiB

"""
Referring Expression Comprehension dataset
"""
import sys
import json
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from .data import TxtLmdb
class ReImageFeatDir(object):
def __init__(self, img_dir):
self.img_dir = img_dir
def __getitem__(self, file_name):
img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True)
img_feat = torch.tensor(img_dump['features'])
img_bb = torch.tensor(img_dump['norm_bb'])
return img_feat, img_bb
class ReDetectFeatDir(object):
def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36,
format_='npz'):
assert format_ == 'npz', 'only support npz for now.'
assert isinstance(img_dir, str), 'img_dir is path, not db.'
self.img_dir = img_dir
self.conf_th = conf_th
self.max_bb = max_bb
self.min_bb = min_bb
self.num_bb = num_bb
def _compute_num_bb(self, img_dump):
num_bb = max(self.min_bb, (img_dump['conf'] > self.conf_th).sum())
num_bb = min(self.max_bb, num_bb)
return num_bb
def __getitem__(self, file_name):
# image input features
img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True)
num_bb = self._compute_num_bb(img_dump)
img_feat = torch.tensor(img_dump['features'][:num_bb, :])
img_bb = torch.tensor(img_dump['norm_bb'][:num_bb, :])
return img_feat, img_bb
class ReferringExpressionDataset(Dataset):
def __init__(self, db_dir, img_dir, max_txt_len=60):
assert isinstance(img_dir, ReImageFeatDir) or \
isinstance(img_dir, ReDetectFeatDir)
self.img_dir = img_dir
# load refs = [{ref_id, sent_ids, ann_id, image_id, sentences, split}]
refs = json.load(open(f'{db_dir}/refs.json', 'r'))
self.ref_ids = [ref['ref_id'] for ref in refs]
self.Refs = {ref['ref_id']: ref for ref in refs}
# load annotations = [{id, area, bbox, image_id, category_id}]
anns = json.load(open(f'{db_dir}/annotations.json', 'r'))
self.Anns = {ann['id']: ann for ann in anns}
# load categories = [{id, name, supercategory}]
categories = json.load(open(f'{db_dir}/categories.json', 'r'))
self.Cats = {cat['id']: cat['name'] for cat in categories}
# load images = [{id, file_name, ann_ids, height, width}]
images = json.load(open(f'{db_dir}/images.json', 'r'))
self.Images = {img['id']: img for img in images}
# id2len: sent_id -> sent_len
id2len = json.load(open(f'{db_dir}/id2len.json', 'r'))
self.id2len = {int(_id): _len for _id, _len in id2len.items()}
self.max_txt_len = max_txt_len
self.sent_ids = self._get_sent_ids()
# db[str(sent_id)] =
# {sent_id, sent, ref_id, ann_id, image_id,
# bbox, input_ids, toked_sent}
self.db = TxtLmdb(db_dir, readonly=True)
# meta
meta = json.load(open(f'{db_dir}/meta.json', 'r'))
self.cls_ = meta['CLS']
self.sep = meta['SEP']
self.mask = meta['MASK']
self.v_range = meta['v_range']
def shuffle(self):
# we shuffle ref_ids and make sent_ids according to ref_ids
random.shuffle(self.ref_ids)
self.sent_ids = self._get_sent_ids()
def _get_sent_ids(self):
sent_ids = []
for ref_id in self.ref_ids:
for sent_id in self.Refs[ref_id]['sent_ids']:
sent_len = self.id2len[sent_id]
if self.max_txt_len == -1 or sent_len < self.max_txt_len:
sent_ids.append(sent_id)
return sent_ids
def _get_img_feat(self, fname):
img_feat, bb = self.img_dir[fname]
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
num_bb = img_feat.size(0)
return img_feat, img_bb, num_bb
def __len__(self):
return len(self.sent_ids)
def __getitem__(self, i):
"""
Return:
:input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0]
:position_ids : range(L)
:img_feat : (num_bb, d)
:img_pos_feat : (num_bb, 7)
:attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1]
:obj_masks : (num_bb, ) all 0's
:target : (1, )
"""
# {sent_id, sent, ref_id, ann_id, image_id,
# bbox, input_ids, toked_sent}
sent_id = self.sent_ids[i]
txt_dump = self.db[str(sent_id)]
image_id = txt_dump['image_id']
fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz'
img_feat, img_pos_feat, num_bb = self._get_img_feat(fname)
# text input
input_ids = txt_dump['input_ids']
input_ids = [self.cls_] + input_ids + [self.sep]
attn_masks = [1] * len(input_ids)
position_ids = list(range(len(input_ids)))
attn_masks += [1] * num_bb
input_ids = torch.tensor(input_ids)
position_ids = torch.tensor(position_ids)
attn_masks = torch.tensor(attn_masks)
# target bbox
img = self.Images[image_id]
assert len(img['ann_ids']) == num_bb, \
'Please use visual_grounding_coco_gt'
target = img['ann_ids'].index(txt_dump['ann_id'])
target = torch.tensor([target])
# obj_masks, to be padded with 1, for masking out non-object prob.
obj_masks = torch.tensor([0]*len(img['ann_ids'])).bool()
return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks,
obj_masks, target)
def re_collate(inputs):
"""
Return:
:input_ids : (n, max_L) padded with 0
:position_ids : (n, max_L) padded with 0
:txt_lens : list of [txt_len]
:img_feat : (n, max_num_bb, feat_dim)
:img_pos_feat : (n, max_num_bb, 7)
:num_bbs : list of [num_bb]
:attn_masks : (n, max_{L+num_bb}) padded with 0
:obj_masks : (n, max_num_bb) padded with 1
:targets : (n, )
"""
(input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks,
targets) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
num_bbs = [f.size(0) for f in img_feats]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = pad_sequence(position_ids,
batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
targets = torch.cat(targets, dim=0)
obj_masks = pad_sequence(obj_masks,
batch_first=True, padding_value=1).bool()
batch_size = len(img_feats)
num_bb = max(num_bbs)
feat_dim = img_feats[0].size(1)
pos_dim = img_pos_feats[0].size(1)
img_feat = torch.zeros(batch_size, num_bb, feat_dim)
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)):
len_ = im.size(0)
img_feat.data[i, :len_, :] = im.data
img_pos_feat.data[i, :len_, :] = pos.data
return (input_ids, position_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attn_masks, obj_masks, targets)
class ReferringExpressionEvalDataset(ReferringExpressionDataset):
def __getitem__(self, i):
"""
Return:
:input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0]
:position_ids : range(L)
:img_feat : (num_bb, d)
:img_pos_feat : (num_bb, 7)
:attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1]
:obj_masks : (num_bb, ) all 0's
:tgt_box : ndarray (4, ) xywh
:obj_boxes : ndarray (num_bb, 4) xywh
:sent_id
"""
# {sent_id, sent, ref_id, ann_id, image_id,
# bbox, input_ids, toked_sent}
sent_id = self.sent_ids[i]
txt_dump = self.db[str(sent_id)]
image_id = txt_dump['image_id']
if isinstance(self.img_dir, ReImageFeatDir):
if '_gt' in self.img_dir.img_dir:
fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz'
elif '_det' in self.img_dir.img_dir:
fname = f'visual_grounding_det_coco_{int(image_id):012}.npz'
elif isinstance(self.img_dir, ReDetectFeatDir):
fname = f'coco_train2014_{int(image_id):012}.npz'
else:
sys.exit('%s not supported.' % self.img_dir)
img_feat, img_pos_feat, num_bb = self._get_img_feat(fname)
# image info
img = self.Images[image_id]
im_width, im_height = img['width'], img['height']
# object boxes, img_pos_feat (xyxywha) -> xywh
obj_boxes = np.stack([img_pos_feat[:, 0]*im_width,
img_pos_feat[:, 1]*im_height,
img_pos_feat[:, 4]*im_width,
img_pos_feat[:, 5]*im_height], axis=1)
obj_masks = torch.tensor([0]*num_bb).bool()
# target box
tgt_box = np.array(txt_dump['bbox']) # xywh
# text input
input_ids = txt_dump['input_ids']
input_ids = [self.cls_] + input_ids + [self.sep]
attn_masks = [1] * len(input_ids)
position_ids = list(range(len(input_ids)))
attn_masks += [1] * num_bb
input_ids = torch.tensor(input_ids)
position_ids = torch.tensor(position_ids)
attn_masks = torch.tensor(attn_masks)
return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks,
obj_masks, tgt_box, obj_boxes, sent_id)
# IoU function
def computeIoU(self, box1, box2):
# each box is of [x1, y1, w, h]
inter_x1 = max(box1[0], box2[0])
inter_y1 = max(box1[1], box2[1])
inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1)
inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1)
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
else:
inter = 0
union = box1[2]*box1[3] + box2[2]*box2[3] - inter
return float(inter)/union
def re_eval_collate(inputs):
"""
Return:
:input_ids : (n, max_L)
:position_ids : (n, max_L)
:txt_lens : list of [txt_len]
:img_feat : (n, max_num_bb, d)
:img_pos_feat : (n, max_num_bb, 7)
:num_bbs : list of [num_bb]
:attn_masks : (n, max{L+num_bb})
:obj_masks : (n, max_num_bb)
:tgt_box : list of n [xywh]
:obj_boxes : list of n [[xywh, xywh, ...]]
:sent_ids : list of n [sent_id]
"""
(input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks,
tgt_box, obj_boxes, sent_ids) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
num_bbs = [f.size(0) for f in img_feats]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = pad_sequence(position_ids,
batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
obj_masks = pad_sequence(obj_masks,
batch_first=True, padding_value=1).bool()
batch_size = len(img_feats)
num_bb = max(num_bbs)
feat_dim = img_feats[0].size(1)
pos_dim = img_pos_feats[0].size(1)
img_feat = torch.zeros(batch_size, num_bb, feat_dim)
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)):
len_ = im.size(0)
img_feat.data[i, :len_, :] = im.data
img_pos_feat.data[i, :len_, :] = pos.data
return (input_ids, position_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attn_masks, obj_masks, tgt_box, obj_boxes, sent_ids)