lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
136 lines
4.9 KiB
136 lines
4.9 KiB
"""
|
|
MRM Datasets (contrastive learning version)
|
|
"""
|
|
import torch
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
from toolz.sandbox import unzip
|
|
from cytoolz import curry
|
|
|
|
from .data import (DetectFeatLmdb, DetectFeatTxtTokDataset,
|
|
pad_tensors, get_gather_index)
|
|
from .mrm import _get_img_mask, _get_img_tgt_mask, _get_feat_target
|
|
from .itm import sample_negative
|
|
|
|
|
|
# FIXME diff implementation from mrfr, mrc
|
|
def _mask_img_feat(img_feat, img_masks, neg_feats,
|
|
noop_prob=0.1, change_prob=0.1):
|
|
rand = torch.rand(*img_masks.size())
|
|
noop_mask = rand < noop_prob
|
|
change_mask = ~noop_mask & (rand < (noop_prob+change_prob)) & img_masks
|
|
img_masks_in = img_masks & ~noop_mask & ~change_mask
|
|
|
|
img_masks_ext = img_masks_in.unsqueeze(-1).expand_as(img_feat)
|
|
img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0)
|
|
|
|
n_neg = change_mask.sum().item()
|
|
feat_dim = neg_feats.size(-1)
|
|
index = torch.arange(0, change_mask.numel(), dtype=torch.long
|
|
).masked_select(change_mask.view(-1))
|
|
index = index.unsqueeze(-1).expand(-1, feat_dim)
|
|
img_feat_out = img_feat_masked.view(-1, feat_dim).scatter(
|
|
dim=0, index=index, src=neg_feats[:n_neg]).view(*img_feat.size())
|
|
|
|
return img_feat_out, img_masks_in
|
|
|
|
|
|
class MrmNceDataset(DetectFeatTxtTokDataset):
|
|
def __init__(self, mask_prob, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.mask_prob = mask_prob
|
|
|
|
def __getitem__(self, i):
|
|
example = super().__getitem__(i)
|
|
# text input
|
|
input_ids = example['input_ids']
|
|
input_ids = self.txt_db.combine_inputs(input_ids)
|
|
|
|
# image input features
|
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(
|
|
example['img_fname'])
|
|
img_mask = _get_img_mask(self.mask_prob, num_bb)
|
|
img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids))
|
|
|
|
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
|
|
|
|
return (input_ids, img_feat, img_pos_feat,
|
|
attn_masks, img_mask, img_mask_tgt,
|
|
example['img_fname'])
|
|
|
|
|
|
class NegativeImageSampler(object):
|
|
def __init__(self, img_dbs, neg_size, size_mul=8):
|
|
if not isinstance(img_dbs, list):
|
|
assert isinstance(img_dbs, DetectFeatLmdb)
|
|
img_dbs = [img_dbs]
|
|
self.neg_size = neg_size
|
|
self.img_db = JoinedDetectFeatLmdb(img_dbs)
|
|
all_imgs = []
|
|
for db in img_dbs:
|
|
all_imgs.extend(db.name2nbb.keys())
|
|
self.all_imgs = all_imgs
|
|
|
|
def sample_negative_feats(self, pos_imgs):
|
|
neg_img_ids = sample_negative(self.all_imgs, pos_imgs, self.neg_size)
|
|
all_neg_feats = torch.cat([self.img_db[img][0] for img in neg_img_ids],
|
|
dim=0)
|
|
# only use multiples of 8 for tensorcores
|
|
n_cut = all_neg_feats.size(0) % 8
|
|
if n_cut != 0:
|
|
return all_neg_feats[:-n_cut]
|
|
else:
|
|
return all_neg_feats
|
|
|
|
|
|
class JoinedDetectFeatLmdb(object):
|
|
def __init__(self, img_dbs):
|
|
assert all(isinstance(db, DetectFeatLmdb) for db in img_dbs)
|
|
self.img_dbs = img_dbs
|
|
|
|
def __getitem__(self, file_name):
|
|
for db in self.img_dbs:
|
|
if file_name in db:
|
|
return db[file_name]
|
|
raise ValueError("image does not exists")
|
|
|
|
|
|
@curry
|
|
def mrm_nce_collate(neg_sampler, inputs):
|
|
(input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts,
|
|
positive_imgs) = map(list, unzip(inputs))
|
|
|
|
txt_lens = [i.size(0) for i in input_ids]
|
|
|
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
|
|
).unsqueeze(0)
|
|
|
|
num_bbs = [f.size(0) for f in img_feats]
|
|
img_feat = pad_tensors(img_feats, num_bbs)
|
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
|
|
neg_feats = neg_sampler.sample_negative_feats(positive_imgs)
|
|
|
|
# mask features
|
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
|
|
feat_targets = _get_feat_target(img_feat, img_masks)
|
|
img_feat, img_masks_in = _mask_img_feat(img_feat, img_masks, neg_feats)
|
|
img_mask_tgt = pad_sequence(img_mask_tgts,
|
|
batch_first=True, padding_value=0)
|
|
|
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
|
|
bs, max_tl = input_ids.size()
|
|
out_size = attn_masks.size(1)
|
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
|
|
|
|
batch = {'input_ids': input_ids,
|
|
'position_ids': position_ids,
|
|
'img_feat': img_feat,
|
|
'img_pos_feat': img_pos_feat,
|
|
'attn_masks': attn_masks,
|
|
'gather_index': gather_index,
|
|
'feat_targets': feat_targets,
|
|
'img_masks': img_masks,
|
|
'img_masks_in': img_masks_in,
|
|
'img_mask_tgt': img_mask_tgt,
|
|
'neg_feats': neg_feats}
|
|
return batch
|