lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
62 lines
1.8 KiB
62 lines
1.8 KiB
import numpy as np
|
|
|
|
|
|
def i2t(sims, npts=None, return_ranks=False):
|
|
"""
|
|
Images->Text (Image Annotation)
|
|
sims: (N, 5N) matrix of similarity im-cap
|
|
"""
|
|
npts = sims.shape[0]
|
|
ranks = np.zeros(npts)
|
|
top1 = np.zeros(npts)
|
|
for index in range(npts):
|
|
inds = np.argsort(sims[index])[::-1]
|
|
# Score
|
|
rank = 1e20
|
|
for i in range(5 * index, 5 * index + 5, 1):
|
|
tmp = np.where(inds == i)[0][0]
|
|
if tmp < rank:
|
|
rank = tmp
|
|
ranks[index] = rank
|
|
top1[index] = inds[0]
|
|
|
|
# Compute metrics
|
|
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
|
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
|
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
|
medr = np.floor(np.median(ranks)) + 1
|
|
meanr = ranks.mean() + 1
|
|
if return_ranks:
|
|
return (r1, r5, r10, medr, meanr), (ranks, top1)
|
|
else:
|
|
return (r1, r5, r10, medr, meanr)
|
|
|
|
|
|
def t2i(sims, npts=None, return_ranks=False):
|
|
"""
|
|
Text->Images (Image Search)
|
|
sims: (N, 5N) matrix of similarity im-cap
|
|
"""
|
|
npts = sims.shape[0]
|
|
ranks = np.zeros(5 * npts)
|
|
top1 = np.zeros(5 * npts)
|
|
|
|
# --> (5N(caption), N(image))
|
|
sims = sims.T
|
|
|
|
for index in range(npts):
|
|
for i in range(5):
|
|
inds = np.argsort(sims[5 * index + i])[::-1]
|
|
ranks[5 * index + i] = np.where(inds == index)[0][0]
|
|
top1[5 * index + i] = inds[0]
|
|
|
|
# Compute metrics
|
|
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
|
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
|
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
|
medr = np.floor(np.median(ranks)) + 1
|
|
meanr = ranks.mean() + 1
|
|
if return_ranks:
|
|
return (r1, r5, r10, medr, meanr), (ranks, top1)
|
|
else:
|
|
return (r1, r5, r10, medr, meanr)
|