logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

34 lines
1.0 KiB

import torch
import torch.nn as nn
class TripletLoss(nn.Module):
def __init__(self, gamma=1.0, similarity=True):
super(TripletLoss, self).__init__()
self.gamma = gamma
self.similarity = similarity
def forward(self, sim_pos, sim_neg):
if self.similarity:
loss = torch.clamp(sim_neg - sim_pos + self.gamma, min=0.)
else:
loss = torch.clamp(sim_pos - sim_neg + self.gamma, min=0.)
return loss.mean()
class SimilarityRegularizationLoss(nn.Module):
def __init__(self, min_val=-1., max_val=1.):
super(SimilarityRegularizationLoss, self).__init__()
self.min_val = min_val
self.max_val = max_val
def forward(self, sim):
loss = torch.sum(torch.abs(torch.clamp(sim - self.min_val, max=0.)))
loss += torch.sum(torch.abs(torch.clamp(sim - self.max_val, min=0.)))
return loss
def __repr__(self,):
return '{}(min_val={}, max_val={})'.format(self.__class__.__name__, self.min_val, self.max_val)