logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

391 lines
13 KiB

"""
MLM datasets
"""
import math
import random
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from uniter_model.data.data import (DetectFeatTxtTokDataset, TxtTokLmdb, get_ids_and_lens, pad_tensors,
get_gather_index, get_gather_index_uniter)
def random_word(tokens, vocab_range, mask):
"""
Masking some random tokens for Language Model task with probabilities as in
the original BERT paper.
:param tokens: list of int, tokenized sentence.
:param vocab_range: for choosing a random word
:return: (list of int, list of int), masked tokens and related labels for
LM prediction
"""
output_label = []
for i, token in enumerate(tokens):
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
prob /= 0.15
# 80% randomly change token to mask token
if prob < 0.8:
tokens[i] = mask
# 10% randomly change token to random token
elif prob < 0.9:
tokens[i] = random.choice(list(range(*vocab_range)))
# -> rest 10% randomly keep current token
# append current token to output (we will predict these later)
output_label.append(token)
else:
# no masking token (will be ignored by loss function later)
output_label.append(-1)
if all(o == -1 for o in output_label):
# at least mask 1
output_label[0] = tokens[0]
tokens[0] = mask
return tokens, output_label
class MlmDataset(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db):
assert isinstance(txt_db, TxtTokLmdb)
super().__init__(txt_db, img_db)
def __getitem__(self, i):
"""
Return:
- input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded
- img_feat : (num_bb, d)
- img_pos_feat : (num_bb, 7)
- attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1]
- txt_labels : (L, ), [-1, -1, wid, -1, -1, -1]
0's padded so that (L + num_bb) % 8 == 0
"""
example = super().__getitem__(i)
# text input
input_ids, txt_labels = self.create_mlm_io(example['input_ids'])
# img input
img_input_ids = torch.Tensor([101]).long()
img_feat, img_pos_feat, num_bb = self._get_img_feat(example['img_fname'])
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
attn_masks_teacher = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, txt_labels, attn_masks_teacher
def create_mlm_io(self, input_ids):
input_ids, txt_labels = random_word(input_ids,
self.txt_db.v_range,
self.txt_db.mask)
input_ids = torch.tensor([self.txt_db.cls_]
+ input_ids
+ [self.txt_db.sep])
txt_labels = torch.tensor([-1] + txt_labels + [-1])
return input_ids, txt_labels
def mlm_collate(inputs):
"""
Return:
:input_ids (n, max_L) padded with 0
:position_ids (n, max_L) padded with 0
:txt_lens list of [txt_len]
:img_feat (n, max_num_bb, feat_dim)
:img_pos_feat (n, max_num_bb, 7)
:num_bbs list of [num_bb]
:attn_masks (n, max_{L + num_bb}) padded with 0
:txt_labels (n, max_L) padded with -1
"""
(input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, txt_labels, attn_masks_teacher
) = map(list, unzip(inputs))
# text batches
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
# image batches
num_bbs = [f.size(0) for f in img_feats]
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0)
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
bs, max_tl = input_ids.size()
out_size = attn_masks_img.size(1)
# gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
attn_masks_teacher = pad_sequence(attn_masks_teacher, batch_first=True, padding_value=0)
gather_index_teacher = get_gather_index_uniter(txt_lens, num_bbs, bs, max_tl, attn_masks_teacher.size(1))
batch = {
'txts': {
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attn_masks,
'img_feat': None,
'img_pos_feat': None,
'img_masks': None,
'gather_index': None
},
'imgs': {
'input_ids': img_input_ids,
'position_ids': img_position_ids,
'attention_mask': attn_masks_img,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'img_masks': None,
'gather_index': gather_index
},
'txt_labels': txt_labels,
'teacher': {
'txt_lens': txt_lens,
'num_bbs': num_bbs,
'bs': bs,
'max_tl': max_tl,
'out_size': out_size,
'gather_index': gather_index_teacher,
'attn_masks': attn_masks_teacher
}
}
return batch
class BlindMlmDataset(Dataset):
def __init__(self, txt_db):
assert isinstance(txt_db, TxtTokLmdb)
self.txt_db = txt_db
self.lens, self.ids = get_ids_and_lens(txt_db)
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
id_ = self.ids[i]
example = self.txt_db[id_]
input_ids, txt_labels = self.create_mlm_io(example['input_ids'])
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
return input_ids, attn_masks, txt_labels
def mlm_blind_collate(inputs):
input_ids, attn_masks, txt_labels = map(list, unzip(inputs))
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'attn_masks': attn_masks,
'txt_labels': txt_labels}
return batch
def eval_mask(len_, num_samples=7):
""" build the mask for evaluating MLM
circularly mask 1 word out of every x words
"""
# build the random masks
if len_ <= num_samples:
masks = torch.eye(len_).bool()
num_samples = len_
else:
mask_inds = [list(range(i, len_, num_samples))
for i in range(num_samples)]
masks = torch.zeros(num_samples, len_).bool()
for i, indices in enumerate(mask_inds):
for j in indices:
masks.data[i, j] = 1
assert (masks.sum(dim=0) != torch.ones(len_).long()).sum().item() == 0
assert masks.sum().item() == len_
return masks
def eval_gather_inds(len_, num_samples=7):
""" get the gather indices """
inds = torch.arange(0, num_samples, dtype=torch.long)
mul = math.ceil(len_ / num_samples)
output = inds.repeat(mul)[:len_]
return output
def stack_pad_tensors(tensors, lens=None, ns=None, pad=0):
"""N x [B_i, T, ...]"""
if ns is None:
ns = [t.size(0) for t in tensors]
if lens is None:
lens = [t.size(1) for t in tensors]
max_len = max(lens)
bs = sum(ns)
hid_dims = tensors[0].size()[2:]
dtype = tensors[0].dtype
output = torch.zeros(bs, max_len, *hid_dims, dtype=dtype)
if pad:
output.data.fill_(pad)
i = 0
for t, l, n in zip(tensors, lens, ns):
output.data[i:i+n, :l, ...] = t.data
i += n
return output
def expand_tensors(tensors, ns):
return [t.unsqueeze(0).expand(n, *tuple([-1]*t.dim()))
for t, n in zip(tensors, ns)]
class MlmEvalDataset(DetectFeatTxtTokDataset):
""" For evaluating MLM training task """
def __init__(self, txt_db, img_db):
assert isinstance(txt_db, TxtTokLmdb)
super().__init__(txt_db, img_db)
def __getitem__(self, i):
example = super().__getitem__(i)
# text input
(input_ids, txt_labels, gather_inds
) = self.create_mlm_eval_io(example['input_ids'])
# img input
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
attn_masks = torch.ones(input_ids.size(1) + num_bb, dtype=torch.long)
return (input_ids, img_feat, img_pos_feat, attn_masks,
txt_labels, gather_inds)
def create_mlm_eval_io(self, input_ids):
txt_labels = torch.tensor(input_ids)
masks = eval_mask(len(input_ids))
n_mask = masks.size(0)
masks = torch.cat([torch.zeros(n_mask, 1).bool(),
masks,
torch.zeros(n_mask, 1).bool()],
dim=1)
input_ids = torch.tensor([[self.txt_db.cls_]
+ input_ids
+ [self.txt_db.sep]
for _ in range(n_mask)])
input_ids.data.masked_fill_(masks, self.txt_db.mask)
gather_inds = eval_gather_inds(len(txt_labels))
return input_ids, txt_labels, gather_inds
def _batch_gather_tgt(gather_inds, n_masks):
gather_tgts = []
offset = 0
for g, n in zip(gather_inds, n_masks):
gather_tgts.append(g + offset)
offset += n
gather_tgt = pad_sequence(gather_tgts, batch_first=True, padding_value=0)
return gather_tgt
def mlm_eval_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks, txt_labels, gather_inds
) = map(list, unzip(inputs))
# sizes
n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids))
# text batches
input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
gather_tgt = _batch_gather_tgt(gather_inds, n_masks)
# image batches
num_bbs = [f.size(0) for f in img_feats]
img_feat = stack_pad_tensors(expand_tensors(img_feats, n_masks),
num_bbs, n_masks)
img_pos_feat = stack_pad_tensors(expand_tensors(img_pos_feats, n_masks),
num_bbs, n_masks)
bs, max_tl = input_ids.size()
attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks),
None, n_masks)
out_size = attn_masks.size(1)
# repeat txt_lens, num_bbs
txt_lens = [l for l, n in zip(txt_lens, n_masks) for _ in range(n)]
num_bbs = [b for b, n in zip(num_bbs, n_masks) for _ in range(n)]
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'gather_tgt': gather_tgt,
'txt_labels': txt_labels}
return batch
class BlindMlmEvalDataset(Dataset):
def __init__(self, txt_db):
assert isinstance(txt_db, TxtTokLmdb)
self.txt_db = txt_db
self.lens, self.ids = get_ids_and_lens(txt_db)
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
id_ = self.ids[i]
example = self.txt_db[id_]
input_ids = example['input_ids']
# text input
input_ids = example['input_ids']
(input_ids, txt_labels, gather_inds
) = self.txt_db.create_mlm_eval_io(input_ids)
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
return input_ids, attn_masks, txt_labels, gather_inds
def mlm_blind_eval_collate(inputs):
(input_ids, position_ids, attn_masks, txt_labels, gather_inds
) = map(list, unzip(inputs))
# sizes
n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids))
# text batches
input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks),
None, n_masks)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
gather_tgt = _batch_gather_tgt(gather_inds, n_masks)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'attn_masks': attn_masks,
'gather_tgt': gather_tgt,
'txt_labels': txt_labels}
return batch