lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
366 lines
16 KiB
366 lines
16 KiB
2 years ago
|
import torch
|
||
|
import numpy as np
|
||
|
import itertools
|
||
|
|
||
|
from torch.nn.utils.rnn import pad_sequence
|
||
|
from uniter_model.data.itm import DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb, get_ids_and_lens
|
||
|
from uniter_model.data.data import get_gather_index
|
||
|
from toolz.sandbox import unzip
|
||
|
from cytoolz import concat
|
||
|
from GLOBAL_VARIABLES import N_EXAMPLES_TEACHER
|
||
|
|
||
|
|
||
|
def pad_tensors(tensors, lens=None, pad=0):
|
||
|
"""B x [T, ...]"""
|
||
|
if lens is None:
|
||
|
lens = [t.size(0) for t in tensors]
|
||
|
max_len = max(lens)
|
||
|
bs = len(tensors)
|
||
|
hid = tensors[0].size(-1)
|
||
|
dtype = tensors[0].dtype
|
||
|
output = torch.zeros(bs, max_len, hid, dtype=dtype)
|
||
|
if pad:
|
||
|
output.data.fill_(pad)
|
||
|
for i, (t, l) in enumerate(zip(tensors, lens)):
|
||
|
output.data[i, :l, ...] = t.data
|
||
|
return output
|
||
|
|
||
|
|
||
|
# for ITM task
|
||
|
class ItmFastDataset(DetectFeatTxtTokDataset):
|
||
|
""" NOTE this Dataset handles distributed training itself
|
||
|
(for more efficient negative sampling) """
|
||
|
def __init__(self, txt_db, img_db, num_hard_negatives=0, img_meta=None, tokenizer=None):
|
||
|
assert isinstance(txt_db, TxtTokLmdb)
|
||
|
assert isinstance(img_db, DetectFeatLmdb)
|
||
|
|
||
|
self.txt_db = txt_db
|
||
|
self.img_db = img_db
|
||
|
|
||
|
self.txt_lens, self.ids = get_ids_and_lens(txt_db)
|
||
|
self.ids_2_idx = {idx:i for i, idx in enumerate(self.ids)}
|
||
|
self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids))
|
||
|
|
||
|
self.num_hard_negatives = num_hard_negatives
|
||
|
self.img_meta = img_meta
|
||
|
self.tokenizer = tokenizer
|
||
|
self.train_imgs = None
|
||
|
self.neg_imgs = None
|
||
|
# self.new_epoch(hard_negatives)
|
||
|
|
||
|
def new_epoch(self, hard_negatives_img=None, hard_negatives_txt=None):
|
||
|
""" should be called every epoch for more randomness"""
|
||
|
self.lens = []
|
||
|
self.train_imgs, self.neg_imgs = [], []
|
||
|
self.train_txts, self.neg_txts = [], []
|
||
|
for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)):
|
||
|
img_fname = super().__getitem__(i)['img_fname']
|
||
|
self.train_imgs.append(img_fname)
|
||
|
self.train_txts.append(id_)
|
||
|
if hard_negatives_img is not None and self.num_hard_negatives > 0:
|
||
|
self.neg_imgs.append(hard_negatives_img[id_][:self.num_hard_negatives])
|
||
|
self.neg_txts.append(hard_negatives_txt[img_fname][:self.num_hard_negatives])
|
||
|
else:
|
||
|
self.neg_imgs.append(None)
|
||
|
self.neg_txts.append(None)
|
||
|
self.lens.append(tl + self.img_db.name2nbb[img_fname])
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
example = super().__getitem__(i)
|
||
|
# labels and negative images should be sampled every epoch
|
||
|
img_fname, hard_neg_imgs = self.train_imgs[i], self.neg_imgs[i]
|
||
|
txt_fname, hard_neg_txts = self.ids[i], self.neg_txts[i]
|
||
|
|
||
|
img_input_ids = torch.Tensor([101]).long()
|
||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname)
|
||
|
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
|
||
|
|
||
|
# text input
|
||
|
input_ids = example['input_ids']
|
||
|
input_ids = self.txt_db.combine_inputs(input_ids)
|
||
|
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
|
||
|
|
||
|
if hard_neg_imgs is not None:
|
||
|
# TODO: add hard negative here
|
||
|
neg_imgs = dict({'img_input_ids': [], 'img_feat': [], 'img_pos_feat': [], 'num_bb': [], 'attn_masks_img': [],
|
||
|
'caption_ids': [], 'attn_masks_captions': []})
|
||
|
for neg_id in hard_neg_imgs:
|
||
|
neg_imgs['img_input_ids'].append(torch.Tensor([101]).long())
|
||
|
t = self._get_img_feat(neg_id)
|
||
|
neg_imgs['img_feat'].append(t[0])
|
||
|
neg_imgs['img_pos_feat'].append(t[1])
|
||
|
neg_imgs['num_bb'].append(t[2])
|
||
|
neg_imgs['attn_masks_img'].append(torch.ones(t[2]+1, dtype=torch.long))
|
||
|
if self.img_meta is not None:
|
||
|
tmp = [self.tokenizer.encode(i, add_special_tokens=False) + [self.tokenizer.sep_token_id]
|
||
|
for i in self.img_meta[neg_id]['caption_multiple']]
|
||
|
neg_imgs['caption_ids'].append(torch.tensor([self.tokenizer.cls_token_id] + sum(tmp, []),
|
||
|
dtype=input_ids.dtype, device=input_ids.device))
|
||
|
neg_imgs['attn_masks_captions'].append(torch.ones(len(neg_imgs['caption_ids'][-1]), dtype=torch.long))
|
||
|
# debug = [self.tokenizer.encode(a) for a in self.img_meta[img_fname]['annotation']]
|
||
|
neg_txts = dict({'input_ids':[], 'position_ids':[], 'attention_mask':[]})
|
||
|
for neg_id in hard_neg_txts:
|
||
|
ei = super().__getitem__(self.ids_2_idx[neg_id])
|
||
|
input_ids_ei = ei['input_ids']
|
||
|
neg_txts['input_ids'].append(self.txt_db.combine_inputs(input_ids_ei))
|
||
|
neg_txts['attention_mask'].append(torch.ones(len(neg_txts['input_ids'][-1]), dtype=torch.long))
|
||
|
else:
|
||
|
neg_imgs = None
|
||
|
neg_txts = None
|
||
|
|
||
|
if self.img_meta is not None:
|
||
|
caption_ids = [self.tokenizer.encode(i, add_special_tokens=False) + [self.tokenizer.sep_token_id] for i in self.img_meta[img_fname]['caption_multiple']]
|
||
|
caption_ids = torch.tensor([self.tokenizer.cls_token_id] + sum(caption_ids, []), dtype=input_ids.dtype, device=input_ids.device)
|
||
|
attn_masks_captions = torch.ones(len(caption_ids), dtype=torch.long)
|
||
|
# debug = [self.tokenizer.encode(a) for a in self.img_meta[img_fname]['annotation']]
|
||
|
else:
|
||
|
caption_ids = None
|
||
|
attn_masks_captions = None
|
||
|
|
||
|
# target = torch.Tensor(1).long()
|
||
|
# target.data.fill_(ground_truth_label)
|
||
|
return input_ids, img_feat, img_pos_feat, img_input_ids, attn_masks, attn_masks_img, self.ids[i], img_fname, neg_imgs, neg_txts, caption_ids, attn_masks_captions
|
||
|
|
||
|
|
||
|
def itm_fast_collate_kd(inputs):
|
||
|
input_ids, img_feats, img_pos_feats, img_input_ids, attn_masks_text, attn_masks_img, idx, img_fname, negs, caption_ids, attn_masks_captions = map(list, unzip(inputs))
|
||
|
|
||
|
# txt_lens = [i.size(0) for i in input_ids]
|
||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
||
|
captions_ids = pad_sequence(caption_ids, batch_first=True, padding_value=0) if caption_ids[0] is not None else None
|
||
|
|
||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
|
||
|
position_ids_captions = torch.arange(0, captions_ids.size(1), dtype=torch.long).unsqueeze(0) if caption_ids[0] is not None else None
|
||
|
|
||
|
if not None in negs:
|
||
|
num_bbs_neg = list(itertools.chain(*[n['num_bb'] for n in negs]))
|
||
|
img_feats_neg = list(itertools.chain(*[n['img_feat'] for n in negs]))
|
||
|
img_input_ids_neg = list(itertools.chain(*[n['img_input_ids'] for n in negs]))
|
||
|
img_pos_feat_neg = list(itertools.chain(*[n['img_pos_feat'] for n in negs]))
|
||
|
attn_masks_img_neg = list(itertools.chain(*[n['attn_masks_img'] for n in negs]))
|
||
|
else:
|
||
|
num_bbs_neg = []
|
||
|
img_feats_neg = []
|
||
|
img_input_ids_neg = []
|
||
|
img_pos_feat_neg = []
|
||
|
attn_masks_img_neg = []
|
||
|
|
||
|
num_bbs = [f.size(0) for f in img_feats] + num_bbs_neg
|
||
|
img_feat = pad_tensors(img_feats+img_feats_neg, num_bbs)
|
||
|
img_pos_feat = pad_tensors(img_pos_feats+img_pos_feat_neg, num_bbs)
|
||
|
|
||
|
img_input_ids = pad_sequence(img_input_ids+img_input_ids_neg, batch_first=True, padding_value=0)
|
||
|
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
|
||
|
|
||
|
attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0)
|
||
|
attn_masks_captions = pad_sequence(attn_masks_captions, batch_first=True, padding_value=0) if attn_masks_captions[0] is not None else None
|
||
|
attn_masks_img = pad_sequence(attn_masks_img+attn_masks_img_neg, batch_first=True, padding_value=0)
|
||
|
sample_size = len(inputs[0])
|
||
|
assert all(sample_size == len(i) for i in inputs)
|
||
|
|
||
|
bs, max_tl = input_ids.size()
|
||
|
out_size = attn_masks_img.size(1)
|
||
|
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
|
||
|
|
||
|
img_feat_teacher = img_feat[:N_EXAMPLES_TEACHER].repeat(bs, 1, 1)
|
||
|
img_pos_feat_teacher = img_pos_feat[:N_EXAMPLES_TEACHER].repeat(bs, 1, 1)
|
||
|
attn_masks_img_teacher = attn_masks_img[:N_EXAMPLES_TEACHER].repeat(bs, 1)[:, 1:]
|
||
|
|
||
|
input_ids_teacher = input_ids.unsqueeze(1).repeat(1, 10, 1).view(bs*N_EXAMPLES_TEACHER, -1)
|
||
|
position_ids_teacher = position_ids
|
||
|
attn_masks_text_teacher = attn_masks_text.unsqueeze(1).repeat(1, 10, 1).view(bs*N_EXAMPLES_TEACHER, -1)
|
||
|
|
||
|
attn_masks_teacher = torch.cat([attn_masks_text_teacher, attn_masks_img_teacher], dim=1)
|
||
|
|
||
|
batch = {
|
||
|
'txt_ids': input_ids,
|
||
|
'img_ids': img_feat,
|
||
|
'caption_ids': captions_ids,
|
||
|
'txt_pos_ids': position_ids,
|
||
|
'img_pos_ids': img_pos_feat,
|
||
|
'caption_pos_ids': position_ids_captions,
|
||
|
'txt_attn_masks': attn_masks_text,
|
||
|
'img_attn_masks': attn_masks_img,
|
||
|
'caption_attn_masks': attn_masks_captions,
|
||
|
'img_txt_ids': img_input_ids,
|
||
|
'img_txt_pos_ids': img_position_ids,
|
||
|
'gather_index': gather_index,
|
||
|
'sample_size': sample_size,
|
||
|
'pos_ctx_indices': list(range(bs)),
|
||
|
'neg_ctx_indices': list(range(bs, len(num_bbs))),
|
||
|
'txt_index': idx,
|
||
|
'img_fname': img_fname,
|
||
|
|
||
|
'img_feat_teacher': img_feat_teacher,
|
||
|
'img_pos_feat_teacher': img_pos_feat_teacher,
|
||
|
'input_ids_teacher': input_ids_teacher,
|
||
|
'position_ids_teacher': position_ids_teacher,
|
||
|
'attn_masks_teacher': attn_masks_teacher
|
||
|
}
|
||
|
return batch
|
||
|
|
||
|
|
||
|
def itm_fast_collate(inputs):
|
||
|
input_ids, img_feats, img_pos_feats, img_input_ids, attn_masks_text, attn_masks_img, idx, img_fname, neg_imgs, neg_txts, caption_ids, attn_masks_captions = map(list, unzip(inputs))
|
||
|
bs = len(input_ids)
|
||
|
# txt_lens = [i.size(0) for i in input_ids]
|
||
|
|
||
|
if not None in neg_imgs:
|
||
|
num_bbs_neg = list(itertools.chain(*[n['num_bb'] for n in neg_imgs]))
|
||
|
img_feats_neg = list(itertools.chain(*[n['img_feat'] for n in neg_imgs]))
|
||
|
img_input_ids_neg = list(itertools.chain(*[n['img_input_ids'] for n in neg_imgs]))
|
||
|
img_pos_feat_neg = list(itertools.chain(*[n['img_pos_feat'] for n in neg_imgs]))
|
||
|
attn_masks_img_neg = list(itertools.chain(*[n['attn_masks_img'] for n in neg_imgs]))
|
||
|
caption_ids_neg = list(itertools.chain(*[n['caption_ids'] for n in neg_imgs]))
|
||
|
attn_masks_captions_neg = list(itertools.chain(*[n['attn_masks_captions'] for n in neg_imgs]))
|
||
|
|
||
|
input_ids_neg = list(itertools.chain(*[n['input_ids'] for n in neg_txts]))
|
||
|
attn_masks_text_neg = list(itertools.chain(*[n['attention_mask'] for n in neg_txts]))
|
||
|
else:
|
||
|
num_bbs_neg = []
|
||
|
img_feats_neg = []
|
||
|
img_input_ids_neg = []
|
||
|
img_pos_feat_neg = []
|
||
|
attn_masks_img_neg = []
|
||
|
caption_ids_neg = []
|
||
|
attn_masks_captions_neg = []
|
||
|
|
||
|
input_ids_neg = []
|
||
|
attn_masks_text_neg = []
|
||
|
|
||
|
input_ids = pad_sequence(input_ids+input_ids_neg, batch_first=True, padding_value=0)
|
||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
|
||
|
|
||
|
captions_ids = pad_sequence(caption_ids+caption_ids_neg, batch_first=True, padding_value=0) if caption_ids[0] is not None else None
|
||
|
position_ids_captions = torch.arange(0, captions_ids.size(1), dtype=torch.long).unsqueeze(0) if caption_ids[0] is not None else None
|
||
|
|
||
|
num_bbs = [f.size(0) for f in img_feats] + num_bbs_neg
|
||
|
img_feat = pad_tensors(img_feats+img_feats_neg, num_bbs)
|
||
|
img_pos_feat = pad_tensors(img_pos_feats+img_pos_feat_neg, num_bbs)
|
||
|
|
||
|
img_input_ids = pad_sequence(img_input_ids+img_input_ids_neg, batch_first=True, padding_value=0)
|
||
|
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
|
||
|
|
||
|
attn_masks_text = pad_sequence(attn_masks_text+attn_masks_text_neg, batch_first=True, padding_value=0)
|
||
|
attn_masks_captions = pad_sequence(attn_masks_captions+attn_masks_captions_neg, batch_first=True, padding_value=0) if attn_masks_captions[0] is not None else None
|
||
|
attn_masks_img = pad_sequence(attn_masks_img+attn_masks_img_neg, batch_first=True, padding_value=0)
|
||
|
sample_size = bs
|
||
|
# assert all(sample_size == len(i) for i in inputs)
|
||
|
|
||
|
max_tl = input_ids.shape[1]
|
||
|
out_size = attn_masks_img.size(1)
|
||
|
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
|
||
|
|
||
|
batch = {
|
||
|
'txts': {
|
||
|
'input_ids': input_ids,
|
||
|
'position_ids': position_ids,
|
||
|
'attention_mask': attn_masks_text,
|
||
|
'img_feat': None,
|
||
|
'img_pos_feat': None,
|
||
|
'img_masks': None,
|
||
|
'gather_index': None
|
||
|
},
|
||
|
'imgs': {
|
||
|
'input_ids': img_input_ids,
|
||
|
'position_ids': img_position_ids,
|
||
|
'attention_mask': attn_masks_img,
|
||
|
'img_feat': img_feat,
|
||
|
'img_pos_feat': img_pos_feat,
|
||
|
'img_masks': None,
|
||
|
'gather_index': gather_index
|
||
|
},
|
||
|
'caps': {
|
||
|
'input_ids': captions_ids,
|
||
|
'position_ids': position_ids_captions,
|
||
|
'attention_mask': attn_masks_captions,
|
||
|
'img_feat': None,
|
||
|
'img_pos_feat': None,
|
||
|
'img_masks': None,
|
||
|
'gather_index': None
|
||
|
},
|
||
|
'sample_size': sample_size,
|
||
|
'pos_ctx_indices': list(range(bs)),
|
||
|
'neg_ctx_indices': list(range(bs, len(num_bbs))),
|
||
|
'txt_index': idx,
|
||
|
'img_fname': img_fname
|
||
|
}
|
||
|
return batch
|
||
|
|
||
|
|
||
|
class ItmValDataset(DetectFeatTxtTokDataset):
|
||
|
""" For evaluating Image-Text-Retrieval task """
|
||
|
def __init__(self, db_dir, img_dir, mini_batch_size=400):
|
||
|
super().__init__(db_dir, img_dir)
|
||
|
del self.lens
|
||
|
self.txt2img = self.txt_db.txt2img
|
||
|
self.img2txts = self.txt_db.img2txts
|
||
|
self.all_img_ids = list(self.img2txts.keys())
|
||
|
|
||
|
assert len(self.img2txts) >= mini_batch_size > 0
|
||
|
self.bs = mini_batch_size
|
||
|
|
||
|
def _get_batch_ids(self, i):
|
||
|
gt_txt_id = self.ids[i]
|
||
|
gt_img_id = self.txt2img[gt_txt_id]
|
||
|
|
||
|
# sample fixed negatives for each gt image
|
||
|
i = self.all_img_ids.index(gt_img_id)
|
||
|
neg_st = i+1
|
||
|
neg_end = neg_st+self.bs-1
|
||
|
if neg_end > len(self.all_img_ids):
|
||
|
# warp around
|
||
|
neg_end -= len(self.all_img_ids)
|
||
|
neg_img_ids = (self.all_img_ids[neg_st:]
|
||
|
+ self.all_img_ids[:neg_end])
|
||
|
else:
|
||
|
neg_img_ids = self.all_img_ids[neg_st:neg_end]
|
||
|
|
||
|
assert len(neg_img_ids) == (self.bs - 1),\
|
||
|
"Did not sample enough neg samples"
|
||
|
|
||
|
return gt_img_id, neg_img_ids
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
""" this returns list of mini-batches """
|
||
|
gt_img_id, neg_img_ids = self._get_batch_ids(i)
|
||
|
# NOTE 1st one is gt img
|
||
|
batch = self.get_batch(i, [gt_img_id] + neg_img_ids)
|
||
|
return batch
|
||
|
|
||
|
def get_batch(self, i, img_ids):
|
||
|
example = super().__getitem__(i)
|
||
|
|
||
|
input_ids = example['input_ids']
|
||
|
input_ids = self.txt_db.combine_inputs(input_ids)
|
||
|
input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone()
|
||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
|
||
|
).unsqueeze(0)
|
||
|
|
||
|
# process image features (gt always first)
|
||
|
img_feats, img_pos_feats, num_bbs = map(
|
||
|
list, unzip(map(self._get_img_feat, img_ids)))
|
||
|
img_feat = pad_tensors(img_feats, num_bbs)
|
||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
|
||
|
|
||
|
tl = input_ids.size(1)
|
||
|
attn_masks_text = torch.ones(len(img_ids), tl).long()
|
||
|
# attn_masks_text = torch.ones(1, tl).long()
|
||
|
attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long()
|
||
|
for i, nbb in enumerate(num_bbs):
|
||
|
attn_masks_img.data[i, :nbb].fill_(1)
|
||
|
|
||
|
# out_size = attn_masks.size(1)
|
||
|
gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size)
|
||
|
|
||
|
batch = {'input_ids': input_ids,
|
||
|
'position_ids': position_ids,
|
||
|
'img_feat': img_feat,
|
||
|
'img_pos_feat': img_pos_feat,
|
||
|
'attn_masks_text': attn_masks_text,
|
||
|
'attn_masks_img': attn_masks_img,
|
||
|
'gather_index': gather_index}
|
||
|
return batch
|
||
|
|
||
|
|
||
|
# for VQA
|