logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

146 lines
5.1 KiB

"""
VQA dataset
"""
import torch
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from uniter_model.data.data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index
def _get_vqa_target(example, num_answers):
target = torch.zeros(num_answers)
labels = example['target']['labels']
scores = example['target']['scores']
if labels and scores:
target.scatter_(0, torch.tensor(labels), torch.tensor(scores))
return target
class VqaDataset(DetectFeatTxtTokDataset):
""" NOTE: This handels distributed inside """
def __init__(self, num_answers, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_answers = num_answers
def __getitem__(self, i):
example = super().__getitem__(i)
qid = self.ids[i]
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
img_input_ids = torch.Tensor([101]).long()
target = _get_vqa_target(example, self.num_answers)
attn_masks_txt = torch.ones(len(input_ids), dtype=torch.long)
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
return qid, input_ids, attn_masks_txt, img_input_ids, img_feat, img_pos_feat, attn_masks_img, target
def vqa_collate(inputs):
(qids, input_ids, attn_masks_txt, img_input_ids, img_feats, img_pos_feats, attn_masks_img, targets
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks_txt = pad_sequence(attn_masks_txt, batch_first=True, padding_value=0)
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
targets = torch.stack(targets, dim=0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0)
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
bs, max_tl = input_ids.size()
out_size = attn_masks_img.size(1)
gather_index_teacher = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
batch = {'qids': qids,
'txts': {
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attn_masks_txt,
'img_feat': None,
'img_pos_feat': None,
'img_masks': None,
'gather_index': None
},
'imgs': {
'input_ids': img_input_ids,
'position_ids': img_position_ids,
'attention_mask': attn_masks_img,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'img_masks': None,
'gather_index': gather_index
},
'gather_index_teacher': gather_index_teacher,
'targets': targets}
return batch
class VqaEvalDataset(VqaDataset):
def __getitem__(self, i):
qid = self.ids[i]
example = DetectFeatTxtTokDataset.__getitem__(self, i)
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
if 'target' in example:
target = _get_vqa_target(example, self.num_answers)
else:
target = None
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return qid, input_ids, img_feat, img_pos_feat, attn_masks, target
def vqa_eval_collate(inputs):
(qids, input_ids, img_feats, img_pos_feats, attn_masks, targets
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
if targets[0] is None:
targets = None
else:
targets = torch.stack(targets, dim=0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'qids': qids,
'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'targets': targets}
return batch