lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
726 lines
29 KiB
726 lines
29 KiB
2 years ago
|
"""
|
||
|
VCR dataset
|
||
|
"""
|
||
|
import json
|
||
|
import copy
|
||
|
import random
|
||
|
|
||
|
import torch
|
||
|
from torch.nn.utils.rnn import pad_sequence
|
||
|
from toolz.sandbox import unzip
|
||
|
from torch.utils.data import Dataset
|
||
|
|
||
|
from .data import DetectFeatLmdb, TxtLmdb, random_word
|
||
|
from .mrc import DetectFeatDir_for_mrc
|
||
|
|
||
|
|
||
|
class ImageTextDataset(Dataset):
|
||
|
def __init__(self, db_dir, img_dir_gt=None, img_dir=None,
|
||
|
max_txt_len=120, task="qa"):
|
||
|
self.txt_lens = []
|
||
|
self.ids = []
|
||
|
self.task = task
|
||
|
for id_, len_ in json.load(open(f'{db_dir}/id2len_{task}.json')
|
||
|
).items():
|
||
|
if max_txt_len == -1 or len_ <= max_txt_len:
|
||
|
self.txt_lens.append(len_)
|
||
|
self.ids.append(id_)
|
||
|
|
||
|
self.db = TxtLmdb(db_dir, readonly=True)
|
||
|
self.img_dir = img_dir
|
||
|
self.img_dir_gt = img_dir_gt
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.ids)
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
id_ = self.ids[i]
|
||
|
txt_dump = self.db[id_]
|
||
|
img_dump_gt, img_dump = None, None
|
||
|
img_fname_gt, img_fname = txt_dump['img_fname']
|
||
|
if self.img_dump_gt:
|
||
|
img_dump_gt = self.img_dump_gt[img_fname_gt]
|
||
|
if self.img_dir:
|
||
|
img_dump = self.img_dir[img_fname]
|
||
|
return img_dump_gt, img_dump, txt_dump
|
||
|
|
||
|
|
||
|
class DetectFeatBertTokDataset(ImageTextDataset):
|
||
|
def __init__(self, db_dir, img_dir_gt=None, img_dir=None,
|
||
|
max_txt_len=60, task="qa"):
|
||
|
assert not (img_dir_gt is None and img_dir is None),\
|
||
|
"image_dir_gt and img_dir cannot all be None"
|
||
|
assert task == "qa" or task == "qar",\
|
||
|
"VCR only allow two tasks: qa or qar"
|
||
|
assert img_dir_gt is None or isinstance(img_dir_gt, DetectFeatLmdb)
|
||
|
assert img_dir is None or isinstance(img_dir, DetectFeatLmdb)
|
||
|
|
||
|
super().__init__(db_dir, img_dir_gt, img_dir, max_txt_len, task)
|
||
|
txt2img = json.load(open(f'{db_dir}/txt2img.json'))
|
||
|
if self.img_dir and self.img_dir_gt:
|
||
|
self.lens = [tl+self.img_dir_gt.name2nbb[txt2img[id_][0]] +
|
||
|
self.img_dir.name2nbb[txt2img[id_][1]]
|
||
|
for tl, id_ in zip(self.txt_lens, self.ids)]
|
||
|
elif self.img_dir:
|
||
|
self.lens = [tl+self.img_dir.name2nbb[txt2img[id_][1]]
|
||
|
for tl, id_ in zip(self.txt_lens, self.ids)]
|
||
|
else:
|
||
|
self.lens = [tl+self.img_dir_gt.name2nbb[txt2img[id_][0]]
|
||
|
for tl, id_ in zip(self.txt_lens, self.ids)]
|
||
|
|
||
|
meta = json.load(open(f'{db_dir}/meta.json', 'r'))
|
||
|
self.cls_ = meta['CLS']
|
||
|
self.sep = meta['SEP']
|
||
|
self.mask = meta['MASK']
|
||
|
self.v_range = meta['v_range']
|
||
|
|
||
|
def _get_img_feat(self, fname_gt, fname):
|
||
|
if self.img_dir and self.img_dir_gt:
|
||
|
img_feat_gt, bb_gt = self.img_dir_gt[fname_gt]
|
||
|
img_bb_gt = torch.cat([bb_gt, bb_gt[:, 4:5]*bb_gt[:, 5:]], dim=-1)
|
||
|
|
||
|
img_feat, bb = self.img_dir[fname]
|
||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
|
||
|
|
||
|
img_feat = torch.cat([img_feat_gt, img_feat], dim=0)
|
||
|
img_bb = torch.cat([img_bb_gt, img_bb], dim=0)
|
||
|
num_bb = img_feat.size(0)
|
||
|
elif self.img_dir:
|
||
|
img_feat, bb = self.img_dir[fname]
|
||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
|
||
|
num_bb = img_feat.size(0)
|
||
|
elif self.img_dir_gt:
|
||
|
img_feat, bb = self.img_dir_gt[fname_gt]
|
||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
|
||
|
num_bb = img_feat.size(0)
|
||
|
return img_feat, img_bb, num_bb
|
||
|
|
||
|
|
||
|
class VcrDataset(DetectFeatBertTokDataset):
|
||
|
def __init__(self, mask_prob, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.mask_prob = mask_prob
|
||
|
del self.txt_lens
|
||
|
|
||
|
def _get_input_ids(self, txt_dump):
|
||
|
# text input
|
||
|
input_ids_q = txt_dump['input_ids']
|
||
|
type_ids_q = [0]*len(input_ids_q)
|
||
|
input_ids_as = txt_dump['input_ids_as']
|
||
|
if self.task == "qar":
|
||
|
input_ids_rs = txt_dump['input_ids_rs']
|
||
|
answer_label = txt_dump['qa_target']
|
||
|
assert answer_label >= 0, "answer_label < 0"
|
||
|
input_ids_gt_a = [self.sep] + copy.deepcopy(
|
||
|
input_ids_as[answer_label])
|
||
|
type_ids_gt_a = [2] * len(input_ids_gt_a)
|
||
|
type_ids_q += type_ids_gt_a
|
||
|
input_ids_q += input_ids_gt_a
|
||
|
input_ids_for_choices = input_ids_rs
|
||
|
else:
|
||
|
input_ids_for_choices = input_ids_as
|
||
|
return input_ids_q, input_ids_for_choices, type_ids_q
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
id_ = self.ids[i]
|
||
|
txt_dump = self.db[id_]
|
||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(
|
||
|
txt_dump['img_fname'][0], txt_dump['img_fname'][1])
|
||
|
object_targets = txt_dump["object_ids"]
|
||
|
input_ids_q, input_ids_for_choices, type_ids_q = self._get_input_ids(
|
||
|
txt_dump)
|
||
|
label = txt_dump['%s_target' % (self.task)]
|
||
|
|
||
|
choice_num_bbs, choice_img_feats, choice_img_pos_feats = (
|
||
|
[], [], [])
|
||
|
(choice_txt_lens, choice_input_ids, choice_txt_type_ids,
|
||
|
choice_attn_masks, choice_position_ids, choice_targets) = (
|
||
|
[], [], [], [], [], [])
|
||
|
choice_obj_targets, choice_img_masks = ([], [])
|
||
|
|
||
|
for index, input_ids_a in enumerate(input_ids_for_choices):
|
||
|
if index == label:
|
||
|
target = torch.tensor([1]).long()
|
||
|
else:
|
||
|
target = torch.tensor([0]).long()
|
||
|
input_ids = [self.cls_] + copy.deepcopy(input_ids_q) +\
|
||
|
[self.sep] + input_ids_a + [self.sep]
|
||
|
type_id_for_choice = 3 if type_ids_q[-1] == 2 else 2
|
||
|
txt_type_ids = [0] + type_ids_q + [type_id_for_choice]*(
|
||
|
len(input_ids_a)+2)
|
||
|
attn_masks = [1] * len(input_ids)
|
||
|
position_ids = list(range(len(input_ids)))
|
||
|
attn_masks += [1] * num_bb
|
||
|
|
||
|
input_ids = torch.tensor(input_ids)
|
||
|
position_ids = torch.tensor(position_ids)
|
||
|
attn_masks = torch.tensor(attn_masks)
|
||
|
txt_type_ids = torch.tensor(txt_type_ids)
|
||
|
|
||
|
choice_txt_lens.append(len(input_ids))
|
||
|
choice_input_ids.append(input_ids)
|
||
|
choice_attn_masks.append(attn_masks)
|
||
|
choice_position_ids.append(position_ids)
|
||
|
choice_txt_type_ids.append(txt_type_ids)
|
||
|
|
||
|
choice_num_bbs.append(num_bb)
|
||
|
choice_img_feats.append(img_feat)
|
||
|
choice_img_pos_feats.append(img_pos_feat)
|
||
|
choice_targets.append(target)
|
||
|
|
||
|
# mask image input features
|
||
|
num_gt_bb = len(object_targets)
|
||
|
num_det_bb = num_bb - num_gt_bb
|
||
|
# only mask gt features
|
||
|
img_mask = [random.random() < self.mask_prob
|
||
|
for _ in range(num_gt_bb)]
|
||
|
if not any(img_mask):
|
||
|
# at least mask 1
|
||
|
img_mask[0] = True
|
||
|
img_mask += [False]*num_det_bb
|
||
|
img_mask = torch.tensor(img_mask)
|
||
|
object_targets += [0]*num_det_bb
|
||
|
obj_targets = torch.tensor(object_targets)
|
||
|
|
||
|
choice_obj_targets.append(obj_targets)
|
||
|
choice_img_masks.append(img_mask)
|
||
|
|
||
|
return (choice_input_ids, choice_position_ids, choice_txt_lens,
|
||
|
choice_txt_type_ids,
|
||
|
choice_img_feats, choice_img_pos_feats, choice_num_bbs,
|
||
|
choice_attn_masks, choice_targets, choice_obj_targets,
|
||
|
choice_img_masks)
|
||
|
|
||
|
|
||
|
def vcr_collate(inputs):
|
||
|
(input_ids, position_ids, txt_lens, txt_type_ids, img_feats,
|
||
|
img_pos_feats, num_bbs, attn_masks, targets,
|
||
|
obj_targets, img_masks) = map(list, unzip(inputs))
|
||
|
|
||
|
all_num_bbs, all_img_feats, all_img_pos_feats = (
|
||
|
[], [], [])
|
||
|
all_txt_lens, all_input_ids, all_attn_masks,\
|
||
|
all_position_ids, all_txt_type_ids = (
|
||
|
[], [], [], [], [])
|
||
|
all_obj_targets = []
|
||
|
all_targets = []
|
||
|
# all_targets = targets
|
||
|
all_img_masks = []
|
||
|
for i in range(len(num_bbs)):
|
||
|
all_input_ids += input_ids[i]
|
||
|
all_position_ids += position_ids[i]
|
||
|
all_txt_lens += txt_lens[i]
|
||
|
all_txt_type_ids += txt_type_ids[i]
|
||
|
all_img_feats += img_feats[i]
|
||
|
all_img_pos_feats += img_pos_feats[i]
|
||
|
all_num_bbs += num_bbs[i]
|
||
|
all_attn_masks += attn_masks[i]
|
||
|
all_obj_targets += obj_targets[i]
|
||
|
all_img_masks += img_masks[i]
|
||
|
all_targets += targets[i]
|
||
|
|
||
|
all_input_ids = pad_sequence(all_input_ids,
|
||
|
batch_first=True, padding_value=0)
|
||
|
all_position_ids = pad_sequence(all_position_ids,
|
||
|
batch_first=True, padding_value=0)
|
||
|
all_txt_type_ids = pad_sequence(all_txt_type_ids,
|
||
|
batch_first=True, padding_value=0)
|
||
|
all_attn_masks = pad_sequence(all_attn_masks,
|
||
|
batch_first=True, padding_value=0)
|
||
|
all_img_masks = pad_sequence(all_img_masks,
|
||
|
batch_first=True, padding_value=0)
|
||
|
# all_targets = pad_sequence(all_targets,
|
||
|
# batch_first=True, padding_value=0)
|
||
|
all_targets = torch.stack(all_targets, dim=0)
|
||
|
|
||
|
batch_size = len(all_img_feats)
|
||
|
num_bb = max(all_num_bbs)
|
||
|
feat_dim = all_img_feats[0].size(1)
|
||
|
pos_dim = all_img_pos_feats[0].size(1)
|
||
|
all_img_feat = torch.zeros(batch_size, num_bb, feat_dim)
|
||
|
all_img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
|
||
|
all_obj_target = torch.zeros(batch_size, num_bb)
|
||
|
for i, (im, pos, label) in enumerate(zip(
|
||
|
all_img_feats, all_img_pos_feats, all_obj_targets)):
|
||
|
len_ = im.size(0)
|
||
|
all_img_feat.data[i, :len_, :] = im.data
|
||
|
all_img_pos_feat.data[i, :len_, :] = pos.data
|
||
|
all_obj_target.data[i, :len_] = label.data
|
||
|
|
||
|
obj_targets = all_obj_target[all_img_masks].contiguous()
|
||
|
return (all_input_ids, all_position_ids, all_txt_lens,
|
||
|
all_txt_type_ids,
|
||
|
all_img_feat, all_img_pos_feat, all_num_bbs,
|
||
|
all_attn_masks, all_targets, obj_targets, all_img_masks)
|
||
|
|
||
|
|
||
|
class VcrEvalDataset(DetectFeatBertTokDataset):
|
||
|
def __init__(self, split, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.split = split
|
||
|
del self.txt_lens
|
||
|
|
||
|
def _get_input_ids(self, txt_dump):
|
||
|
# text input
|
||
|
input_ids_for_choices = []
|
||
|
type_ids_for_choices = []
|
||
|
input_ids_q = txt_dump['input_ids']
|
||
|
type_ids_q = [0]*len(input_ids_q)
|
||
|
input_ids_as = txt_dump['input_ids_as']
|
||
|
input_ids_rs = txt_dump['input_ids_rs']
|
||
|
for index, input_ids_a in enumerate(input_ids_as):
|
||
|
curr_input_ids_qa = [self.cls_] + copy.deepcopy(input_ids_q) +\
|
||
|
[self.sep] + input_ids_a + [self.sep]
|
||
|
curr_type_ids_qa = [0] + type_ids_q + [2]*(
|
||
|
len(input_ids_a)+2)
|
||
|
input_ids_for_choices.append(curr_input_ids_qa)
|
||
|
type_ids_for_choices.append(curr_type_ids_qa)
|
||
|
for index, input_ids_a in enumerate(input_ids_as):
|
||
|
curr_input_ids_qa = [self.cls_] + copy.deepcopy(input_ids_q) +\
|
||
|
[self.sep] + input_ids_a + [self.sep]
|
||
|
curr_type_ids_qa = [0] + type_ids_q + [2]*(
|
||
|
len(input_ids_a)+1)
|
||
|
if (self.split == "val" and index == txt_dump["qa_target"]) or\
|
||
|
self.split == "test":
|
||
|
for input_ids_r in input_ids_rs:
|
||
|
curr_input_ids_qar = copy.deepcopy(curr_input_ids_qa) +\
|
||
|
input_ids_r + [self.sep]
|
||
|
curr_type_ids_qar = copy.deepcopy(curr_type_ids_qa) +\
|
||
|
[3]*(len(input_ids_r)+2)
|
||
|
input_ids_for_choices.append(curr_input_ids_qar)
|
||
|
type_ids_for_choices.append(curr_type_ids_qar)
|
||
|
return input_ids_for_choices, type_ids_for_choices
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
qid = self.ids[i]
|
||
|
id_ = self.ids[i]
|
||
|
txt_dump = self.db[id_]
|
||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(
|
||
|
txt_dump['img_fname'][0], txt_dump['img_fname'][1])
|
||
|
object_targets = txt_dump["object_ids"]
|
||
|
input_ids_for_choices, type_ids_for_choices = self._get_input_ids(
|
||
|
txt_dump)
|
||
|
qa_target = torch.tensor([int(txt_dump["qa_target"])])
|
||
|
qar_target = torch.tensor([int(txt_dump["qar_target"])])
|
||
|
|
||
|
choice_num_bbs, choice_img_feats, choice_img_pos_feats = (
|
||
|
[], [], [])
|
||
|
(choice_txt_lens, choice_input_ids, choice_attn_masks,
|
||
|
choice_position_ids, choice_txt_type_ids) = (
|
||
|
[], [], [], [], [])
|
||
|
choice_obj_targets = []
|
||
|
for index, input_ids in enumerate(input_ids_for_choices):
|
||
|
txt_type_ids = type_ids_for_choices[index]
|
||
|
attn_masks = [1] * len(input_ids)
|
||
|
position_ids = list(range(len(input_ids)))
|
||
|
attn_masks += [1] * num_bb
|
||
|
|
||
|
input_ids = torch.tensor(input_ids)
|
||
|
position_ids = torch.tensor(position_ids)
|
||
|
attn_masks = torch.tensor(attn_masks)
|
||
|
txt_type_ids = torch.tensor(txt_type_ids)
|
||
|
|
||
|
choice_txt_lens.append(len(input_ids))
|
||
|
choice_input_ids.append(input_ids)
|
||
|
choice_attn_masks.append(attn_masks)
|
||
|
choice_position_ids.append(position_ids)
|
||
|
choice_txt_type_ids.append(txt_type_ids)
|
||
|
|
||
|
choice_num_bbs.append(num_bb)
|
||
|
choice_img_feats.append(img_feat)
|
||
|
choice_img_pos_feats.append(img_pos_feat)
|
||
|
|
||
|
obj_targets = torch.tensor(object_targets)
|
||
|
choice_obj_targets.append(obj_targets)
|
||
|
|
||
|
return (qid, choice_input_ids, choice_position_ids, choice_txt_lens,
|
||
|
choice_txt_type_ids,
|
||
|
choice_img_feats, choice_img_pos_feats, choice_num_bbs,
|
||
|
choice_attn_masks, qa_target, qar_target, choice_obj_targets)
|
||
|
|
||
|
|
||
|
def vcr_eval_collate(inputs):
|
||
|
(qids, input_ids, position_ids, txt_lens, txt_type_ids,
|
||
|
img_feats, img_pos_feats,
|
||
|
num_bbs, attn_masks, qa_targets, qar_targets,
|
||
|
obj_targets) = map(list, unzip(inputs))
|
||
|
|
||
|
all_num_bbs, all_img_feats, all_img_pos_feats = (
|
||
|
[], [], [])
|
||
|
all_txt_lens, all_input_ids, all_attn_masks, all_position_ids,\
|
||
|
all_txt_type_ids = (
|
||
|
[], [], [], [], [])
|
||
|
# all_qa_targets = qa_targets
|
||
|
# all_qar_targets = qar_targets
|
||
|
all_obj_targets = []
|
||
|
for i in range(len(num_bbs)):
|
||
|
all_input_ids += input_ids[i]
|
||
|
all_position_ids += position_ids[i]
|
||
|
all_txt_lens += txt_lens[i]
|
||
|
all_img_feats += img_feats[i]
|
||
|
all_img_pos_feats += img_pos_feats[i]
|
||
|
all_num_bbs += num_bbs[i]
|
||
|
all_attn_masks += attn_masks[i]
|
||
|
all_txt_type_ids += txt_type_ids[i]
|
||
|
all_obj_targets += obj_targets[i]
|
||
|
|
||
|
all_input_ids = pad_sequence(all_input_ids,
|
||
|
batch_first=True, padding_value=0)
|
||
|
all_position_ids = pad_sequence(all_position_ids,
|
||
|
batch_first=True, padding_value=0)
|
||
|
all_txt_type_ids = pad_sequence(all_txt_type_ids,
|
||
|
batch_first=True, padding_value=0)
|
||
|
all_attn_masks = pad_sequence(all_attn_masks,
|
||
|
batch_first=True, padding_value=0)
|
||
|
all_obj_targets = pad_sequence(all_obj_targets,
|
||
|
batch_first=True, padding_value=0)
|
||
|
all_qa_targets = torch.stack(qa_targets, dim=0)
|
||
|
all_qar_targets = torch.stack(qar_targets, dim=0)
|
||
|
|
||
|
batch_size = len(all_img_feats)
|
||
|
num_bb = max(all_num_bbs)
|
||
|
feat_dim = all_img_feats[0].size(1)
|
||
|
pos_dim = all_img_pos_feats[0].size(1)
|
||
|
all_img_feat = torch.zeros(batch_size, num_bb, feat_dim)
|
||
|
all_img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
|
||
|
for i, (im, pos) in enumerate(zip(
|
||
|
all_img_feats, all_img_pos_feats)):
|
||
|
len_ = im.size(0)
|
||
|
all_img_feat.data[i, :len_, :] = im.data
|
||
|
all_img_pos_feat.data[i, :len_, :] = pos.data
|
||
|
|
||
|
return (qids, all_input_ids, all_position_ids, all_txt_lens,
|
||
|
all_txt_type_ids,
|
||
|
all_img_feat, all_img_pos_feat, all_num_bbs,
|
||
|
all_attn_masks, all_qa_targets, all_qar_targets, all_obj_targets)
|
||
|
|
||
|
|
||
|
class MlmDatasetForVCR(DetectFeatBertTokDataset):
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
del self.txt_lens
|
||
|
|
||
|
def _get_input_ids(self, txt_dump, mask=True):
|
||
|
# text input
|
||
|
input_ids_q = txt_dump['input_ids']
|
||
|
type_ids_q = [0]*len(input_ids_q)
|
||
|
if mask:
|
||
|
input_ids_q, txt_labels_q = random_word(
|
||
|
input_ids_q, self.v_range, self.mask)
|
||
|
else:
|
||
|
txt_labels_q = input_ids_q
|
||
|
|
||
|
answer_label = txt_dump['qa_target']
|
||
|
assert answer_label >= 0, "answer_label < 0"
|
||
|
|
||
|
input_ids_a = txt_dump['input_ids_as'][answer_label]
|
||
|
type_ids_a = [2]*len(input_ids_a)
|
||
|
if mask:
|
||
|
input_ids_a, txt_labels_a = random_word(
|
||
|
input_ids_a, self.v_range, self.mask)
|
||
|
else:
|
||
|
txt_labels_a = input_ids_a
|
||
|
|
||
|
input_ids = input_ids_q + [self.sep] + input_ids_a
|
||
|
type_ids = type_ids_q + [0] + type_ids_a
|
||
|
txt_labels = txt_labels_q + [-1] + txt_labels_a
|
||
|
|
||
|
if self.task == "qar":
|
||
|
rationale_label = txt_dump['qar_target']
|
||
|
assert rationale_label >= 0, "rationale_label < 0"
|
||
|
|
||
|
input_ids_r = txt_dump['input_ids_rs'][rationale_label]
|
||
|
type_ids_r = [3]*len(input_ids_r)
|
||
|
if mask:
|
||
|
input_ids_r, txt_labels_r = random_word(
|
||
|
input_ids_r, self.v_range, self.mask)
|
||
|
else:
|
||
|
txt_labels_r = input_ids_r
|
||
|
|
||
|
input_ids += [self.sep] + input_ids_r
|
||
|
type_ids += [2] + type_ids_r
|
||
|
txt_labels += [-1] + txt_labels_r
|
||
|
return input_ids, type_ids, txt_labels
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
id_ = self.ids[i]
|
||
|
txt_dump = self.db[id_]
|
||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(
|
||
|
txt_dump['img_fname'][0], txt_dump['img_fname'][1])
|
||
|
|
||
|
# txt inputs
|
||
|
input_ids, type_ids, txt_labels = self._get_input_ids(txt_dump)
|
||
|
input_ids = [self.cls_] + input_ids + [self.sep]
|
||
|
txt_labels = [-1] + txt_labels + [-1]
|
||
|
type_ids = [type_ids[0]] + type_ids + [type_ids[-1]]
|
||
|
attn_masks = [1] * len(input_ids)
|
||
|
position_ids = list(range(len(input_ids)))
|
||
|
attn_masks += [1] * num_bb
|
||
|
input_ids = torch.tensor(input_ids)
|
||
|
position_ids = torch.tensor(position_ids)
|
||
|
attn_masks = torch.tensor(attn_masks)
|
||
|
txt_labels = torch.tensor(txt_labels)
|
||
|
type_ids = torch.tensor(type_ids)
|
||
|
|
||
|
return (input_ids, position_ids, type_ids, img_feat, img_pos_feat,
|
||
|
attn_masks, txt_labels)
|
||
|
|
||
|
|
||
|
def mlm_collate_for_vcr(inputs):
|
||
|
(input_ids, position_ids, type_ids, img_feats, img_pos_feats, attn_masks,
|
||
|
txt_labels) = map(list, unzip(inputs))
|
||
|
|
||
|
txt_lens = [i.size(0) for i in input_ids]
|
||
|
num_bbs = [f.size(0) for f in img_feats]
|
||
|
|
||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
||
|
type_ids = pad_sequence(type_ids, batch_first=True, padding_value=0)
|
||
|
position_ids = pad_sequence(position_ids,
|
||
|
batch_first=True, padding_value=0)
|
||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
|
||
|
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
|
||
|
|
||
|
batch_size = len(img_feats)
|
||
|
num_bb = max(num_bbs)
|
||
|
feat_dim = img_feats[0].size(1)
|
||
|
pos_dim = img_pos_feats[0].size(1)
|
||
|
img_feat = torch.zeros(batch_size, num_bb, feat_dim)
|
||
|
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
|
||
|
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)):
|
||
|
len_ = im.size(0)
|
||
|
img_feat.data[i, :len_, :] = im.data
|
||
|
img_pos_feat.data[i, :len_, :] = pos.data
|
||
|
|
||
|
return (input_ids, position_ids, type_ids, txt_lens,
|
||
|
img_feat, img_pos_feat, num_bbs,
|
||
|
attn_masks, txt_labels)
|
||
|
|
||
|
|
||
|
class MrmDatasetForVCR(DetectFeatBertTokDataset):
|
||
|
def __init__(self, mask_prob, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.mask_prob = mask_prob
|
||
|
del self.txt_lens
|
||
|
|
||
|
def _get_input_ids(self, txt_dump, mask=True):
|
||
|
# text input
|
||
|
input_ids_q = txt_dump['input_ids']
|
||
|
type_ids_q = [0]*len(input_ids_q)
|
||
|
|
||
|
answer_label = txt_dump['qa_target']
|
||
|
assert answer_label >= 0, "answer_label < 0"
|
||
|
|
||
|
input_ids_a = txt_dump['input_ids_as'][answer_label]
|
||
|
type_ids_a = [2]*len(input_ids_a)
|
||
|
|
||
|
input_ids = input_ids_q + [self.sep] + input_ids_a
|
||
|
type_ids = type_ids_q + [0] + type_ids_a
|
||
|
|
||
|
if self.task == "qar":
|
||
|
rationale_label = txt_dump['qar_target']
|
||
|
assert rationale_label >= 0, "rationale_label < 0"
|
||
|
|
||
|
input_ids_r = txt_dump['input_ids_rs'][rationale_label]
|
||
|
type_ids_r = [3]*len(input_ids_r)
|
||
|
|
||
|
input_ids += [self.sep] + input_ids_r
|
||
|
type_ids += [2] + type_ids_r
|
||
|
return input_ids, type_ids
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
id_ = self.ids[i]
|
||
|
txt_dump = self.db[id_]
|
||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(
|
||
|
txt_dump['img_fname'][0], txt_dump['img_fname'][1])
|
||
|
|
||
|
# image input features
|
||
|
img_mask = [random.random() < self.mask_prob for _ in range(num_bb)]
|
||
|
if not any(img_mask):
|
||
|
# at least mask 1
|
||
|
img_mask[0] = True
|
||
|
img_mask = torch.tensor(img_mask)
|
||
|
|
||
|
# text input
|
||
|
input_ids, type_ids = self._get_input_ids(txt_dump)
|
||
|
input_ids = [self.cls_] + input_ids + [self.sep]
|
||
|
type_ids = [type_ids[0]] + type_ids + [type_ids[-1]]
|
||
|
attn_masks = [1] * len(input_ids)
|
||
|
position_ids = list(range(len(input_ids)))
|
||
|
attn_masks += [1] * num_bb
|
||
|
input_ids = torch.tensor(input_ids)
|
||
|
position_ids = torch.tensor(position_ids)
|
||
|
attn_masks = torch.tensor(attn_masks)
|
||
|
type_ids = torch.tensor(type_ids)
|
||
|
|
||
|
return (input_ids, position_ids, type_ids, img_feat, img_pos_feat,
|
||
|
attn_masks, img_mask)
|
||
|
|
||
|
|
||
|
def mrm_collate_for_vcr(inputs):
|
||
|
(input_ids, position_ids, type_ids, img_feats, img_pos_feats,
|
||
|
attn_masks, img_masks) = map(list, unzip(inputs))
|
||
|
|
||
|
txt_lens = [i.size(0) for i in input_ids]
|
||
|
num_bbs = [f.size(0) for f in img_feats]
|
||
|
|
||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
||
|
position_ids = pad_sequence(position_ids,
|
||
|
batch_first=True, padding_value=0)
|
||
|
type_ids = pad_sequence(type_ids, batch_first=True, padding_value=0)
|
||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
|
||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
|
||
|
|
||
|
batch_size = len(img_feats)
|
||
|
num_bb = max(num_bbs)
|
||
|
feat_dim = img_feats[0].size(1)
|
||
|
pos_dim = img_pos_feats[0].size(1)
|
||
|
img_feat = torch.zeros(batch_size, num_bb, feat_dim)
|
||
|
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
|
||
|
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)):
|
||
|
len_ = im.size(0)
|
||
|
img_feat.data[i, :len_, :] = im.data
|
||
|
img_pos_feat.data[i, :len_, :] = pos.data
|
||
|
|
||
|
return (input_ids, position_ids, type_ids, txt_lens,
|
||
|
img_feat, img_pos_feat, num_bbs,
|
||
|
attn_masks, img_masks)
|
||
|
|
||
|
|
||
|
class DetectFeatBertTokDataset_for_mrc_vcr(DetectFeatBertTokDataset):
|
||
|
def __init__(self, db_dir, img_dir_gt=None, img_dir=None,
|
||
|
max_txt_len=60, task="qa"):
|
||
|
assert not (img_dir_gt is None and img_dir is None),\
|
||
|
"image_dir_gt and img_dir cannot all be None"
|
||
|
assert task == "qa" or task == "qar",\
|
||
|
"VCR only allow two tasks: qa or qar"
|
||
|
assert img_dir_gt is None or isinstance(img_dir_gt, DetectFeatLmdb)
|
||
|
assert img_dir is None or isinstance(img_dir, DetectFeatLmdb)
|
||
|
super().__init__(db_dir, img_dir_gt, img_dir, max_txt_len, task)
|
||
|
if self.img_dir:
|
||
|
self.img_dir = DetectFeatDir_for_mrc(img_dir)
|
||
|
if self.img_dir_gt:
|
||
|
self.img_dir_gt = DetectFeatDir_for_mrc(img_dir_gt)
|
||
|
|
||
|
def _get_img_feat(self, fname_gt, fname):
|
||
|
if self.img_dir and self.img_dir_gt:
|
||
|
img_feat_gt, bb_gt,\
|
||
|
img_soft_labels_gt = self.img_dir_gt[fname_gt]
|
||
|
img_bb_gt = torch.cat([bb_gt, bb_gt[:, 4:5]*bb_gt[:, 5:]], dim=-1)
|
||
|
|
||
|
img_feat, bb, img_soft_labels = self.img_dir[fname]
|
||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
|
||
|
|
||
|
img_feat = torch.cat([img_feat_gt, img_feat], dim=0)
|
||
|
img_bb = torch.cat([img_bb_gt, img_bb], dim=0)
|
||
|
img_soft_labels = torch.cat(
|
||
|
[img_soft_labels_gt, img_soft_labels], dim=0)
|
||
|
num_bb = img_feat.size(0)
|
||
|
elif self.img_dir:
|
||
|
img_feat, bb, img_soft_labels = self.img_dir[fname]
|
||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
|
||
|
num_bb = img_feat.size(0)
|
||
|
elif self.img_dir_gt:
|
||
|
img_feat, bb, img_soft_labels = self.img_dir_gt[fname_gt]
|
||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
|
||
|
num_bb = img_feat.size(0)
|
||
|
return img_feat, img_bb, img_soft_labels, num_bb
|
||
|
|
||
|
|
||
|
class MrcDatasetForVCR(DetectFeatBertTokDataset_for_mrc_vcr):
|
||
|
def __init__(self, mask_prob, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.mask_prob = mask_prob
|
||
|
del self.txt_lens
|
||
|
|
||
|
def _get_input_ids(self, txt_dump, mask=True):
|
||
|
# text input
|
||
|
input_ids_q = txt_dump['input_ids']
|
||
|
type_ids_q = [0]*len(input_ids_q)
|
||
|
|
||
|
answer_label = txt_dump['qa_target']
|
||
|
assert answer_label >= 0, "answer_label < 0"
|
||
|
|
||
|
input_ids_a = txt_dump['input_ids_as'][answer_label]
|
||
|
type_ids_a = [2]*len(input_ids_a)
|
||
|
|
||
|
input_ids = input_ids_q + [self.sep] + input_ids_a
|
||
|
type_ids = type_ids_q + [0] + type_ids_a
|
||
|
|
||
|
if self.task == "qar":
|
||
|
rationale_label = txt_dump['qar_target']
|
||
|
assert rationale_label >= 0, "rationale_label < 0"
|
||
|
|
||
|
input_ids_r = txt_dump['input_ids_rs'][rationale_label]
|
||
|
type_ids_r = [3]*len(input_ids_r)
|
||
|
|
||
|
input_ids += [self.sep] + input_ids_r
|
||
|
type_ids += [2] + type_ids_r
|
||
|
return input_ids, type_ids
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
id_ = self.ids[i]
|
||
|
txt_dump = self.db[id_]
|
||
|
img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat(
|
||
|
txt_dump['img_fname'][0], txt_dump['img_fname'][1])
|
||
|
|
||
|
# image input features
|
||
|
img_mask = [random.random() < self.mask_prob for _ in range(num_bb)]
|
||
|
if not any(img_mask):
|
||
|
# at least mask 1
|
||
|
img_mask[0] = True
|
||
|
img_mask = torch.tensor(img_mask)
|
||
|
|
||
|
# text input
|
||
|
input_ids, type_ids = self._get_input_ids(txt_dump)
|
||
|
input_ids = [self.cls_] + input_ids + [self.sep]
|
||
|
type_ids = [type_ids[0]] + type_ids + [type_ids[-1]]
|
||
|
attn_masks = [1] * len(input_ids)
|
||
|
position_ids = list(range(len(input_ids)))
|
||
|
attn_masks += [1] * num_bb
|
||
|
input_ids = torch.tensor(input_ids)
|
||
|
position_ids = torch.tensor(position_ids)
|
||
|
attn_masks = torch.tensor(attn_masks)
|
||
|
type_ids = torch.tensor(type_ids)
|
||
|
|
||
|
return (input_ids, position_ids, type_ids, img_feat, img_pos_feat,
|
||
|
img_soft_labels, attn_masks, img_mask)
|
||
|
|
||
|
|
||
|
def mrc_collate_for_vcr(inputs):
|
||
|
(input_ids, position_ids, type_ids, img_feats, img_pos_feats,
|
||
|
img_soft_labels, attn_masks, img_masks
|
||
|
) = map(list, unzip(inputs))
|
||
|
|
||
|
txt_lens = [i.size(0) for i in input_ids]
|
||
|
num_bbs = [f.size(0) for f in img_feats]
|
||
|
|
||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
||
|
position_ids = pad_sequence(position_ids,
|
||
|
batch_first=True, padding_value=0)
|
||
|
type_ids = pad_sequence(type_ids, batch_first=True, padding_value=0)
|
||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
|
||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
|
||
|
|
||
|
batch_size = len(img_feats)
|
||
|
num_bb = max(num_bbs)
|
||
|
feat_dim = img_feats[0].size(1)
|
||
|
soft_label_dim = img_soft_labels[0].size(1)
|
||
|
pos_dim = img_pos_feats[0].size(1)
|
||
|
img_feat = torch.zeros(batch_size, num_bb, feat_dim)
|
||
|
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
|
||
|
img_soft_label = torch.zeros(batch_size, num_bb, soft_label_dim)
|
||
|
for i, (im, pos, label) in enumerate(zip(img_feats,
|
||
|
img_pos_feats,
|
||
|
img_soft_labels)):
|
||
|
len_ = im.size(0)
|
||
|
img_feat.data[i, :len_, :] = im.data
|
||
|
img_pos_feat.data[i, :len_, :] = pos.data
|
||
|
img_soft_label.data[i, :len_, :] = label.data
|
||
|
|
||
|
img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label)
|
||
|
label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view(
|
||
|
-1, soft_label_dim)
|
||
|
return (input_ids, position_ids, type_ids, txt_lens,
|
||
|
img_feat, img_pos_feat, num_bbs,
|
||
|
attn_masks, (img_masks, label_targets))
|