magic
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
80 lines
3.2 KiB
80 lines
3.2 KiB
2 years ago
|
import torch
|
||
|
|
||
|
def compute_valid_token_num(valid_len_list):
|
||
|
res = 0
|
||
|
for one_len in valid_len_list:
|
||
|
res += one_len * (one_len - 1)
|
||
|
return res
|
||
|
|
||
|
def build_mask_matrix(seqlen, valid_len_list, prefix_len = 0):
|
||
|
'''
|
||
|
prefix_len: the length of prefix that we do not want to compute CL loss for.
|
||
|
|
||
|
(1) if a sequence of length 4 contains zero padding token (i.e., the valid length is 4),
|
||
|
then the loss padding matrix looks like
|
||
|
[0., 1., 1., 1.],
|
||
|
[1., 0., 1., 1.],
|
||
|
[1., 1., 0., 1.],
|
||
|
[1., 1., 1., 0.]
|
||
|
|
||
|
(2) if a sequence of length 4 contains 1 padding token (i.e., the valid length is 3),
|
||
|
then the loss padding matrix looks like
|
||
|
[0., 1., 1., 0.],
|
||
|
[1., 0., 1., 0.],
|
||
|
[1., 1., 0., 0.],
|
||
|
[0., 0., 0., 0.]
|
||
|
'''
|
||
|
res_list = []
|
||
|
base_mask = torch.ones(seqlen, seqlen) - torch.eye(seqlen, seqlen)
|
||
|
base_mask = base_mask.type(torch.FloatTensor)
|
||
|
bsz = len(valid_len_list)
|
||
|
for i in range(bsz):
|
||
|
one_base_mask = base_mask.clone()
|
||
|
one_valid_len = valid_len_list[i]
|
||
|
one_base_mask[:,one_valid_len:] = 0.
|
||
|
one_base_mask[one_valid_len:, :] = 0.
|
||
|
if prefix_len > 0:
|
||
|
one_base_mask[:prefix_len, :prefix_len] = 0.
|
||
|
res_list.append(one_base_mask)
|
||
|
res_mask = torch.stack(res_list, dim = 0)#torch.FloatTensor(res_list)
|
||
|
#print (res_mask)
|
||
|
assert res_mask.size() == torch.Size([bsz, seqlen, seqlen])
|
||
|
return res_mask
|
||
|
|
||
|
def contrastive_loss(margin, score_matrix, input_ids, pad_token_id, prefix_len=0):
|
||
|
'''
|
||
|
margin: predefined margin to push similarity score away
|
||
|
score_matrix: bsz x seqlen x seqlen
|
||
|
input_ids: bsz x seqlen
|
||
|
pad_token_id: indicating which tokens are padding token
|
||
|
'''
|
||
|
bsz, seqlen, _ = score_matrix.size()
|
||
|
gold_score = torch.diagonal(score_matrix, offset=0, dim1=1, dim2=2) # bsz x seqlen
|
||
|
gold_score = torch.unsqueeze(gold_score, -1)
|
||
|
assert gold_score.size() == torch.Size([bsz, seqlen, 1])
|
||
|
difference_matrix = gold_score - score_matrix
|
||
|
assert difference_matrix.size() == torch.Size([bsz, seqlen, seqlen])
|
||
|
loss_matrix = margin - difference_matrix # bsz x seqlen x seqlen
|
||
|
loss_matrix = torch.nn.functional.relu(loss_matrix)
|
||
|
|
||
|
### input mask
|
||
|
input_mask = torch.ones_like(input_ids).type(torch.FloatTensor)
|
||
|
if loss_matrix.is_cuda:
|
||
|
input_mask = input_mask.cuda(loss_matrix.get_device())
|
||
|
input_mask = input_mask.masked_fill(input_ids.eq(pad_token_id), 0.0)
|
||
|
|
||
|
if loss_matrix.is_cuda:
|
||
|
input_mask = input_mask.cuda(loss_matrix.get_device())
|
||
|
|
||
|
valid_len_list = torch.sum(input_mask, dim = -1).tolist()
|
||
|
loss_mask = build_mask_matrix(seqlen, [int(item) for item in valid_len_list], prefix_len)
|
||
|
if score_matrix.is_cuda:
|
||
|
loss_mask = loss_mask.cuda(score_matrix.get_device())
|
||
|
masked_loss_matrix = loss_matrix * loss_mask
|
||
|
|
||
|
loss_matrix = torch.sum(masked_loss_matrix, dim = -1)
|
||
|
assert loss_matrix.size() == input_ids.size()
|
||
|
loss_matrix = loss_matrix * input_mask
|
||
|
cl_loss = torch.sum(loss_matrix) / torch.sum(loss_mask)
|
||
|
return cl_loss
|
||
|
|