logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

80 lines
3.2 KiB

import torch
def compute_valid_token_num(valid_len_list):
res = 0
for one_len in valid_len_list:
res += one_len * (one_len - 1)
return res
def build_mask_matrix(seqlen, valid_len_list, prefix_len = 0):
'''
prefix_len: the length of prefix that we do not want to compute CL loss for.
(1) if a sequence of length 4 contains zero padding token (i.e., the valid length is 4),
then the loss padding matrix looks like
[0., 1., 1., 1.],
[1., 0., 1., 1.],
[1., 1., 0., 1.],
[1., 1., 1., 0.]
(2) if a sequence of length 4 contains 1 padding token (i.e., the valid length is 3),
then the loss padding matrix looks like
[0., 1., 1., 0.],
[1., 0., 1., 0.],
[1., 1., 0., 0.],
[0., 0., 0., 0.]
'''
res_list = []
base_mask = torch.ones(seqlen, seqlen) - torch.eye(seqlen, seqlen)
base_mask = base_mask.type(torch.FloatTensor)
bsz = len(valid_len_list)
for i in range(bsz):
one_base_mask = base_mask.clone()
one_valid_len = valid_len_list[i]
one_base_mask[:,one_valid_len:] = 0.
one_base_mask[one_valid_len:, :] = 0.
if prefix_len > 0:
one_base_mask[:prefix_len, :prefix_len] = 0.
res_list.append(one_base_mask)
res_mask = torch.stack(res_list, dim = 0)#torch.FloatTensor(res_list)
#print (res_mask)
assert res_mask.size() == torch.Size([bsz, seqlen, seqlen])
return res_mask
def contrastive_loss(margin, score_matrix, input_ids, pad_token_id, prefix_len=0):
'''
margin: predefined margin to push similarity score away
score_matrix: bsz x seqlen x seqlen
input_ids: bsz x seqlen
pad_token_id: indicating which tokens are padding token
'''
bsz, seqlen, _ = score_matrix.size()
gold_score = torch.diagonal(score_matrix, offset=0, dim1=1, dim2=2) # bsz x seqlen
gold_score = torch.unsqueeze(gold_score, -1)
assert gold_score.size() == torch.Size([bsz, seqlen, 1])
difference_matrix = gold_score - score_matrix
assert difference_matrix.size() == torch.Size([bsz, seqlen, seqlen])
loss_matrix = margin - difference_matrix # bsz x seqlen x seqlen
loss_matrix = torch.nn.functional.relu(loss_matrix)
### input mask
input_mask = torch.ones_like(input_ids).type(torch.FloatTensor)
if loss_matrix.is_cuda:
input_mask = input_mask.cuda(loss_matrix.get_device())
input_mask = input_mask.masked_fill(input_ids.eq(pad_token_id), 0.0)
if loss_matrix.is_cuda:
input_mask = input_mask.cuda(loss_matrix.get_device())
valid_len_list = torch.sum(input_mask, dim = -1).tolist()
loss_mask = build_mask_matrix(seqlen, [int(item) for item in valid_len_list], prefix_len)
if score_matrix.is_cuda:
loss_mask = loss_mask.cuda(score_matrix.get_device())
masked_loss_matrix = loss_matrix * loss_mask
loss_matrix = torch.sum(masked_loss_matrix, dim = -1)
assert loss_matrix.size() == input_ids.size()
loss_matrix = loss_matrix * input_mask
cl_loss = torch.sum(loss_matrix) / torch.sum(loss_mask)
return cl_loss