Itm dataset
from collections import defaultdict
import copy
import json
import random
import torch
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from toolz.sandbox import unzip
from cytoolz import concat
from uniter_model.data.data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb,
pad_tensors, get_gather_index, get_ids_and_lens)
from uniter_model.data.sampler import TokenBucketSampler
class TokenBucketSamplerForItm(TokenBucketSampler):
def __init__(self, dset, *args, **kwargs):
super().__init__(dset.lens, *args, **kwargs)
self.dset = dset
def __iter__(self):
it = super().__iter__()
self._lens = self.dset.lens
return it
def _has_overlap(la, lb):
if len(la) < len(lb):
la, lb = lb, la
s = set(la)
return any(b in s for b in lb)
def _sample_negative_rand(sample_pool, ground_truths, num_sample):
""" random and retry """
outputs = ground_truths[:1]
while _has_overlap(outputs, ground_truths):
outputs = random.sample(sample_pool, num_sample)
return outputs
def _sample_negative_extra(sample_pool, ground_truths, num_sample):
""" sample extra then remove """
tot_size = len(ground_truths) + num_sample
outputs = set(random.sample(sample_pool, tot_size))
for gt in ground_truths:
outputs = list(outputs)[:num_sample]
return outputs
sample_negative = _sample_negative_rand # swith between 2 implementations
class ItmDataset(DetectFeatTxtTokDataset):
""" NOTE this Dataset handles distributed training itself
(for more efficient negative sampling) """
def __init__(self, txt_db, img_db, neg_sample_p=0.0):
assert isinstance(txt_db, TxtTokLmdb)
assert isinstance(img_db, DetectFeatLmdb)
self.txt_db = txt_db
self.img_db = img_db
self.txt_lens, self.ids = get_ids_and_lens(txt_db)
self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids))
self.neg_sample_p = neg_sample_p
def new_epoch(self):
""" should be called every epoch for more randomness"""
self.labels = np.random.choice(
[0, 1], size=len(self.ids),
p=[self.neg_sample_p, 1-self.neg_sample_p])
self.lens = []
self.train_imgs = []
for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)):
img_fname = super().__getitem__(i)['img_fname']
if self.labels[i] == 0:
img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0]
self.lens.append(tl + self.img_db.name2nbb[img_fname])
def __getitem__(self, i):
example = super().__getitem__(i)
# labels and negative images should be sampled every epoch
ground_truth_label = self.labels[i]
img_fname = self.train_imgs[i]
img_input_ids = torch.Tensor([101]).long()
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname)
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
target = torch.Tensor(1).long()
return input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, target
def itm_collate(inputs):
(input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, targets
) = map(list, unzip(inputs))
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0)
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
targets = torch.cat(targets, dim=0)
bs, max_tl = input_ids.size()
out_size = attn_masks_img.size(1)
# gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
batch = {
'txts': {
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attn_masks,
'img_feat': None,
'img_pos_feat': None,
'img_masks': None,
'gather_index': None
'imgs': {
'input_ids': img_input_ids,
'position_ids': img_position_ids,
'attention_mask': attn_masks_img,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'img_masks': None,
'gather_index': gather_index
'pos_ctx_indices': list(range(bs)),
'neg_ctx_indices': list(range(bs, len(num_bbs))),
'targets': targets
return batch
def _compute_ot_scatter(txt_lens, max_txt_len, joint_len):
ot_scatter = torch.arange(0, joint_len, dtype=torch.long
).unsqueeze(0).repeat(len(txt_lens), 1)
for i, tl in enumerate(txt_lens):
max_ind = max_txt_len + (joint_len-tl)
ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind,
return ot_scatter
def _compute_pad(lens, max_len):
pad = torch.zeros(len(lens), max_len, dtype=torch.bool)
for i, l in enumerate(lens):
pad.data[i, l:].fill_(1)
return pad
def itm_ot_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks, targets
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
targets = torch.cat(targets, dim=0)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
# OT inputs
max_tl = max(txt_lens)
max_nbb = max(num_bbs)
ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1))
txt_pad = _compute_pad(txt_lens, max_tl)
img_pad = _compute_pad(num_bbs, max_nbb)
ot_inputs = {'ot_scatter': ot_scatter,
'scatter_max': ot_scatter.max().item(),
'txt_pad': txt_pad,
'img_pad': img_pad}
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'targets': targets,
'ot_inputs': ot_inputs}
return batch
class ItmRankDataset(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db, neg_sample_size=1):
assert neg_sample_size > 0, \
"ItmRankDataset need at least 1 negative sample"
super().__init__(txt_db, img_db)
txt2img = self.txt_db.txt2img
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
# images partitioned by rank
self.img2txts = defaultdict(list)
for id_, img in self.txt2img.items():
self.img_name_list = list(self.img2txts.keys())
assert neg_sample_size > 0
self.neg_sample_size = neg_sample_size
def __getitem__(self, i):
gt_txt_id = self.ids[i]
gt_img_fname = self.txt2img[gt_txt_id]
id_pairs = [(gt_txt_id, gt_img_fname)]
# sample negatives
neg_sample_img_ids = sample_negative(
self.img_name_list, [gt_img_fname], self.neg_sample_size)
neg_sample_txt_ids = sample_negative(
self.ids, self.img2txts[gt_img_fname], self.neg_sample_size)
id_pairs.extend([(gt_txt_id, neg_img_id)
for neg_img_id in neg_sample_img_ids] +
[(neg_txt_id, gt_img_fname)
for neg_txt_id in neg_sample_txt_ids])
inputs = self._collect_inputs(id_pairs)
assert len(inputs) == (1 + 2*self.neg_sample_size)
return inputs
def _collect_inputs(self, id_pairs):
# create input features
inputs = []
for txt_id, img_id in id_pairs:
example = self.txt_db[txt_id]
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
# img input
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id)
# mask
attn_masks_text = torch.ones(len(input_ids), dtype=torch.long)
attn_masks_img = torch.ones(num_bb, dtype=torch.long)
inputs.append((input_ids, img_feat, img_pos_feat, attn_masks_text, attn_masks_img))
return inputs
class ItmRankDatasetHardNeg(ItmRankDataset):
def __init__(self, txt_db, img_db, neg_sample_size=1, hard_neg_size=1):
assert hard_neg_size > 0, \
"ItmRankDatasetHardNeg need at least 1 hard negative sample"
DetectFeatTxtTokDataset.__init__(self, txt_db, img_db)
txt2img = self.txt_db.txt2img
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
self.img2txts = self.txt_db.img2txts
self.img_name_list = list(self.img2txts.keys())
assert neg_sample_size > 0
self.neg_sample_size = neg_sample_size
self.hard_neg_size = hard_neg_size
def reload_hard_negs(self, hard_neg_dir):
self.txt2hardimgs = json.load(
self.img2hardtxts = json.load(
def __getitem__(self, i):
gt_txt_id = self.ids[i]
gt_img_fname = self.txt2img[gt_txt_id]
id_pairs = [(gt_txt_id, gt_img_fname)]
# sample hard negatives
if self.hard_neg_size > 0:
hard_neg_img_samples = random.sample(
self.txt2hardimgs[gt_txt_id], self.hard_neg_size)
hard_neg_txt_samples = random.sample(
self.img2hardtxts[gt_img_fname], self.hard_neg_size)
id_pairs.extend([(gt_txt_id, neg_img_id)
for neg_img_id in hard_neg_img_samples] +
[(neg_txt_id, gt_img_fname)
for neg_txt_id in hard_neg_txt_samples])
# sample normal negatives
if self.neg_sample_size > 0:
neg_sample_img_ids = sample_negative(
self.img_name_list, [gt_img_fname], self.neg_sample_size)
neg_sample_txt_ids = sample_negative(
self.ids, self.img2txts[gt_img_fname], self.neg_sample_size)
id_pairs.extend([(gt_txt_id, neg_img_id)
for neg_img_id in neg_sample_img_ids] +
[(neg_txt_id, gt_img_fname)
for neg_txt_id in neg_sample_txt_ids])
inputs = self._collect_inputs(id_pairs)
assert len(inputs) == (1
+ 2*self.neg_sample_size
+ 2*self.hard_neg_size)
return inputs
def itm_rank_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks_text, attn_masks_img,
) = map(list, unzip(concat(i for i in inputs)))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0)
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
sample_size = len(inputs[0])
assert all(sample_size == len(i) for i in inputs)
bs, max_tl = input_ids.size()
# out_size = attn_masks.size(1)
gather_index = None # get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks_text': attn_masks_text,
'attn_masks_img': attn_masks_img,
'gather_index': gather_index,
'sample_size': sample_size}
return batch
class ItmRankDatasetHardNegFromText(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db, neg_sample_size=1):
assert neg_sample_size > 0, \
"ItmRankDatasetHardNegV2 need at least 1 negative sample"
super().__init__(txt_db, img_db)
txt2img = self.txt_db.txt2img
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
self.img2txts = self.txt_db.img2txts
self.img_name_list = list(self.img2txts.keys())
self.neg_sample_size = neg_sample_size
def __getitem__(self, i):
gt_txt_id = self.ids[i]
gt_img_fname = self.txt2img[gt_txt_id]
input_ids = self.txt_db[gt_txt_id]['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
input_ids = input_ids.unsqueeze(0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
neg_img_ids = sample_negative(
self.img_name_list, [gt_img_fname], self.neg_sample_size)
img_ids = [gt_img_fname] + neg_img_ids
# process image features (gt always first)
img_feats, img_pos_feats, num_bbs = map(
list, unzip(map(self._get_img_feat, img_ids)))
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
tl = input_ids.size(1)
attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long()
for i, nbb in enumerate(num_bbs):
attn_masks.data[i, :tl+nbb].fill_(1)
out_size = attn_masks.size(1)
gather_index = get_gather_index([tl]*len(img_ids), num_bbs,
len(img_ids), tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index}
return batch
class ItmRankDatasetHardNegFromImage(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db, neg_sample_size=1):
assert neg_sample_size > 0, \
"ItmRankDatasetHardNegV2 need at least 1 negative sample"
super().__init__(txt_db, img_db)
txt2img = self.txt_db.txt2img
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
self.img2txts = self.txt_db.img2txts
self.txt_name_list = list(self.txt2img.keys())
self.neg_sample_size = neg_sample_size
def __getitem__(self, i):
gt_txt_id = self.ids[i]
gt_img_id = self.txt2img[gt_txt_id]
gt_txt_ids = self.img2txts[gt_img_id]
# process image features (gt always first)
img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id)
img_feat = img_feat.unsqueeze(0)
img_pos_feat = img_pos_feat.unsqueeze(0)
# sample negative
neg_txt_ids = sample_negative(
self.txt_name_list, gt_txt_ids, self.neg_sample_size)
txt_ids = [gt_txt_id] + neg_txt_ids
# process text inputs
all_inputs = []
txt_lens = []
for txt_id in txt_ids:
input_ids = self.txt_db.combine_inputs(
input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long()
for i, tl in enumerate(txt_lens):
attn_masks.data[i, :tl+nbb].fill_(1)
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids),
len(txt_ids), tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index}
return batch
def itm_rank_hnv2_collate(inputs):
assert len(inputs) == 1
return inputs[0]
class ItmValDataset(DetectFeatTxtTokDataset):
""" For evaluating Image-Text-Retrieval task """
def __init__(self, db_dir, img_dir, mini_batch_size=400):
super().__init__(db_dir, img_dir)
del self.lens
self.txt2img = self.txt_db.txt2img
self.img2txts = self.txt_db.img2txts
self.all_img_ids = list(self.img2txts.keys())
assert len(self.img2txts) >= mini_batch_size > 0
self.bs = mini_batch_size
def _get_batch_ids(self, i):
gt_txt_id = self.ids[i]
gt_img_id = self.txt2img[gt_txt_id]
# sample fixed negatives for each gt image
i = self.all_img_ids.index(gt_img_id)
neg_st = i+1
neg_end = neg_st+self.bs-1
if neg_end > len(self.all_img_ids):
# warp around
neg_end -= len(self.all_img_ids)
neg_img_ids = (self.all_img_ids[neg_st:]
+ self.all_img_ids[:neg_end])
neg_img_ids = self.all_img_ids[neg_st:neg_end]
assert len(neg_img_ids) == (self.bs - 1),\
"Did not sample enough neg samples"
return gt_img_id, neg_img_ids
def __getitem__(self, i):
""" this returns list of mini-batches """
gt_img_id, neg_img_ids = self._get_batch_ids(i)
# NOTE 1st one is gt img
batch = self.get_batch(i, [gt_img_id] + neg_img_ids)
return batch
def get_batch(self, i, img_ids):
example = super().__getitem__(i)
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone()
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
# process image features (gt always first)
img_feats, img_pos_feats, num_bbs = map(
list, unzip(map(self._get_img_feat, img_ids)))
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
tl = input_ids.size(1)
attn_masks_text = torch.ones(len(img_ids), tl).long()
# attn_masks_text = torch.ones(1, tl).long()
attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long()
for i, nbb in enumerate(num_bbs):
attn_masks_img.data[i, :nbb].fill_(1)
# out_size = attn_masks.size(1)
gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks_text': attn_masks_text,
'attn_masks_img': attn_masks_img,
'gather_index': gather_index}
return batch
def itm_val_collate(inputs):
assert len(inputs) == 1, "input batch size > 1"
return inputs[0]
class ItmHardNegDataset(ItmValDataset):
def _get_batch_ids(self, i):
gt_txt_id = self.ids[i]
gt_img_id = self.txt2img[gt_txt_id]
# sample fixed negatives for each gt image
i = self.all_img_ids.index(gt_img_id)
all_img_ids = copy.deepcopy(self.all_img_ids)
neg_img_ids = all_img_ids[:self.bs]
assert len(neg_img_ids) == (self.bs),\
"Did not sample enough neg samples"
return gt_img_id, neg_img_ids
def __getitem__(self, i):
""" this returns list of mini-batches """
_, neg_img_ids = self._get_batch_ids(i)
batch = self.get_batch(i, neg_img_ids)
batch['gt_txt_id'] = self.ids[i]
batch['neg_img_ids'] = neg_img_ids
return batch
itm_hn_collate = itm_val_collate
class ItmEvalDataset(ItmValDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids),
key=lambda i: self.img_db.name2nbb[i])
def __getitem__(self, i):
mini_batches = []
for st in range(0, len(self.all_img_ids), self.bs):
self.get_batch(i, self.all_img_ids[st:st+self.bs]))
return mini_batches
itm_eval_collate = itm_val_collate