""" VQA dataset """ import torch from torch.nn.utils.rnn import pad_sequence from toolz.sandbox import unzip from uniter_model.data.data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index def _get_vqa_target(example, num_answers): target = torch.zeros(num_answers) labels = example['target']['labels'] scores = example['target']['scores'] if labels and scores: target.scatter_(0, torch.tensor(labels), torch.tensor(scores)) return target class VqaDataset(DetectFeatTxtTokDataset): """ NOTE: This handels distributed inside """ def __init__(self, num_answers, *args, **kwargs): super().__init__(*args, **kwargs) self.num_answers = num_answers def __getitem__(self, i): example = super().__getitem__(i) qid = self.ids[i] img_feat, img_pos_feat, num_bb = self._get_img_feat( example['img_fname']) # text input input_ids = example['input_ids'] input_ids = self.txt_db.combine_inputs(input_ids) img_input_ids = torch.Tensor([101]).long() target = _get_vqa_target(example, self.num_answers) attn_masks_txt = torch.ones(len(input_ids), dtype=torch.long) attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) return qid, input_ids, attn_masks_txt, img_input_ids, img_feat, img_pos_feat, attn_masks_img, target def vqa_collate(inputs): (qids, input_ids, attn_masks_txt, img_input_ids, img_feats, img_pos_feats, attn_masks_img, targets ) = map(list, unzip(inputs)) txt_lens = [i.size(0) for i in input_ids] input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long ).unsqueeze(0) attn_masks_txt = pad_sequence(attn_masks_txt, batch_first=True, padding_value=0) attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) targets = torch.stack(targets, dim=0) num_bbs = [f.size(0) for f in img_feats] img_feat = pad_tensors(img_feats, num_bbs) img_pos_feat = pad_tensors(img_pos_feats, num_bbs) img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0) img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) bs, max_tl = input_ids.size() out_size = attn_masks_img.size(1) gather_index_teacher = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) batch = {'qids': qids, 'txts': { 'input_ids': input_ids, 'position_ids': position_ids, 'attention_mask': attn_masks_txt, 'img_feat': None, 'img_pos_feat': None, 'img_masks': None, 'gather_index': None }, 'imgs': { 'input_ids': img_input_ids, 'position_ids': img_position_ids, 'attention_mask': attn_masks_img, 'img_feat': img_feat, 'img_pos_feat': img_pos_feat, 'img_masks': None, 'gather_index': gather_index }, 'gather_index_teacher': gather_index_teacher, 'targets': targets} return batch class VqaEvalDataset(VqaDataset): def __getitem__(self, i): qid = self.ids[i] example = DetectFeatTxtTokDataset.__getitem__(self, i) img_feat, img_pos_feat, num_bb = self._get_img_feat( example['img_fname']) # text input input_ids = example['input_ids'] input_ids = self.txt_db.combine_inputs(input_ids) if 'target' in example: target = _get_vqa_target(example, self.num_answers) else: target = None attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) return qid, input_ids, img_feat, img_pos_feat, attn_masks, target def vqa_eval_collate(inputs): (qids, input_ids, img_feats, img_pos_feats, attn_masks, targets ) = map(list, unzip(inputs)) txt_lens = [i.size(0) for i in input_ids] input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long ).unsqueeze(0) attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) if targets[0] is None: targets = None else: targets = torch.stack(targets, dim=0) num_bbs = [f.size(0) for f in img_feats] img_feat = pad_tensors(img_feats, num_bbs) img_pos_feat = pad_tensors(img_pos_feats, num_bbs) bs, max_tl = input_ids.size() out_size = attn_masks.size(1) gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) batch = {'qids': qids, 'input_ids': input_ids, 'position_ids': position_ids, 'img_feat': img_feat, 'img_pos_feat': img_pos_feat, 'attn_masks': attn_masks, 'gather_index': gather_index, 'targets': targets} return batch