import torch def split_tensors(n, x): if torch.is_tensor(x): assert x.shape[0] % n == 0 x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1) elif type(x) is list or type(x) is tuple: x = [split_tensors(n, _) for _ in x] elif x is None: x = [None] * n return x # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. #def decode_sequence(ix_to_word, seq): # # N, D = seq.size() # N, D = seq.shape # out = [] # for i in range(N): # txt = '' # for j in range(D): # ix = seq[i,j] # if ix > 0 : # if j >= 1: # txt = txt + ' ' # txt = txt + ix_to_word[str(ix.item())] # else: # break # if int(os.getenv('REMOVE_BAD_ENDINGS', '0')): # flag = 0 # words = txt.split(' ') # for j in range(len(words)): # if words[-j-1] not in bad_endings: # flag = -j # break # txt = ' '.join(words[0:len(words)+flag]) # out.append(txt.replace('@@ ', '')) # return out def decode_sequence(ix_to_word, seq, remove_bad_endings = True): # N, D = seq.size() N, D = seq.shape bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] bad_endings += ['the'] out = [] for i in range(N): txt = '' for j in range(D): ix = seq[i,j] if ix > 0 : if j >= 1: txt = txt + ' ' txt = txt + ix_to_word[str(ix.item())] else: break if remove_bad_endings is True: flag = 0 words = txt.split(' ') for j in range(len(words)): if words[-j-1] not in bad_endings: flag = -j break txt = ' '.join(words[0:len(words)+flag]) out.append(txt.replace('@@ ', '')) return out def penalty_builder(penalty_config): if penalty_config == '': return lambda x,y: y pen_type, alpha = penalty_config.split('_') alpha = float(alpha) if pen_type == 'wu': return lambda x,y: length_wu(x,y,alpha) if pen_type == 'avg': return lambda x,y: length_average(x,y,alpha) def length_wu(length, logprobs, alpha=0.): """ NMT length re-ranking score from "Google's Neural Machine Translation System" :cite:`wu2016google`. """ modifier = (((5 + length) ** alpha) / ((5 + 1) ** alpha)) return (logprobs / modifier) def length_average(length, logprobs, alpha=0.): """ Returns the average probability of tokens in a sequence. """ return logprobs / length