MLM datasets
import math
import random
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from uniter_model.data.data import (DetectFeatTxtTokDataset, TxtTokLmdb, get_ids_and_lens, pad_tensors,
get_gather_index, get_gather_index_uniter)
def random_word(tokens, vocab_range, mask):
Masking some random tokens for Language Model task with probabilities as in
the original BERT paper.
:param tokens: list of int, tokenized sentence.
:param vocab_range: for choosing a random word
:return: (list of int, list of int), masked tokens and related labels for
LM prediction
output_label = []
for i, token in enumerate(tokens):
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
prob /= 0.15
# 80% randomly change token to mask token
if prob < 0.8:
tokens[i] = mask
# 10% randomly change token to random token
elif prob < 0.9:
tokens[i] = random.choice(list(range(*vocab_range)))
# -> rest 10% randomly keep current token
# append current token to output (we will predict these later)
# no masking token (will be ignored by loss function later)
if all(o == -1 for o in output_label):
# at least mask 1
output_label[0] = tokens[0]
tokens[0] = mask
return tokens, output_label
class MlmDataset(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db):
assert isinstance(txt_db, TxtTokLmdb)
super().__init__(txt_db, img_db)
def __getitem__(self, i):
- input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded
- img_feat : (num_bb, d)
- img_pos_feat : (num_bb, 7)
- attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1]
- txt_labels : (L, ), [-1, -1, wid, -1, -1, -1]
0's padded so that (L + num_bb) % 8 == 0
example = super().__getitem__(i)
# text input
input_ids, txt_labels = self.create_mlm_io(example['input_ids'])
# img input
img_input_ids = torch.Tensor([101]).long()
img_feat, img_pos_feat, num_bb = self._get_img_feat(example['img_fname'])
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
attn_masks_teacher = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, txt_labels, attn_masks_teacher
def create_mlm_io(self, input_ids):
input_ids, txt_labels = random_word(input_ids,
input_ids = torch.tensor([self.txt_db.cls_]
+ input_ids
+ [self.txt_db.sep])
txt_labels = torch.tensor([-1] + txt_labels + [-1])
return input_ids, txt_labels
def mlm_collate(inputs):
:input_ids (n, max_L) padded with 0
:position_ids (n, max_L) padded with 0
:txt_lens list of [txt_len]
:img_feat (n, max_num_bb, feat_dim)
:img_pos_feat (n, max_num_bb, 7)
:num_bbs list of [num_bb]
:attn_masks (n, max_{L + num_bb}) padded with 0
:txt_labels (n, max_L) padded with -1
(input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, txt_labels, attn_masks_teacher
) = map(list, unzip(inputs))
# text batches
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
# image batches
num_bbs = [f.size(0) for f in img_feats]
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0)
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
bs, max_tl = input_ids.size()
out_size = attn_masks_img.size(1)
# gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
attn_masks_teacher = pad_sequence(attn_masks_teacher, batch_first=True, padding_value=0)
gather_index_teacher = get_gather_index_uniter(txt_lens, num_bbs, bs, max_tl, attn_masks_teacher.size(1))
batch = {
'txts': {
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attn_masks,
'img_feat': None,
'img_pos_feat': None,
'img_masks': None,
'gather_index': None
'imgs': {
'input_ids': img_input_ids,
'position_ids': img_position_ids,
'attention_mask': attn_masks_img,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'img_masks': None,
'gather_index': gather_index
'txt_labels': txt_labels,
'teacher': {
'txt_lens': txt_lens,
'num_bbs': num_bbs,
'bs': bs,
'max_tl': max_tl,
'out_size': out_size,
'gather_index': gather_index_teacher,
'attn_masks': attn_masks_teacher
return batch
class BlindMlmDataset(Dataset):
def __init__(self, txt_db):
assert isinstance(txt_db, TxtTokLmdb)
self.txt_db = txt_db
self.lens, self.ids = get_ids_and_lens(txt_db)
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
id_ = self.ids[i]
example = self.txt_db[id_]
input_ids, txt_labels = self.create_mlm_io(example['input_ids'])
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
return input_ids, attn_masks, txt_labels
def mlm_blind_collate(inputs):
input_ids, attn_masks, txt_labels = map(list, unzip(inputs))
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'attn_masks': attn_masks,
'txt_labels': txt_labels}
return batch
def eval_mask(len_, num_samples=7):
""" build the mask for evaluating MLM
circularly mask 1 word out of every x words
# build the random masks
if len_ <= num_samples:
masks = torch.eye(len_).bool()
num_samples = len_
mask_inds = [list(range(i, len_, num_samples))
for i in range(num_samples)]
masks = torch.zeros(num_samples, len_).bool()
for i, indices in enumerate(mask_inds):
for j in indices:
masks.data[i, j] = 1
assert (masks.sum(dim=0) != torch.ones(len_).long()).sum().item() == 0
assert masks.sum().item() == len_
return masks
def eval_gather_inds(len_, num_samples=7):
""" get the gather indices """
inds = torch.arange(0, num_samples, dtype=torch.long)
mul = math.ceil(len_ / num_samples)
output = inds.repeat(mul)[:len_]
return output
def stack_pad_tensors(tensors, lens=None, ns=None, pad=0):
"""N x [B_i, T, ...]"""
if ns is None:
ns = [t.size(0) for t in tensors]
if lens is None:
lens = [t.size(1) for t in tensors]
max_len = max(lens)
bs = sum(ns)
hid_dims = tensors[0].size()[2:]
dtype = tensors[0].dtype
output = torch.zeros(bs, max_len, *hid_dims, dtype=dtype)
if pad:
i = 0
for t, l, n in zip(tensors, lens, ns):
output.data[i:i+n, :l, ...] = t.data
i += n
return output
def expand_tensors(tensors, ns):
return [t.unsqueeze(0).expand(n, *tuple([-1]*t.dim()))
for t, n in zip(tensors, ns)]
class MlmEvalDataset(DetectFeatTxtTokDataset):
""" For evaluating MLM training task """
def __init__(self, txt_db, img_db):
assert isinstance(txt_db, TxtTokLmdb)
super().__init__(txt_db, img_db)
def __getitem__(self, i):
example = super().__getitem__(i)
# text input
(input_ids, txt_labels, gather_inds
) = self.create_mlm_eval_io(example['input_ids'])
# img input
img_feat, img_pos_feat, num_bb = self._get_img_feat(
attn_masks = torch.ones(input_ids.size(1) + num_bb, dtype=torch.long)
return (input_ids, img_feat, img_pos_feat, attn_masks,
txt_labels, gather_inds)
def create_mlm_eval_io(self, input_ids):
txt_labels = torch.tensor(input_ids)
masks = eval_mask(len(input_ids))
n_mask = masks.size(0)
masks = torch.cat([torch.zeros(n_mask, 1).bool(),
torch.zeros(n_mask, 1).bool()],
input_ids = torch.tensor([[self.txt_db.cls_]
+ input_ids
+ [self.txt_db.sep]
for _ in range(n_mask)])
input_ids.data.masked_fill_(masks, self.txt_db.mask)
gather_inds = eval_gather_inds(len(txt_labels))
return input_ids, txt_labels, gather_inds
def _batch_gather_tgt(gather_inds, n_masks):
gather_tgts = []
offset = 0
for g, n in zip(gather_inds, n_masks):
gather_tgts.append(g + offset)
offset += n
gather_tgt = pad_sequence(gather_tgts, batch_first=True, padding_value=0)
return gather_tgt
def mlm_eval_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks, txt_labels, gather_inds
) = map(list, unzip(inputs))
# sizes
n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids))
# text batches
input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
gather_tgt = _batch_gather_tgt(gather_inds, n_masks)
# image batches
num_bbs = [f.size(0) for f in img_feats]
img_feat = stack_pad_tensors(expand_tensors(img_feats, n_masks),
num_bbs, n_masks)
img_pos_feat = stack_pad_tensors(expand_tensors(img_pos_feats, n_masks),
num_bbs, n_masks)
bs, max_tl = input_ids.size()
attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks),
None, n_masks)
out_size = attn_masks.size(1)
# repeat txt_lens, num_bbs
txt_lens = [l for l, n in zip(txt_lens, n_masks) for _ in range(n)]
num_bbs = [b for b, n in zip(num_bbs, n_masks) for _ in range(n)]
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'gather_tgt': gather_tgt,
'txt_labels': txt_labels}
return batch
class BlindMlmEvalDataset(Dataset):
def __init__(self, txt_db):
assert isinstance(txt_db, TxtTokLmdb)
self.txt_db = txt_db
self.lens, self.ids = get_ids_and_lens(txt_db)
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
id_ = self.ids[i]
example = self.txt_db[id_]
input_ids = example['input_ids']
# text input
input_ids = example['input_ids']
(input_ids, txt_labels, gather_inds
) = self.txt_db.create_mlm_eval_io(input_ids)
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
return input_ids, attn_masks, txt_labels, gather_inds
def mlm_blind_eval_collate(inputs):
(input_ids, position_ids, attn_masks, txt_labels, gather_inds
) = map(list, unzip(inputs))
# sizes
n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids))
# text batches
input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks),
None, n_masks)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
gather_tgt = _batch_gather_tgt(gather_inds, n_masks)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'attn_masks': attn_masks,
'gather_tgt': gather_tgt,
'txt_labels': txt_labels}
return batch