""" Itm dataset """ from collections import defaultdict import copy import json import random import torch from torch.nn.utils.rnn import pad_sequence import numpy as np from toolz.sandbox import unzip from cytoolz import concat from uniter_model.data.data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb, pad_tensors, get_gather_index, get_ids_and_lens) from uniter_model.data.sampler import TokenBucketSampler class TokenBucketSamplerForItm(TokenBucketSampler): def __init__(self, dset, *args, **kwargs): super().__init__(dset.lens, *args, **kwargs) self.dset = dset def __iter__(self): it = super().__iter__() self.dset.new_epoch() self._lens = self.dset.lens return it def _has_overlap(la, lb): if len(la) < len(lb): la, lb = lb, la s = set(la) return any(b in s for b in lb) def _sample_negative_rand(sample_pool, ground_truths, num_sample): """ random and retry """ outputs = ground_truths[:1] while _has_overlap(outputs, ground_truths): outputs = random.sample(sample_pool, num_sample) return outputs def _sample_negative_extra(sample_pool, ground_truths, num_sample): """ sample extra then remove """ tot_size = len(ground_truths) + num_sample outputs = set(random.sample(sample_pool, tot_size)) for gt in ground_truths: outputs.discard(gt) outputs = list(outputs)[:num_sample] return outputs sample_negative = _sample_negative_rand # swith between 2 implementations class ItmDataset(DetectFeatTxtTokDataset): """ NOTE this Dataset handles distributed training itself (for more efficient negative sampling) """ def __init__(self, txt_db, img_db, neg_sample_p=0.0): assert isinstance(txt_db, TxtTokLmdb) assert isinstance(img_db, DetectFeatLmdb) self.txt_db = txt_db self.img_db = img_db self.txt_lens, self.ids = get_ids_and_lens(txt_db) self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids)) self.neg_sample_p = neg_sample_p self.new_epoch() def new_epoch(self): """ should be called every epoch for more randomness""" self.labels = np.random.choice( [0, 1], size=len(self.ids), p=[self.neg_sample_p, 1-self.neg_sample_p]) self.lens = [] self.train_imgs = [] for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)): img_fname = super().__getitem__(i)['img_fname'] if self.labels[i] == 0: img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0] self.train_imgs.append(img_fname) self.lens.append(tl + self.img_db.name2nbb[img_fname]) def __getitem__(self, i): example = super().__getitem__(i) # labels and negative images should be sampled every epoch ground_truth_label = self.labels[i] img_fname = self.train_imgs[i] img_input_ids = torch.Tensor([101]).long() img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname) attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) # text input input_ids = example['input_ids'] input_ids = self.txt_db.combine_inputs(input_ids) attn_masks = torch.ones(len(input_ids), dtype=torch.long) target = torch.Tensor(1).long() target.data.fill_(ground_truth_label) return input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, target def itm_collate(inputs): (input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, targets ) = map(list, unzip(inputs)) input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0) num_bbs = [f.size(0) for f in img_feats] img_feat = pad_tensors(img_feats, num_bbs) img_pos_feat = pad_tensors(img_pos_feats, num_bbs) img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0) img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) targets = torch.cat(targets, dim=0) bs, max_tl = input_ids.size() out_size = attn_masks_img.size(1) # gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) batch = { 'txts': { 'input_ids': input_ids, 'position_ids': position_ids, 'attention_mask': attn_masks, 'img_feat': None, 'img_pos_feat': None, 'img_masks': None, 'gather_index': None }, 'imgs': { 'input_ids': img_input_ids, 'position_ids': img_position_ids, 'attention_mask': attn_masks_img, 'img_feat': img_feat, 'img_pos_feat': img_pos_feat, 'img_masks': None, 'gather_index': gather_index }, 'pos_ctx_indices': list(range(bs)), 'neg_ctx_indices': list(range(bs, len(num_bbs))), 'targets': targets } return batch def _compute_ot_scatter(txt_lens, max_txt_len, joint_len): ot_scatter = torch.arange(0, joint_len, dtype=torch.long ).unsqueeze(0).repeat(len(txt_lens), 1) for i, tl in enumerate(txt_lens): max_ind = max_txt_len + (joint_len-tl) ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind, dtype=torch.long).data return ot_scatter def _compute_pad(lens, max_len): pad = torch.zeros(len(lens), max_len, dtype=torch.bool) for i, l in enumerate(lens): pad.data[i, l:].fill_(1) return pad def itm_ot_collate(inputs): (input_ids, img_feats, img_pos_feats, attn_masks, targets ) = map(list, unzip(inputs)) txt_lens = [i.size(0) for i in input_ids] input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long ).unsqueeze(0) num_bbs = [f.size(0) for f in img_feats] img_feat = pad_tensors(img_feats, num_bbs) img_pos_feat = pad_tensors(img_pos_feats, num_bbs) attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) targets = torch.cat(targets, dim=0) bs, max_tl = input_ids.size() out_size = attn_masks.size(1) gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) # OT inputs max_tl = max(txt_lens) max_nbb = max(num_bbs) ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1)) txt_pad = _compute_pad(txt_lens, max_tl) img_pad = _compute_pad(num_bbs, max_nbb) ot_inputs = {'ot_scatter': ot_scatter, 'scatter_max': ot_scatter.max().item(), 'txt_pad': txt_pad, 'img_pad': img_pad} batch = {'input_ids': input_ids, 'position_ids': position_ids, 'img_feat': img_feat, 'img_pos_feat': img_pos_feat, 'attn_masks': attn_masks, 'gather_index': gather_index, 'targets': targets, 'ot_inputs': ot_inputs} return batch class ItmRankDataset(DetectFeatTxtTokDataset): def __init__(self, txt_db, img_db, neg_sample_size=1): assert neg_sample_size > 0, \ "ItmRankDataset need at least 1 negative sample" super().__init__(txt_db, img_db) txt2img = self.txt_db.txt2img self.txt2img = {id_: txt2img[id_] for id_ in self.ids} # images partitioned by rank self.img2txts = defaultdict(list) for id_, img in self.txt2img.items(): self.img2txts[img].append(id_) self.img_name_list = list(self.img2txts.keys()) assert neg_sample_size > 0 self.neg_sample_size = neg_sample_size def __getitem__(self, i): gt_txt_id = self.ids[i] gt_img_fname = self.txt2img[gt_txt_id] id_pairs = [(gt_txt_id, gt_img_fname)] # sample negatives neg_sample_img_ids = sample_negative( self.img_name_list, [gt_img_fname], self.neg_sample_size) neg_sample_txt_ids = sample_negative( self.ids, self.img2txts[gt_img_fname], self.neg_sample_size) id_pairs.extend([(gt_txt_id, neg_img_id) for neg_img_id in neg_sample_img_ids] + [(neg_txt_id, gt_img_fname) for neg_txt_id in neg_sample_txt_ids]) inputs = self._collect_inputs(id_pairs) assert len(inputs) == (1 + 2*self.neg_sample_size) return inputs def _collect_inputs(self, id_pairs): # create input features inputs = [] for txt_id, img_id in id_pairs: example = self.txt_db[txt_id] # text input input_ids = example['input_ids'] input_ids = self.txt_db.combine_inputs(input_ids) # img input img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id) # mask attn_masks_text = torch.ones(len(input_ids), dtype=torch.long) attn_masks_img = torch.ones(num_bb, dtype=torch.long) inputs.append((input_ids, img_feat, img_pos_feat, attn_masks_text, attn_masks_img)) return inputs class ItmRankDatasetHardNeg(ItmRankDataset): def __init__(self, txt_db, img_db, neg_sample_size=1, hard_neg_size=1): assert hard_neg_size > 0, \ "ItmRankDatasetHardNeg need at least 1 hard negative sample" DetectFeatTxtTokDataset.__init__(self, txt_db, img_db) txt2img = self.txt_db.txt2img self.txt2img = {id_: txt2img[id_] for id_ in self.ids} self.img2txts = self.txt_db.img2txts self.img_name_list = list(self.img2txts.keys()) assert neg_sample_size > 0 self.neg_sample_size = neg_sample_size self.hard_neg_size = hard_neg_size def reload_hard_negs(self, hard_neg_dir): self.txt2hardimgs = json.load( open(f'{hard_neg_dir}/' f'txt2hardimgs_rank{hvd.rank()}.json')) self.img2hardtxts = json.load( open(f'{hard_neg_dir}/img2hardtxts.json')) def __getitem__(self, i): gt_txt_id = self.ids[i] gt_img_fname = self.txt2img[gt_txt_id] id_pairs = [(gt_txt_id, gt_img_fname)] # sample hard negatives if self.hard_neg_size > 0: hard_neg_img_samples = random.sample( self.txt2hardimgs[gt_txt_id], self.hard_neg_size) hard_neg_txt_samples = random.sample( self.img2hardtxts[gt_img_fname], self.hard_neg_size) id_pairs.extend([(gt_txt_id, neg_img_id) for neg_img_id in hard_neg_img_samples] + [(neg_txt_id, gt_img_fname) for neg_txt_id in hard_neg_txt_samples]) # sample normal negatives if self.neg_sample_size > 0: neg_sample_img_ids = sample_negative( self.img_name_list, [gt_img_fname], self.neg_sample_size) neg_sample_txt_ids = sample_negative( self.ids, self.img2txts[gt_img_fname], self.neg_sample_size) id_pairs.extend([(gt_txt_id, neg_img_id) for neg_img_id in neg_sample_img_ids] + [(neg_txt_id, gt_img_fname) for neg_txt_id in neg_sample_txt_ids]) inputs = self._collect_inputs(id_pairs) assert len(inputs) == (1 + 2*self.neg_sample_size + 2*self.hard_neg_size) return inputs def itm_rank_collate(inputs): (input_ids, img_feats, img_pos_feats, attn_masks_text, attn_masks_img, ) = map(list, unzip(concat(i for i in inputs))) txt_lens = [i.size(0) for i in input_ids] input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long ).unsqueeze(0) num_bbs = [f.size(0) for f in img_feats] img_feat = pad_tensors(img_feats, num_bbs) img_pos_feat = pad_tensors(img_pos_feats, num_bbs) attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0) attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) sample_size = len(inputs[0]) assert all(sample_size == len(i) for i in inputs) bs, max_tl = input_ids.size() # out_size = attn_masks.size(1) gather_index = None # get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) batch = {'input_ids': input_ids, 'position_ids': position_ids, 'img_feat': img_feat, 'img_pos_feat': img_pos_feat, 'attn_masks_text': attn_masks_text, 'attn_masks_img': attn_masks_img, 'gather_index': gather_index, 'sample_size': sample_size} return batch class ItmRankDatasetHardNegFromText(DetectFeatTxtTokDataset): def __init__(self, txt_db, img_db, neg_sample_size=1): assert neg_sample_size > 0, \ "ItmRankDatasetHardNegV2 need at least 1 negative sample" super().__init__(txt_db, img_db) txt2img = self.txt_db.txt2img self.txt2img = {id_: txt2img[id_] for id_ in self.ids} self.img2txts = self.txt_db.img2txts self.img_name_list = list(self.img2txts.keys()) self.neg_sample_size = neg_sample_size def __getitem__(self, i): gt_txt_id = self.ids[i] gt_img_fname = self.txt2img[gt_txt_id] input_ids = self.txt_db[gt_txt_id]['input_ids'] input_ids = self.txt_db.combine_inputs(input_ids) input_ids = input_ids.unsqueeze(0) position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long ).unsqueeze(0) neg_img_ids = sample_negative( self.img_name_list, [gt_img_fname], self.neg_sample_size) img_ids = [gt_img_fname] + neg_img_ids # process image features (gt always first) img_feats, img_pos_feats, num_bbs = map( list, unzip(map(self._get_img_feat, img_ids))) img_feat = pad_tensors(img_feats, num_bbs) img_pos_feat = pad_tensors(img_pos_feats, num_bbs) tl = input_ids.size(1) attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long() for i, nbb in enumerate(num_bbs): attn_masks.data[i, :tl+nbb].fill_(1) out_size = attn_masks.size(1) gather_index = get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size) batch = {'input_ids': input_ids, 'position_ids': position_ids, 'img_feat': img_feat, 'img_pos_feat': img_pos_feat, 'attn_masks': attn_masks, 'gather_index': gather_index} return batch class ItmRankDatasetHardNegFromImage(DetectFeatTxtTokDataset): def __init__(self, txt_db, img_db, neg_sample_size=1): assert neg_sample_size > 0, \ "ItmRankDatasetHardNegV2 need at least 1 negative sample" super().__init__(txt_db, img_db) txt2img = self.txt_db.txt2img self.txt2img = {id_: txt2img[id_] for id_ in self.ids} self.img2txts = self.txt_db.img2txts self.txt_name_list = list(self.txt2img.keys()) self.neg_sample_size = neg_sample_size def __getitem__(self, i): gt_txt_id = self.ids[i] gt_img_id = self.txt2img[gt_txt_id] gt_txt_ids = self.img2txts[gt_img_id] # process image features (gt always first) img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id) img_feat = img_feat.unsqueeze(0) img_pos_feat = img_pos_feat.unsqueeze(0) # sample negative neg_txt_ids = sample_negative( self.txt_name_list, gt_txt_ids, self.neg_sample_size) txt_ids = [gt_txt_id] + neg_txt_ids # process text inputs all_inputs = [] txt_lens = [] for txt_id in txt_ids: input_ids = self.txt_db.combine_inputs( self.txt_db[txt_id]['input_ids']) all_inputs.append(input_ids) txt_lens.append(len(input_ids)) input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0) position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long ).unsqueeze(0) attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long() for i, tl in enumerate(txt_lens): attn_masks.data[i, :tl+nbb].fill_(1) out_size = attn_masks.size(1) gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids), len(txt_ids), tl, out_size) batch = {'input_ids': input_ids, 'position_ids': position_ids, 'img_feat': img_feat, 'img_pos_feat': img_pos_feat, 'attn_masks': attn_masks, 'gather_index': gather_index} return batch def itm_rank_hnv2_collate(inputs): assert len(inputs) == 1 return inputs[0] class ItmValDataset(DetectFeatTxtTokDataset): """ For evaluating Image-Text-Retrieval task """ def __init__(self, db_dir, img_dir, mini_batch_size=400): super().__init__(db_dir, img_dir) del self.lens self.txt2img = self.txt_db.txt2img self.img2txts = self.txt_db.img2txts self.all_img_ids = list(self.img2txts.keys()) assert len(self.img2txts) >= mini_batch_size > 0 self.bs = mini_batch_size def _get_batch_ids(self, i): gt_txt_id = self.ids[i] gt_img_id = self.txt2img[gt_txt_id] # sample fixed negatives for each gt image i = self.all_img_ids.index(gt_img_id) neg_st = i+1 neg_end = neg_st+self.bs-1 if neg_end > len(self.all_img_ids): # warp around neg_end -= len(self.all_img_ids) neg_img_ids = (self.all_img_ids[neg_st:] + self.all_img_ids[:neg_end]) else: neg_img_ids = self.all_img_ids[neg_st:neg_end] assert len(neg_img_ids) == (self.bs - 1),\ "Did not sample enough neg samples" return gt_img_id, neg_img_ids def __getitem__(self, i): """ this returns list of mini-batches """ gt_img_id, neg_img_ids = self._get_batch_ids(i) # NOTE 1st one is gt img batch = self.get_batch(i, [gt_img_id] + neg_img_ids) return batch def get_batch(self, i, img_ids): example = super().__getitem__(i) input_ids = example['input_ids'] input_ids = self.txt_db.combine_inputs(input_ids) input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone() position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long ).unsqueeze(0) # process image features (gt always first) img_feats, img_pos_feats, num_bbs = map( list, unzip(map(self._get_img_feat, img_ids))) img_feat = pad_tensors(img_feats, num_bbs) img_pos_feat = pad_tensors(img_pos_feats, num_bbs) tl = input_ids.size(1) attn_masks_text = torch.ones(len(img_ids), tl).long() # attn_masks_text = torch.ones(1, tl).long() attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long() for i, nbb in enumerate(num_bbs): attn_masks_img.data[i, :nbb].fill_(1) # out_size = attn_masks.size(1) gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size) batch = {'input_ids': input_ids, 'position_ids': position_ids, 'img_feat': img_feat, 'img_pos_feat': img_pos_feat, 'attn_masks_text': attn_masks_text, 'attn_masks_img': attn_masks_img, 'gather_index': gather_index} return batch def itm_val_collate(inputs): assert len(inputs) == 1, "input batch size > 1" return inputs[0] class ItmHardNegDataset(ItmValDataset): def _get_batch_ids(self, i): gt_txt_id = self.ids[i] gt_img_id = self.txt2img[gt_txt_id] # sample fixed negatives for each gt image i = self.all_img_ids.index(gt_img_id) all_img_ids = copy.deepcopy(self.all_img_ids) all_img_ids.remove(gt_img_id) random.shuffle(all_img_ids) neg_img_ids = all_img_ids[:self.bs] assert len(neg_img_ids) == (self.bs),\ "Did not sample enough neg samples" return gt_img_id, neg_img_ids def __getitem__(self, i): """ this returns list of mini-batches """ _, neg_img_ids = self._get_batch_ids(i) batch = self.get_batch(i, neg_img_ids) batch['gt_txt_id'] = self.ids[i] batch['neg_img_ids'] = neg_img_ids return batch itm_hn_collate = itm_val_collate class ItmEvalDataset(ItmValDataset): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids), key=lambda i: self.img_db.name2nbb[i]) def __getitem__(self, i): mini_batches = [] for st in range(0, len(self.all_img_ids), self.bs): mini_batches.append( self.get_batch(i, self.all_img_ids[st:st+self.bs])) return mini_batches itm_eval_collate = itm_val_collate