lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
33 lines
967 B
33 lines
967 B
2 years ago
|
"""
|
||
|
Misc lr helper
|
||
|
"""
|
||
|
from torch.optim import Adam, Adamax
|
||
|
|
||
|
from .adamw import AdamW
|
||
|
|
||
|
|
||
|
def build_optimizer(model, opts):
|
||
|
param_optimizer = list(model.named_parameters())
|
||
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||
|
optimizer_grouped_parameters = [
|
||
|
{'params': [p for n, p in param_optimizer
|
||
|
if not any(nd in n for nd in no_decay)],
|
||
|
'weight_decay': opts.weight_decay},
|
||
|
{'params': [p for n, p in param_optimizer
|
||
|
if any(nd in n for nd in no_decay)],
|
||
|
'weight_decay': 0.0}
|
||
|
]
|
||
|
|
||
|
# currently Adam only
|
||
|
if opts.optim == 'adam':
|
||
|
OptimCls = Adam
|
||
|
elif opts.optim == 'adamax':
|
||
|
OptimCls = Adamax
|
||
|
elif opts.optim == 'adamw':
|
||
|
OptimCls = AdamW
|
||
|
else:
|
||
|
raise ValueError('invalid optimizer')
|
||
|
optimizer = OptimCls(optimizer_grouped_parameters,
|
||
|
lr=opts.learning_rate, betas=opts.betas)
|
||
|
return optimizer
|