lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
592 lines
22 KiB
592 lines
22 KiB
"""
|
|
Itm dataset
|
|
"""
|
|
from collections import defaultdict
|
|
import copy
|
|
import json
|
|
import random
|
|
|
|
import torch
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
import numpy as np
|
|
from toolz.sandbox import unzip
|
|
from cytoolz import concat
|
|
|
|
from uniter_model.data.data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb,
|
|
pad_tensors, get_gather_index, get_ids_and_lens)
|
|
from uniter_model.data.sampler import TokenBucketSampler
|
|
|
|
|
|
class TokenBucketSamplerForItm(TokenBucketSampler):
|
|
def __init__(self, dset, *args, **kwargs):
|
|
super().__init__(dset.lens, *args, **kwargs)
|
|
self.dset = dset
|
|
|
|
def __iter__(self):
|
|
it = super().__iter__()
|
|
self.dset.new_epoch()
|
|
self._lens = self.dset.lens
|
|
return it
|
|
|
|
|
|
def _has_overlap(la, lb):
|
|
if len(la) < len(lb):
|
|
la, lb = lb, la
|
|
s = set(la)
|
|
return any(b in s for b in lb)
|
|
|
|
|
|
def _sample_negative_rand(sample_pool, ground_truths, num_sample):
|
|
""" random and retry """
|
|
outputs = ground_truths[:1]
|
|
while _has_overlap(outputs, ground_truths):
|
|
outputs = random.sample(sample_pool, num_sample)
|
|
return outputs
|
|
|
|
|
|
def _sample_negative_extra(sample_pool, ground_truths, num_sample):
|
|
""" sample extra then remove """
|
|
tot_size = len(ground_truths) + num_sample
|
|
outputs = set(random.sample(sample_pool, tot_size))
|
|
for gt in ground_truths:
|
|
outputs.discard(gt)
|
|
outputs = list(outputs)[:num_sample]
|
|
return outputs
|
|
|
|
|
|
sample_negative = _sample_negative_rand # swith between 2 implementations
|
|
|
|
|
|
class ItmDataset(DetectFeatTxtTokDataset):
|
|
""" NOTE this Dataset handles distributed training itself
|
|
(for more efficient negative sampling) """
|
|
def __init__(self, txt_db, img_db, neg_sample_p=0.0):
|
|
assert isinstance(txt_db, TxtTokLmdb)
|
|
assert isinstance(img_db, DetectFeatLmdb)
|
|
|
|
self.txt_db = txt_db
|
|
self.img_db = img_db
|
|
|
|
self.txt_lens, self.ids = get_ids_and_lens(txt_db)
|
|
self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids))
|
|
|
|
self.neg_sample_p = neg_sample_p
|
|
self.new_epoch()
|
|
|
|
def new_epoch(self):
|
|
""" should be called every epoch for more randomness"""
|
|
self.labels = np.random.choice(
|
|
[0, 1], size=len(self.ids),
|
|
p=[self.neg_sample_p, 1-self.neg_sample_p])
|
|
|
|
self.lens = []
|
|
self.train_imgs = []
|
|
for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)):
|
|
img_fname = super().__getitem__(i)['img_fname']
|
|
if self.labels[i] == 0:
|
|
img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0]
|
|
self.train_imgs.append(img_fname)
|
|
self.lens.append(tl + self.img_db.name2nbb[img_fname])
|
|
|
|
def __getitem__(self, i):
|
|
example = super().__getitem__(i)
|
|
# labels and negative images should be sampled every epoch
|
|
ground_truth_label = self.labels[i]
|
|
img_fname = self.train_imgs[i]
|
|
img_input_ids = torch.Tensor([101]).long()
|
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname)
|
|
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
|
|
|
|
# text input
|
|
input_ids = example['input_ids']
|
|
input_ids = self.txt_db.combine_inputs(input_ids)
|
|
|
|
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
|
|
target = torch.Tensor(1).long()
|
|
target.data.fill_(ground_truth_label)
|
|
|
|
return input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, target
|
|
|
|
|
|
def itm_collate(inputs):
|
|
(input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, targets
|
|
) = map(list, unzip(inputs))
|
|
|
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
|
|
|
|
num_bbs = [f.size(0) for f in img_feats]
|
|
img_feat = pad_tensors(img_feats, num_bbs)
|
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
|
|
|
|
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0)
|
|
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
|
|
|
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
|
|
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
|
|
targets = torch.cat(targets, dim=0)
|
|
bs, max_tl = input_ids.size()
|
|
out_size = attn_masks_img.size(1)
|
|
# gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
|
|
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
|
|
|
|
batch = {
|
|
'txts': {
|
|
'input_ids': input_ids,
|
|
'position_ids': position_ids,
|
|
'attention_mask': attn_masks,
|
|
'img_feat': None,
|
|
'img_pos_feat': None,
|
|
'img_masks': None,
|
|
'gather_index': None
|
|
},
|
|
'imgs': {
|
|
'input_ids': img_input_ids,
|
|
'position_ids': img_position_ids,
|
|
'attention_mask': attn_masks_img,
|
|
'img_feat': img_feat,
|
|
'img_pos_feat': img_pos_feat,
|
|
'img_masks': None,
|
|
'gather_index': gather_index
|
|
},
|
|
'pos_ctx_indices': list(range(bs)),
|
|
'neg_ctx_indices': list(range(bs, len(num_bbs))),
|
|
'targets': targets
|
|
}
|
|
return batch
|
|
|
|
|
|
def _compute_ot_scatter(txt_lens, max_txt_len, joint_len):
|
|
ot_scatter = torch.arange(0, joint_len, dtype=torch.long
|
|
).unsqueeze(0).repeat(len(txt_lens), 1)
|
|
for i, tl in enumerate(txt_lens):
|
|
max_ind = max_txt_len + (joint_len-tl)
|
|
ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind,
|
|
dtype=torch.long).data
|
|
return ot_scatter
|
|
|
|
|
|
def _compute_pad(lens, max_len):
|
|
pad = torch.zeros(len(lens), max_len, dtype=torch.bool)
|
|
for i, l in enumerate(lens):
|
|
pad.data[i, l:].fill_(1)
|
|
return pad
|
|
|
|
|
|
def itm_ot_collate(inputs):
|
|
(input_ids, img_feats, img_pos_feats, attn_masks, targets
|
|
) = map(list, unzip(inputs))
|
|
|
|
txt_lens = [i.size(0) for i in input_ids]
|
|
|
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
|
|
).unsqueeze(0)
|
|
|
|
num_bbs = [f.size(0) for f in img_feats]
|
|
img_feat = pad_tensors(img_feats, num_bbs)
|
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
|
|
|
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
|
|
targets = torch.cat(targets, dim=0)
|
|
bs, max_tl = input_ids.size()
|
|
out_size = attn_masks.size(1)
|
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
|
|
|
|
# OT inputs
|
|
max_tl = max(txt_lens)
|
|
max_nbb = max(num_bbs)
|
|
ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1))
|
|
txt_pad = _compute_pad(txt_lens, max_tl)
|
|
img_pad = _compute_pad(num_bbs, max_nbb)
|
|
ot_inputs = {'ot_scatter': ot_scatter,
|
|
'scatter_max': ot_scatter.max().item(),
|
|
'txt_pad': txt_pad,
|
|
'img_pad': img_pad}
|
|
|
|
batch = {'input_ids': input_ids,
|
|
'position_ids': position_ids,
|
|
'img_feat': img_feat,
|
|
'img_pos_feat': img_pos_feat,
|
|
'attn_masks': attn_masks,
|
|
'gather_index': gather_index,
|
|
'targets': targets,
|
|
'ot_inputs': ot_inputs}
|
|
return batch
|
|
|
|
|
|
class ItmRankDataset(DetectFeatTxtTokDataset):
|
|
def __init__(self, txt_db, img_db, neg_sample_size=1):
|
|
assert neg_sample_size > 0, \
|
|
"ItmRankDataset need at least 1 negative sample"
|
|
super().__init__(txt_db, img_db)
|
|
|
|
txt2img = self.txt_db.txt2img
|
|
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
|
|
# images partitioned by rank
|
|
self.img2txts = defaultdict(list)
|
|
for id_, img in self.txt2img.items():
|
|
self.img2txts[img].append(id_)
|
|
self.img_name_list = list(self.img2txts.keys())
|
|
|
|
assert neg_sample_size > 0
|
|
self.neg_sample_size = neg_sample_size
|
|
|
|
def __getitem__(self, i):
|
|
gt_txt_id = self.ids[i]
|
|
gt_img_fname = self.txt2img[gt_txt_id]
|
|
|
|
id_pairs = [(gt_txt_id, gt_img_fname)]
|
|
# sample negatives
|
|
neg_sample_img_ids = sample_negative(
|
|
self.img_name_list, [gt_img_fname], self.neg_sample_size)
|
|
neg_sample_txt_ids = sample_negative(
|
|
self.ids, self.img2txts[gt_img_fname], self.neg_sample_size)
|
|
id_pairs.extend([(gt_txt_id, neg_img_id)
|
|
for neg_img_id in neg_sample_img_ids] +
|
|
[(neg_txt_id, gt_img_fname)
|
|
for neg_txt_id in neg_sample_txt_ids])
|
|
inputs = self._collect_inputs(id_pairs)
|
|
assert len(inputs) == (1 + 2*self.neg_sample_size)
|
|
return inputs
|
|
|
|
def _collect_inputs(self, id_pairs):
|
|
# create input features
|
|
inputs = []
|
|
for txt_id, img_id in id_pairs:
|
|
example = self.txt_db[txt_id]
|
|
# text input
|
|
input_ids = example['input_ids']
|
|
input_ids = self.txt_db.combine_inputs(input_ids)
|
|
# img input
|
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id)
|
|
# mask
|
|
attn_masks_text = torch.ones(len(input_ids), dtype=torch.long)
|
|
attn_masks_img = torch.ones(num_bb, dtype=torch.long)
|
|
|
|
inputs.append((input_ids, img_feat, img_pos_feat, attn_masks_text, attn_masks_img))
|
|
|
|
return inputs
|
|
|
|
|
|
class ItmRankDatasetHardNeg(ItmRankDataset):
|
|
def __init__(self, txt_db, img_db, neg_sample_size=1, hard_neg_size=1):
|
|
assert hard_neg_size > 0, \
|
|
"ItmRankDatasetHardNeg need at least 1 hard negative sample"
|
|
DetectFeatTxtTokDataset.__init__(self, txt_db, img_db)
|
|
|
|
txt2img = self.txt_db.txt2img
|
|
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
|
|
self.img2txts = self.txt_db.img2txts
|
|
self.img_name_list = list(self.img2txts.keys())
|
|
|
|
assert neg_sample_size > 0
|
|
self.neg_sample_size = neg_sample_size
|
|
self.hard_neg_size = hard_neg_size
|
|
|
|
def reload_hard_negs(self, hard_neg_dir):
|
|
self.txt2hardimgs = json.load(
|
|
open(f'{hard_neg_dir}/'
|
|
f'txt2hardimgs_rank{hvd.rank()}.json'))
|
|
self.img2hardtxts = json.load(
|
|
open(f'{hard_neg_dir}/img2hardtxts.json'))
|
|
|
|
def __getitem__(self, i):
|
|
gt_txt_id = self.ids[i]
|
|
gt_img_fname = self.txt2img[gt_txt_id]
|
|
|
|
id_pairs = [(gt_txt_id, gt_img_fname)]
|
|
# sample hard negatives
|
|
if self.hard_neg_size > 0:
|
|
hard_neg_img_samples = random.sample(
|
|
self.txt2hardimgs[gt_txt_id], self.hard_neg_size)
|
|
hard_neg_txt_samples = random.sample(
|
|
self.img2hardtxts[gt_img_fname], self.hard_neg_size)
|
|
id_pairs.extend([(gt_txt_id, neg_img_id)
|
|
for neg_img_id in hard_neg_img_samples] +
|
|
[(neg_txt_id, gt_img_fname)
|
|
for neg_txt_id in hard_neg_txt_samples])
|
|
# sample normal negatives
|
|
if self.neg_sample_size > 0:
|
|
neg_sample_img_ids = sample_negative(
|
|
self.img_name_list, [gt_img_fname], self.neg_sample_size)
|
|
neg_sample_txt_ids = sample_negative(
|
|
self.ids, self.img2txts[gt_img_fname], self.neg_sample_size)
|
|
id_pairs.extend([(gt_txt_id, neg_img_id)
|
|
for neg_img_id in neg_sample_img_ids] +
|
|
[(neg_txt_id, gt_img_fname)
|
|
for neg_txt_id in neg_sample_txt_ids])
|
|
|
|
inputs = self._collect_inputs(id_pairs)
|
|
assert len(inputs) == (1
|
|
+ 2*self.neg_sample_size
|
|
+ 2*self.hard_neg_size)
|
|
return inputs
|
|
|
|
|
|
def itm_rank_collate(inputs):
|
|
(input_ids, img_feats, img_pos_feats, attn_masks_text, attn_masks_img,
|
|
) = map(list, unzip(concat(i for i in inputs)))
|
|
|
|
txt_lens = [i.size(0) for i in input_ids]
|
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
|
|
).unsqueeze(0)
|
|
|
|
num_bbs = [f.size(0) for f in img_feats]
|
|
img_feat = pad_tensors(img_feats, num_bbs)
|
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
|
|
|
|
attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0)
|
|
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
|
|
sample_size = len(inputs[0])
|
|
assert all(sample_size == len(i) for i in inputs)
|
|
|
|
bs, max_tl = input_ids.size()
|
|
# out_size = attn_masks.size(1)
|
|
gather_index = None # get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
|
|
|
|
batch = {'input_ids': input_ids,
|
|
'position_ids': position_ids,
|
|
'img_feat': img_feat,
|
|
'img_pos_feat': img_pos_feat,
|
|
'attn_masks_text': attn_masks_text,
|
|
'attn_masks_img': attn_masks_img,
|
|
'gather_index': gather_index,
|
|
'sample_size': sample_size}
|
|
return batch
|
|
|
|
|
|
class ItmRankDatasetHardNegFromText(DetectFeatTxtTokDataset):
|
|
def __init__(self, txt_db, img_db, neg_sample_size=1):
|
|
assert neg_sample_size > 0, \
|
|
"ItmRankDatasetHardNegV2 need at least 1 negative sample"
|
|
super().__init__(txt_db, img_db)
|
|
|
|
txt2img = self.txt_db.txt2img
|
|
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
|
|
self.img2txts = self.txt_db.img2txts
|
|
self.img_name_list = list(self.img2txts.keys())
|
|
self.neg_sample_size = neg_sample_size
|
|
|
|
def __getitem__(self, i):
|
|
gt_txt_id = self.ids[i]
|
|
gt_img_fname = self.txt2img[gt_txt_id]
|
|
|
|
input_ids = self.txt_db[gt_txt_id]['input_ids']
|
|
input_ids = self.txt_db.combine_inputs(input_ids)
|
|
input_ids = input_ids.unsqueeze(0)
|
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
|
|
).unsqueeze(0)
|
|
|
|
neg_img_ids = sample_negative(
|
|
self.img_name_list, [gt_img_fname], self.neg_sample_size)
|
|
img_ids = [gt_img_fname] + neg_img_ids
|
|
# process image features (gt always first)
|
|
img_feats, img_pos_feats, num_bbs = map(
|
|
list, unzip(map(self._get_img_feat, img_ids)))
|
|
img_feat = pad_tensors(img_feats, num_bbs)
|
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
|
|
|
|
tl = input_ids.size(1)
|
|
attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long()
|
|
for i, nbb in enumerate(num_bbs):
|
|
attn_masks.data[i, :tl+nbb].fill_(1)
|
|
out_size = attn_masks.size(1)
|
|
gather_index = get_gather_index([tl]*len(img_ids), num_bbs,
|
|
len(img_ids), tl, out_size)
|
|
|
|
batch = {'input_ids': input_ids,
|
|
'position_ids': position_ids,
|
|
'img_feat': img_feat,
|
|
'img_pos_feat': img_pos_feat,
|
|
'attn_masks': attn_masks,
|
|
'gather_index': gather_index}
|
|
return batch
|
|
|
|
|
|
class ItmRankDatasetHardNegFromImage(DetectFeatTxtTokDataset):
|
|
def __init__(self, txt_db, img_db, neg_sample_size=1):
|
|
assert neg_sample_size > 0, \
|
|
"ItmRankDatasetHardNegV2 need at least 1 negative sample"
|
|
super().__init__(txt_db, img_db)
|
|
|
|
txt2img = self.txt_db.txt2img
|
|
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
|
|
self.img2txts = self.txt_db.img2txts
|
|
self.txt_name_list = list(self.txt2img.keys())
|
|
self.neg_sample_size = neg_sample_size
|
|
|
|
|
|
def __getitem__(self, i):
|
|
gt_txt_id = self.ids[i]
|
|
gt_img_id = self.txt2img[gt_txt_id]
|
|
gt_txt_ids = self.img2txts[gt_img_id]
|
|
|
|
# process image features (gt always first)
|
|
img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id)
|
|
img_feat = img_feat.unsqueeze(0)
|
|
img_pos_feat = img_pos_feat.unsqueeze(0)
|
|
|
|
# sample negative
|
|
neg_txt_ids = sample_negative(
|
|
self.txt_name_list, gt_txt_ids, self.neg_sample_size)
|
|
txt_ids = [gt_txt_id] + neg_txt_ids
|
|
|
|
# process text inputs
|
|
all_inputs = []
|
|
txt_lens = []
|
|
for txt_id in txt_ids:
|
|
input_ids = self.txt_db.combine_inputs(
|
|
self.txt_db[txt_id]['input_ids'])
|
|
all_inputs.append(input_ids)
|
|
txt_lens.append(len(input_ids))
|
|
input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0)
|
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
|
|
).unsqueeze(0)
|
|
|
|
attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long()
|
|
for i, tl in enumerate(txt_lens):
|
|
attn_masks.data[i, :tl+nbb].fill_(1)
|
|
out_size = attn_masks.size(1)
|
|
gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids),
|
|
len(txt_ids), tl, out_size)
|
|
|
|
batch = {'input_ids': input_ids,
|
|
'position_ids': position_ids,
|
|
'img_feat': img_feat,
|
|
'img_pos_feat': img_pos_feat,
|
|
'attn_masks': attn_masks,
|
|
'gather_index': gather_index}
|
|
return batch
|
|
|
|
|
|
def itm_rank_hnv2_collate(inputs):
|
|
assert len(inputs) == 1
|
|
return inputs[0]
|
|
|
|
|
|
class ItmValDataset(DetectFeatTxtTokDataset):
|
|
""" For evaluating Image-Text-Retrieval task """
|
|
def __init__(self, db_dir, img_dir, mini_batch_size=400):
|
|
super().__init__(db_dir, img_dir)
|
|
del self.lens
|
|
self.txt2img = self.txt_db.txt2img
|
|
self.img2txts = self.txt_db.img2txts
|
|
self.all_img_ids = list(self.img2txts.keys())
|
|
|
|
assert len(self.img2txts) >= mini_batch_size > 0
|
|
self.bs = mini_batch_size
|
|
|
|
def _get_batch_ids(self, i):
|
|
gt_txt_id = self.ids[i]
|
|
gt_img_id = self.txt2img[gt_txt_id]
|
|
|
|
# sample fixed negatives for each gt image
|
|
i = self.all_img_ids.index(gt_img_id)
|
|
neg_st = i+1
|
|
neg_end = neg_st+self.bs-1
|
|
if neg_end > len(self.all_img_ids):
|
|
# warp around
|
|
neg_end -= len(self.all_img_ids)
|
|
neg_img_ids = (self.all_img_ids[neg_st:]
|
|
+ self.all_img_ids[:neg_end])
|
|
else:
|
|
neg_img_ids = self.all_img_ids[neg_st:neg_end]
|
|
|
|
assert len(neg_img_ids) == (self.bs - 1),\
|
|
"Did not sample enough neg samples"
|
|
|
|
return gt_img_id, neg_img_ids
|
|
|
|
def __getitem__(self, i):
|
|
""" this returns list of mini-batches """
|
|
gt_img_id, neg_img_ids = self._get_batch_ids(i)
|
|
# NOTE 1st one is gt img
|
|
batch = self.get_batch(i, [gt_img_id] + neg_img_ids)
|
|
return batch
|
|
|
|
def get_batch(self, i, img_ids):
|
|
example = super().__getitem__(i)
|
|
|
|
input_ids = example['input_ids']
|
|
input_ids = self.txt_db.combine_inputs(input_ids)
|
|
input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone()
|
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
|
|
).unsqueeze(0)
|
|
|
|
# process image features (gt always first)
|
|
img_feats, img_pos_feats, num_bbs = map(
|
|
list, unzip(map(self._get_img_feat, img_ids)))
|
|
img_feat = pad_tensors(img_feats, num_bbs)
|
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
|
|
|
|
tl = input_ids.size(1)
|
|
attn_masks_text = torch.ones(len(img_ids), tl).long()
|
|
# attn_masks_text = torch.ones(1, tl).long()
|
|
attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long()
|
|
for i, nbb in enumerate(num_bbs):
|
|
attn_masks_img.data[i, :nbb].fill_(1)
|
|
|
|
# out_size = attn_masks.size(1)
|
|
gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size)
|
|
|
|
batch = {'input_ids': input_ids,
|
|
'position_ids': position_ids,
|
|
'img_feat': img_feat,
|
|
'img_pos_feat': img_pos_feat,
|
|
'attn_masks_text': attn_masks_text,
|
|
'attn_masks_img': attn_masks_img,
|
|
'gather_index': gather_index}
|
|
return batch
|
|
|
|
|
|
def itm_val_collate(inputs):
|
|
assert len(inputs) == 1, "input batch size > 1"
|
|
return inputs[0]
|
|
|
|
|
|
class ItmHardNegDataset(ItmValDataset):
|
|
def _get_batch_ids(self, i):
|
|
gt_txt_id = self.ids[i]
|
|
gt_img_id = self.txt2img[gt_txt_id]
|
|
|
|
# sample fixed negatives for each gt image
|
|
i = self.all_img_ids.index(gt_img_id)
|
|
all_img_ids = copy.deepcopy(self.all_img_ids)
|
|
all_img_ids.remove(gt_img_id)
|
|
random.shuffle(all_img_ids)
|
|
neg_img_ids = all_img_ids[:self.bs]
|
|
|
|
assert len(neg_img_ids) == (self.bs),\
|
|
"Did not sample enough neg samples"
|
|
|
|
return gt_img_id, neg_img_ids
|
|
|
|
def __getitem__(self, i):
|
|
""" this returns list of mini-batches """
|
|
_, neg_img_ids = self._get_batch_ids(i)
|
|
batch = self.get_batch(i, neg_img_ids)
|
|
batch['gt_txt_id'] = self.ids[i]
|
|
batch['neg_img_ids'] = neg_img_ids
|
|
return batch
|
|
|
|
|
|
itm_hn_collate = itm_val_collate
|
|
|
|
|
|
class ItmEvalDataset(ItmValDataset):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids),
|
|
key=lambda i: self.img_db.name2nbb[i])
|
|
|
|
def __getitem__(self, i):
|
|
mini_batches = []
|
|
for st in range(0, len(self.all_img_ids), self.bs):
|
|
mini_batches.append(
|
|
self.get_batch(i, self.all_img_ids[st:st+self.bs]))
|
|
return mini_batches
|
|
|
|
|
|
itm_eval_collate = itm_val_collate
|