You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Files and versions

66 lines
3.0 KiB

import random
import logging
import collections
import json
import os
import itertools
import numpy as np
from collections import ChainMap
from dvl.trainer import build_dataloader, _save_checkpoint, eval_model_on_dataloader, load_dataset
logger = logging.getLogger()
def random_hard_neg(fname2id, num_hard_negatives, id2set, set2id):
# num_hard_negatives must be very small
hard_negs = dict()
for i in fname2id:
while True:
hard_neg = random.choices(set2id[id2set[i]], k=num_hard_negatives)
if fname2id[i] not in hard_neg:
hard_negs[i] = hard_neg
return hard_negs
def get_img_txt_mappings(train_txt_dbs):
train_img2txt = dict(ChainMap(*[json.load(open(os.path.join(db_folder, 'img2txts.json'))) for db_folder in train_txt_dbs]))
train_txt2img = dict(itertools.chain(*[[(v, k) for v in vals] for k, vals in train_img2txt.items()]))
train_json = [json.load(open(os.path.join(db_folder, 'img2txts.json'))) for db_folder in train_txt_dbs]
train_img2set = dict(ChainMap(*[{k:v for k in tj } for tj, v in zip(train_json, train_txt_dbs)]))
train_txt2set = {txt_id: train_img2set[img_id] for txt_id, img_id in train_txt2img.items()}
train_set2img, train_set2txt = collections.defaultdict(list), collections.defaultdict(list)
for img_id, set_id in train_img2set.items():
train_set2txt[set_id] += train_img2txt[img_id]
return train_img2txt, train_txt2img, train_img2set, train_txt2set, train_set2img, train_set2txt
def sampled_hard_negatives(all_img_dbs, args, collate_func, bi_encoder, train_img2txt, train_txt2img):
train_dataset_eval = load_dataset(all_img_dbs, args.train_txt_dbs, args.train_img_dbs, args, True)
hard_negs_txt_all, hard_negs_img_all = [], []
for dset in train_dataset_eval.datasets:
train_dataloader_hn = build_dataloader(dset, collate_func, True, args, args.valid_batch_size)
logger.info(f'eval for train dataloader len (for hn) = {len(train_dataloader_hn)}')
num_hard_sampled = min(max(args.num_hard_negatives * 2 + 10, 50), 1000)
loss_hard, correct_ratio_hard, indexer_hard, recall_hard, (hard_neg_img, hard_neg_txt) = \
eval_model_on_dataloader(bi_encoder, train_dataloader_hn, args, train_img2txt, num_hard_sampled)
[v.remove(train_txt2img[k]) for k, v in hard_neg_img.items() if train_txt2img[k] in v]
hard_neg_txt = {k: list(set(v) - set(train_img2txt[k])) for k, v in hard_neg_txt.items()}
# remove self in hard negatives as they are labels
hard_negs_txt_all.append({k: random.sample(v, args.num_hard_negatives) for k, v in hard_neg_txt.items()})
hard_negs_img_all.append({k: random.sample(v, args.num_hard_negatives) for k, v in hard_neg_img.items()})
hard_negs_txt_all = dict(collections.ChainMap(*hard_negs_txt_all))
hard_negs_img_all = dict(collections.ChainMap(*hard_negs_img_all))
return hard_negs_txt_all, hard_negs_img_all