logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

25 lines
613 B

import torch
import torch.nn.functional as F
class L2Constrain(object):
def __init__(self, axis=-1, eps=1e-6):
self.axis = axis
self.eps = eps
def __call__(self, module):
if hasattr(module, 'weight'):
w = module.weight.data
module.weight.data = F.normalize(w, p=2, dim=self.axis, eps=self.eps)
class NonNegConstrain(object):
def __init__(self, eps=1e-3):
self.eps = eps
def __call__(self, module):
if hasattr(module, 'weight'):
w = module.weight.data
module.weight.data = torch.clamp(w, min=self.eps)