towhee
/
distill-and-select
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
25 lines
613 B
25 lines
613 B
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class L2Constrain(object):
|
|
|
|
def __init__(self, axis=-1, eps=1e-6):
|
|
self.axis = axis
|
|
self.eps = eps
|
|
|
|
def __call__(self, module):
|
|
if hasattr(module, 'weight'):
|
|
w = module.weight.data
|
|
module.weight.data = F.normalize(w, p=2, dim=self.axis, eps=self.eps)
|
|
|
|
|
|
class NonNegConstrain(object):
|
|
|
|
def __init__(self, eps=1e-3):
|
|
self.eps = eps
|
|
|
|
def __call__(self, module):
|
|
if hasattr(module, 'weight'):
|
|
w = module.weight.data
|
|
module.weight.data = torch.clamp(w, min=self.eps)
|