towhee
/
distill-and-select
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
33 lines
1.0 KiB
33 lines
1.0 KiB
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class TripletLoss(nn.Module):
|
|
|
|
def __init__(self, gamma=1.0, similarity=True):
|
|
super(TripletLoss, self).__init__()
|
|
self.gamma = gamma
|
|
self.similarity = similarity
|
|
|
|
def forward(self, sim_pos, sim_neg):
|
|
if self.similarity:
|
|
loss = torch.clamp(sim_neg - sim_pos + self.gamma, min=0.)
|
|
else:
|
|
loss = torch.clamp(sim_pos - sim_neg + self.gamma, min=0.)
|
|
return loss.mean()
|
|
|
|
|
|
class SimilarityRegularizationLoss(nn.Module):
|
|
|
|
def __init__(self, min_val=-1., max_val=1.):
|
|
super(SimilarityRegularizationLoss, self).__init__()
|
|
self.min_val = min_val
|
|
self.max_val = max_val
|
|
|
|
def forward(self, sim):
|
|
loss = torch.sum(torch.abs(torch.clamp(sim - self.min_val, max=0.)))
|
|
loss += torch.sum(torch.abs(torch.clamp(sim - self.max_val, min=0.)))
|
|
return loss
|
|
|
|
def __repr__(self,):
|
|
return '{}(min_val={}, max_val={})'.format(self.__class__.__name__, self.min_val, self.max_val)
|