clip-caption-reward
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
93 lines
2.7 KiB
93 lines
2.7 KiB
2 years ago
|
import torch
|
||
|
|
||
|
def split_tensors(n, x):
|
||
|
if torch.is_tensor(x):
|
||
|
assert x.shape[0] % n == 0
|
||
|
x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
|
||
|
elif type(x) is list or type(x) is tuple:
|
||
|
x = [split_tensors(n, _) for _ in x]
|
||
|
elif x is None:
|
||
|
x = [None] * n
|
||
|
return x
|
||
|
|
||
|
# Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token.
|
||
|
#def decode_sequence(ix_to_word, seq):
|
||
|
# # N, D = seq.size()
|
||
|
# N, D = seq.shape
|
||
|
# out = []
|
||
|
# for i in range(N):
|
||
|
# txt = ''
|
||
|
# for j in range(D):
|
||
|
# ix = seq[i,j]
|
||
|
# if ix > 0 :
|
||
|
# if j >= 1:
|
||
|
# txt = txt + ' '
|
||
|
# txt = txt + ix_to_word[str(ix.item())]
|
||
|
# else:
|
||
|
# break
|
||
|
# if int(os.getenv('REMOVE_BAD_ENDINGS', '0')):
|
||
|
# flag = 0
|
||
|
# words = txt.split(' ')
|
||
|
# for j in range(len(words)):
|
||
|
# if words[-j-1] not in bad_endings:
|
||
|
# flag = -j
|
||
|
# break
|
||
|
# txt = ' '.join(words[0:len(words)+flag])
|
||
|
# out.append(txt.replace('@@ ', ''))
|
||
|
# return out
|
||
|
|
||
|
def decode_sequence(ix_to_word, seq, remove_bad_endings = True):
|
||
|
# N, D = seq.size()
|
||
|
N, D = seq.shape
|
||
|
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
|
||
|
bad_endings += ['the']
|
||
|
out = []
|
||
|
for i in range(N):
|
||
|
txt = ''
|
||
|
for j in range(D):
|
||
|
ix = seq[i,j]
|
||
|
if ix > 0 :
|
||
|
if j >= 1:
|
||
|
txt = txt + ' '
|
||
|
txt = txt + ix_to_word[str(ix.item())]
|
||
|
else:
|
||
|
break
|
||
|
if remove_bad_endings is True:
|
||
|
flag = 0
|
||
|
words = txt.split(' ')
|
||
|
for j in range(len(words)):
|
||
|
if words[-j-1] not in bad_endings:
|
||
|
flag = -j
|
||
|
break
|
||
|
txt = ' '.join(words[0:len(words)+flag])
|
||
|
out.append(txt.replace('@@ ', ''))
|
||
|
return out
|
||
|
|
||
|
|
||
|
|
||
|
def penalty_builder(penalty_config):
|
||
|
if penalty_config == '':
|
||
|
return lambda x,y: y
|
||
|
pen_type, alpha = penalty_config.split('_')
|
||
|
alpha = float(alpha)
|
||
|
if pen_type == 'wu':
|
||
|
return lambda x,y: length_wu(x,y,alpha)
|
||
|
if pen_type == 'avg':
|
||
|
return lambda x,y: length_average(x,y,alpha)
|
||
|
|
||
|
def length_wu(length, logprobs, alpha=0.):
|
||
|
"""
|
||
|
NMT length re-ranking score from
|
||
|
"Google's Neural Machine Translation System" :cite:`wu2016google`.
|
||
|
"""
|
||
|
|
||
|
modifier = (((5 + length) ** alpha) /
|
||
|
((5 + 1) ** alpha))
|
||
|
return (logprobs / modifier)
|
||
|
|
||
|
def length_average(length, logprobs, alpha=0.):
|
||
|
"""
|
||
|
Returns the average probability of tokens in a sequence.
|
||
|
"""
|
||
|
return logprobs / length
|