lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
288 lines
10 KiB
288 lines
10 KiB
2 years ago
|
"""
|
||
|
MRM Datasets
|
||
|
"""
|
||
|
import random
|
||
|
|
||
|
import torch
|
||
|
from torch.utils.data import Dataset
|
||
|
from torch.nn.utils.rnn import pad_sequence
|
||
|
from toolz.sandbox import unzip
|
||
|
from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index
|
||
|
|
||
|
|
||
|
def _get_img_mask(mask_prob, num_bb):
|
||
|
img_mask = [random.random() < mask_prob for _ in range(num_bb)]
|
||
|
if not any(img_mask):
|
||
|
# at least mask 1
|
||
|
img_mask[random.choice(range(num_bb))] = True
|
||
|
img_mask = torch.tensor(img_mask)
|
||
|
return img_mask
|
||
|
|
||
|
|
||
|
def _get_img_tgt_mask(img_mask, txt_len):
|
||
|
z = torch.zeros(txt_len, dtype=torch.bool)
|
||
|
img_mask_tgt = torch.cat([z, img_mask], dim=0)
|
||
|
return img_mask_tgt
|
||
|
|
||
|
|
||
|
def _get_feat_target(img_feat, img_masks):
|
||
|
img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) # (n, m, d)
|
||
|
feat_dim = img_feat.size(-1)
|
||
|
feat_targets = img_feat[img_masks_ext].contiguous().view(
|
||
|
-1, feat_dim) # (s, d)
|
||
|
return feat_targets
|
||
|
|
||
|
|
||
|
def _mask_img_feat(img_feat, img_masks):
|
||
|
img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat)
|
||
|
img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0)
|
||
|
return img_feat_masked
|
||
|
|
||
|
|
||
|
class MrfrDataset(DetectFeatTxtTokDataset):
|
||
|
def __init__(self, mask_prob, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.mask_prob = mask_prob
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
"""
|
||
|
Return:
|
||
|
- input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded
|
||
|
- img_feat : (num_bb, d)
|
||
|
- img_pos_feat : (num_bb, 7)
|
||
|
- attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1]
|
||
|
- img_mask : (num_bb, ) between {0, 1}
|
||
|
"""
|
||
|
example = super().__getitem__(i)
|
||
|
# text input
|
||
|
input_ids = example['input_ids']
|
||
|
input_ids = self.txt_db.combine_inputs(input_ids)
|
||
|
|
||
|
# image input features
|
||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(
|
||
|
example['img_fname'])
|
||
|
img_mask = _get_img_mask(self.mask_prob, num_bb)
|
||
|
img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids))
|
||
|
|
||
|
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
|
||
|
|
||
|
return (input_ids, img_feat, img_pos_feat,
|
||
|
attn_masks, img_mask, img_mask_tgt)
|
||
|
|
||
|
|
||
|
def mrfr_collate(inputs):
|
||
|
"""
|
||
|
Return:
|
||
|
- input_ids : (n, max_L), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded
|
||
|
- position_ids : (n, max_L)
|
||
|
- txt_lens : list of [input_len]
|
||
|
- img_feat : (n, max_num_bb, d)
|
||
|
- img_pos_feat : (n, max_num_bb, 7)
|
||
|
- num_bbs : list of [num_bb]
|
||
|
- attn_masks : (n, max_{L + num_bb}), ie., [1, 1, ..., 0, 0, 1, 1]
|
||
|
- img_masks : (n, max_num_bb) between {0, 1}
|
||
|
"""
|
||
|
(input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts,
|
||
|
) = map(list, unzip(inputs))
|
||
|
|
||
|
txt_lens = [i.size(0) for i in input_ids]
|
||
|
|
||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
|
||
|
).unsqueeze(0)
|
||
|
|
||
|
num_bbs = [f.size(0) for f in img_feats]
|
||
|
img_feat = pad_tensors(img_feats, num_bbs)
|
||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
|
||
|
|
||
|
# mask features
|
||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
|
||
|
feat_targets = _get_feat_target(img_feat, img_masks)
|
||
|
img_feat = _mask_img_feat(img_feat, img_masks)
|
||
|
img_mask_tgt = pad_sequence(img_mask_tgts,
|
||
|
batch_first=True, padding_value=0)
|
||
|
|
||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
|
||
|
bs, max_tl = input_ids.size()
|
||
|
out_size = attn_masks.size(1)
|
||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
|
||
|
|
||
|
batch = {'input_ids': input_ids,
|
||
|
'position_ids': position_ids,
|
||
|
'img_feat': img_feat,
|
||
|
'img_pos_feat': img_pos_feat,
|
||
|
'attn_masks': attn_masks,
|
||
|
'gather_index': gather_index,
|
||
|
'feat_targets': feat_targets,
|
||
|
'img_masks': img_masks,
|
||
|
'img_mask_tgt': img_mask_tgt}
|
||
|
return batch
|
||
|
|
||
|
|
||
|
class OnlyImgMrfrDataset(Dataset):
|
||
|
""" an image-only MRM """
|
||
|
def __init__(self, mask_prob, img_db):
|
||
|
self.ids, self.lens = map(list, unzip(self.img_db.name2nbb.items()))
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
id_ = self.ids[i]
|
||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(id_)
|
||
|
attn_masks = torch.ones(num_bb, dtype=torch.long)
|
||
|
img_mask = _get_img_mask(self.mask_prob, num_bb)
|
||
|
|
||
|
return img_feat, img_pos_feat, attn_masks, img_mask
|
||
|
|
||
|
def _get_img_feat(self, fname):
|
||
|
img_feat, bb = self.img_db[fname]
|
||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
|
||
|
num_bb = img_feat.size(0)
|
||
|
return img_feat, img_bb, num_bb
|
||
|
|
||
|
|
||
|
def mrfr_only_img_collate(inputs):
|
||
|
img_feats, img_pos_feats, attn_masks, img_masks = map(list, unzip(inputs))
|
||
|
|
||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
|
||
|
|
||
|
num_bbs = [f.size(0) for f in img_feats]
|
||
|
img_feat = pad_tensors(img_feats, num_bbs)
|
||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
|
||
|
|
||
|
# mask features
|
||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
|
||
|
feat_targets = _get_feat_target(img_feat, img_masks)
|
||
|
img_feat = _mask_img_feat(img_feat, img_masks)
|
||
|
|
||
|
batch = {'img_feat': img_feat,
|
||
|
'img_pos_feat': img_pos_feat,
|
||
|
'attn_masks': attn_masks,
|
||
|
'feat_targets': feat_targets,
|
||
|
'img_masks': img_masks,
|
||
|
'img_mask_tgt': img_masks}
|
||
|
return batch
|
||
|
|
||
|
|
||
|
def _get_targets(img_masks, img_soft_label):
|
||
|
soft_label_dim = img_soft_label.size(-1)
|
||
|
img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label)
|
||
|
label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view(
|
||
|
-1, soft_label_dim)
|
||
|
return label_targets
|
||
|
|
||
|
|
||
|
class MrcDataset(DetectFeatTxtTokDataset):
|
||
|
def __init__(self, mask_prob, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.mask_prob = mask_prob
|
||
|
|
||
|
def _get_img_feat(self, fname):
|
||
|
img_dump = self.img_db.get_dump(fname)
|
||
|
num_bb = self.img_db.name2nbb[fname]
|
||
|
img_feat = torch.tensor(img_dump['features'])
|
||
|
bb = torch.tensor(img_dump['norm_bb'])
|
||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
|
||
|
img_soft_label = torch.tensor(img_dump['soft_labels'])
|
||
|
return img_feat, img_bb, img_soft_label, num_bb
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
example = super().__getitem__(i)
|
||
|
img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat(
|
||
|
example['img_fname'])
|
||
|
|
||
|
# image input features
|
||
|
img_mask = _get_img_mask(self.mask_prob, num_bb)
|
||
|
|
||
|
# text input
|
||
|
input_ids = example['input_ids']
|
||
|
input_ids = self.txt_db.combine_inputs(input_ids)
|
||
|
img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids))
|
||
|
|
||
|
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
|
||
|
|
||
|
return (input_ids, img_feat, img_pos_feat,
|
||
|
img_soft_labels, attn_masks, img_mask, img_mask_tgt)
|
||
|
|
||
|
|
||
|
def mrc_collate(inputs):
|
||
|
(input_ids, img_feats, img_pos_feats, img_soft_labels,
|
||
|
attn_masks, img_masks, img_mask_tgts) = map(list, unzip(inputs))
|
||
|
|
||
|
txt_lens = [i.size(0) for i in input_ids]
|
||
|
num_bbs = [f.size(0) for f in img_feats]
|
||
|
|
||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
|
||
|
).unsqueeze(0)
|
||
|
|
||
|
img_feat = pad_tensors(img_feats, num_bbs)
|
||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
|
||
|
img_soft_label = pad_tensors(img_soft_labels, num_bbs)
|
||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
|
||
|
label_targets = _get_targets(img_masks, img_soft_label)
|
||
|
|
||
|
img_feat = _mask_img_feat(img_feat, img_masks)
|
||
|
img_mask_tgt = pad_sequence(img_mask_tgts,
|
||
|
batch_first=True, padding_value=0)
|
||
|
|
||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
|
||
|
bs, max_tl = input_ids.size()
|
||
|
out_size = attn_masks.size(1)
|
||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
|
||
|
|
||
|
batch = {'input_ids': input_ids,
|
||
|
'position_ids': position_ids,
|
||
|
'img_feat': img_feat,
|
||
|
'img_pos_feat': img_pos_feat,
|
||
|
'attn_masks': attn_masks,
|
||
|
'gather_index': gather_index,
|
||
|
'img_masks': img_masks,
|
||
|
'img_mask_tgt': img_mask_tgt,
|
||
|
'label_targets': label_targets}
|
||
|
return batch
|
||
|
|
||
|
|
||
|
class OnlyImgMrcDataset(OnlyImgMrfrDataset):
|
||
|
""" an image-only MRC """
|
||
|
def __getitem__(self, i):
|
||
|
id_ = self.ids[i]
|
||
|
(img_feat, img_pos_feat, img_soft_labels, num_bb
|
||
|
) = self._get_img_feat(id_)
|
||
|
attn_masks = torch.ones(num_bb, dtype=torch.long)
|
||
|
img_mask = _get_img_mask(self.mask_prob, num_bb)
|
||
|
|
||
|
return img_feat, img_pos_feat, img_soft_labels, attn_masks, img_mask
|
||
|
|
||
|
def _get_img_feat(self, fname):
|
||
|
img_dump = self.img_db.get_dump(fname)
|
||
|
num_bb = self.img_db.name2nbb[fname]
|
||
|
img_feat = torch.tensor(img_dump['features'])
|
||
|
bb = torch.tensor(img_dump['norm_bb'])
|
||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
|
||
|
img_soft_labels = torch.tensor(img_dump['soft_labels'])
|
||
|
return img_feat, img_bb, img_soft_labels, num_bb
|
||
|
|
||
|
|
||
|
def mrc_only_img_collate(inputs):
|
||
|
(img_feats, img_pos_feats, img_soft_labels, attn_masks, img_masks
|
||
|
) = map(list, unzip(inputs))
|
||
|
|
||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
|
||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
|
||
|
num_bbs = [f.size(0) for f in img_feats]
|
||
|
|
||
|
img_feat = pad_tensors(img_feats, num_bbs)
|
||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
|
||
|
img_soft_label = pad_tensors(img_soft_labels, num_bbs)
|
||
|
label_targets = _get_targets(img_masks, img_soft_label)
|
||
|
|
||
|
# mask features
|
||
|
img_feat = _mask_img_feat(img_feat, img_masks)
|
||
|
|
||
|
batch = {'img_feat': img_feat,
|
||
|
'img_pos_feat': img_pos_feat,
|
||
|
'attn_masks': attn_masks,
|
||
|
'img_masks': img_masks,
|
||
|
'img_mask_tgt': img_masks,
|
||
|
'label_targets': label_targets}
|
||
|
return batch
|