""" VQA dataset """ import torch from torch.nn.utils.rnn import pad_sequence from toolz.sandbox import unzip from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index def _get_vqa_target(example, num_answers): target = torch.zeros(num_answers) labels = example['target']['labels'] scores = example['target']['scores'] if labels and scores: target.scatter_(0, torch.tensor(labels), torch.tensor(scores)) return target class VqaDataset(DetectFeatTxtTokDataset): """ NOTE: This handels distributed inside """ def __init__(self, num_answers, *args, **kwargs): super().__init__(*args, **kwargs) self.num_answers = num_answers def __getitem__(self, i): example = super().__getitem__(i) img_feat, img_pos_feat, num_bb = self._get_img_feat( example['img_fname']) # text input input_ids = example['input_ids'] input_ids = self.txt_db.combine_inputs(input_ids) target = _get_vqa_target(example, self.num_answers) attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) return input_ids, img_feat, img_pos_feat, attn_masks, target def vqa_collate(inputs): (input_ids, img_feats, img_pos_feats, attn_masks, targets ) = map(list, unzip(inputs)) txt_lens = [i.size(0) for i in input_ids] input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long ).unsqueeze(0) attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) targets = torch.stack(targets, dim=0) num_bbs = [f.size(0) for f in img_feats] img_feat = pad_tensors(img_feats, num_bbs) img_pos_feat = pad_tensors(img_pos_feats, num_bbs) bs, max_tl = input_ids.size() out_size = attn_masks.size(1) gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) batch = {'input_ids': input_ids, 'position_ids': position_ids, 'img_feat': img_feat, 'img_pos_feat': img_pos_feat, 'attn_masks': attn_masks, 'gather_index': gather_index, 'targets': targets} return batch class VqaEvalDataset(VqaDataset): def __getitem__(self, i): qid = self.ids[i] example = DetectFeatTxtTokDataset.__getitem__(self, i) img_feat, img_pos_feat, num_bb = self._get_img_feat( example['img_fname']) # text input input_ids = example['input_ids'] input_ids = self.txt_db.combine_inputs(input_ids) if 'target' in example: target = _get_vqa_target(example, self.num_answers) else: target = None attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) return qid, input_ids, img_feat, img_pos_feat, attn_masks, target def vqa_eval_collate(inputs): (qids, input_ids, img_feats, img_pos_feats, attn_masks, targets ) = map(list, unzip(inputs)) txt_lens = [i.size(0) for i in input_ids] input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long ).unsqueeze(0) attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) if targets[0] is None: targets = None else: targets = torch.stack(targets, dim=0) num_bbs = [f.size(0) for f in img_feats] img_feat = pad_tensors(img_feats, num_bbs) img_pos_feat = pad_tensors(img_pos_feats, num_bbs) bs, max_tl = input_ids.size() out_size = attn_masks.size(1) gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) batch = {'qids': qids, 'input_ids': input_ids, 'position_ids': position_ids, 'img_feat': img_feat, 'img_pos_feat': img_pos_feat, 'attn_masks': attn_masks, 'gather_index': gather_index, 'targets': targets} return batch