lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
53 lines
1.8 KiB
53 lines
1.8 KiB
""" Image Text Retrieval evaluation helper """
|
|
import torch
|
|
|
|
|
|
@torch.no_grad()
|
|
def itm_eval(score_matrix, txt_ids, img_ids, txt2img, img2txts):
|
|
# image retrieval
|
|
img2j = {i: j for j, i in enumerate(img_ids)}
|
|
_, rank_txt = score_matrix.topk(10, dim=1)
|
|
gt_img_j = torch.LongTensor([img2j[txt2img[txt_id]]
|
|
for txt_id in txt_ids],
|
|
).to(rank_txt.device
|
|
).unsqueeze(1).expand_as(rank_txt)
|
|
rank = (rank_txt == gt_img_j).nonzero()
|
|
if rank.numel():
|
|
ir_r1 = (rank < 1).sum().item() / len(txt_ids)
|
|
ir_r5 = (rank < 5).sum().item() / len(txt_ids)
|
|
ir_r10 = (rank < 10).sum().item() / len(txt_ids)
|
|
else:
|
|
ir_r1, ir_r5, ir_r10 = 0, 0, 0
|
|
|
|
# text retrieval
|
|
txt2i = {t: i for i, t in enumerate(txt_ids)}
|
|
_, rank_img = score_matrix.topk(10, dim=0)
|
|
tr_r1, tr_r5, tr_r10 = 0, 0, 0
|
|
for j, img_id in enumerate(img_ids):
|
|
gt_is = [txt2i[t] for t in img2txts[img_id]]
|
|
ranks = [(rank_img[:, j] == i).nonzero() for i in gt_is]
|
|
rank = min([10] + [r.item() for r in ranks if r.numel()])
|
|
if rank < 1:
|
|
tr_r1 += 1
|
|
if rank < 5:
|
|
tr_r5 += 1
|
|
if rank < 10:
|
|
tr_r10 += 1
|
|
tr_r1 /= len(img_ids)
|
|
tr_r5 /= len(img_ids)
|
|
tr_r10 /= len(img_ids)
|
|
|
|
tr_mean = (tr_r1 + tr_r5 + tr_r10) / 3
|
|
ir_mean = (ir_r1 + ir_r5 + ir_r10) / 3
|
|
r_mean = (tr_mean + ir_mean) / 2
|
|
|
|
eval_log = {'txt_r1': tr_r1,
|
|
'txt_r5': tr_r5,
|
|
'txt_r10': tr_r10,
|
|
'txt_r_mean': tr_mean,
|
|
'img_r1': ir_r1,
|
|
'img_r5': ir_r5,
|
|
'img_r10': ir_r10,
|
|
'img_r_mean': ir_mean,
|
|
'r_mean': r_mean}
|
|
return eval_log
|