lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
53 lines
1.6 KiB
53 lines
1.6 KiB
2 years ago
|
"""
|
||
|
optimizer learning rate scheduling helpers
|
||
|
"""
|
||
|
from math import ceil
|
||
|
|
||
|
|
||
|
def noam_schedule(step, warmup_step=4000):
|
||
|
if step <= warmup_step:
|
||
|
return step / warmup_step
|
||
|
return (warmup_step ** 0.5) * (step ** -0.5)
|
||
|
|
||
|
|
||
|
def warmup_linear(step, warmup_step, tot_step):
|
||
|
if step < warmup_step:
|
||
|
return step / warmup_step
|
||
|
return max(0, (tot_step-step)/(tot_step-warmup_step))
|
||
|
|
||
|
|
||
|
def vqa_schedule(step, warmup_interval, decay_interval,
|
||
|
decay_start, decay_rate):
|
||
|
""" VQA schedule from MCAN """
|
||
|
if step < warmup_interval:
|
||
|
return 1/4
|
||
|
elif step < 2 * warmup_interval:
|
||
|
return 2/4
|
||
|
elif step < 3 * warmup_interval:
|
||
|
return 3/4
|
||
|
elif step >= decay_start:
|
||
|
num_decay = ceil((step - decay_start) / decay_interval)
|
||
|
return decay_rate ** num_decay
|
||
|
else:
|
||
|
return 1
|
||
|
|
||
|
|
||
|
def get_lr_sched(global_step, opts):
|
||
|
# learning rate scheduling
|
||
|
if opts.decay == 'linear':
|
||
|
lr_this_step = opts.learning_rate * warmup_linear(
|
||
|
global_step, opts.warmup_steps, opts.num_train_steps)
|
||
|
elif opts.decay == 'invsqrt':
|
||
|
lr_this_step = opts.learning_rate * noam_schedule(
|
||
|
global_step, opts.warmup_steps)
|
||
|
elif opts.decay == 'constant':
|
||
|
lr_this_step = opts.learning_rate
|
||
|
elif opts.decay == 'vqa':
|
||
|
lr_this_step = opts.learning_rate * vqa_schedule(
|
||
|
global_step, opts.warm_int, opts.decay_int,
|
||
|
opts.decay_st, opts.decay_rate)
|
||
|
if lr_this_step <= 0:
|
||
|
# save guard for possible miscalculation of train steps
|
||
|
lr_this_step = 1e-8
|
||
|
return lr_this_step
|