Handles linear scaling rule, gradual warmup, and LR decay.
// lr_warmup_init is the starting learning rate; the learning rate is linearly
// scaled up to the full learning rate after `lr_warmup_steps` before decaying.
linear_warmup = [(lr_warmup_init + float(step) / lr_warmup_step *
(1 - lr_warmup_init), step)
for step in range(lr_warmup_step)]
lr_schedule = linear_warmup + [[1.0, lr_warmup_step],
[0.1, first_lr_drop_step],
[0.01, second_lr_drop_step]]