lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
32 lines
967 B
32 lines
967 B
"""
|
|
Misc lr helper
|
|
"""
|
|
from torch.optim import Adam, Adamax
|
|
|
|
from .adamw import AdamW
|
|
|
|
|
|
def build_optimizer(model, opts):
|
|
param_optimizer = list(model.named_parameters())
|
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
|
optimizer_grouped_parameters = [
|
|
{'params': [p for n, p in param_optimizer
|
|
if not any(nd in n for nd in no_decay)],
|
|
'weight_decay': opts.weight_decay},
|
|
{'params': [p for n, p in param_optimizer
|
|
if any(nd in n for nd in no_decay)],
|
|
'weight_decay': 0.0}
|
|
]
|
|
|
|
# currently Adam only
|
|
if opts.optim == 'adam':
|
|
OptimCls = Adam
|
|
elif opts.optim == 'adamax':
|
|
OptimCls = Adamax
|
|
elif opts.optim == 'adamw':
|
|
OptimCls = AdamW
|
|
else:
|
|
raise ValueError('invalid optimizer')
|
|
optimizer = OptimCls(optimizer_grouped_parameters,
|
|
lr=opts.learning_rate, betas=opts.betas)
|
|
return optimizer
|