logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

125 lines
4.0 KiB

"""
VQA dataset
"""
import torch
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index
def _get_vqa_target(example, num_answers):
target = torch.zeros(num_answers)
labels = example['target']['labels']
scores = example['target']['scores']
if labels and scores:
target.scatter_(0, torch.tensor(labels), torch.tensor(scores))
return target
class VqaDataset(DetectFeatTxtTokDataset):
""" NOTE: This handels distributed inside """
def __init__(self, num_answers, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_answers = num_answers
def __getitem__(self, i):
example = super().__getitem__(i)
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
target = _get_vqa_target(example, self.num_answers)
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return input_ids, img_feat, img_pos_feat, attn_masks, target
def vqa_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks, targets
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
targets = torch.stack(targets, dim=0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'targets': targets}
return batch
class VqaEvalDataset(VqaDataset):
def __getitem__(self, i):
qid = self.ids[i]
example = DetectFeatTxtTokDataset.__getitem__(self, i)
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
if 'target' in example:
target = _get_vqa_target(example, self.num_answers)
else:
target = None
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return qid, input_ids, img_feat, img_pos_feat, attn_masks, target
def vqa_eval_collate(inputs):
(qids, input_ids, img_feats, img_pos_feats, attn_masks, targets
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
if targets[0] is None:
targets = None
else:
targets = torch.stack(targets, dim=0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'qids': qids,
'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'targets': targets}
return batch