Fixes for Nesterov momentum in ML


This article points out errors in everyone’s Nesterov momentum algorithms for ML. If you don’t use Nesterov momentum, it won’t affect you.

Normal momentum:

vt+1=μtvt+f(θt)v_{t+1} = \mu_t v_t + \nabla f(\theta_t)
θt+1=θtγtvt+1\theta_{t+1} = \theta_t - \gamma_t v_{t+1}

vtv_t is a momentum buffer. μt[0,1]\mu_t \in [0, 1] is a decay scalar. θt\theta_t are the model parameters. f(θt)\nabla f(\theta_t) is a gradient. This form is standard among practitioners. (Older papers move γt\gamma_t elsewhere.)

Sutskever introduced Nesterov momentum to Machine Learning, calculating the gradient at a lookahead position:

vt+1=μtvt+f(θtγtμtvt)v_{t+1} = \mu_t v_t + \nabla f(\theta_t - \gamma_t \mu_t v_t)
θt+1=θtγtvt+1\theta_{t+1} = \theta_t - \gamma_t v_{t+1}

Nesterov momentum is optimal in convex optimization in a narrow sense. Sometimes it helps in ML.

Bengio substituted ϕt=θtγtμtvt\phi_t = \theta_t - \gamma_t \mu_t v_t, baking in the lookahead:

vt+1=μtvt+f(ϕt)v_{t+1} = \mu_t v_t + \nabla f(\phi_{t})
ϕt+1+γt+1μt+1vt+1=ϕt+γtμtvtγtvt+1ϕt+1=ϕt+γt(μtvtvt+1)γt+1μt+1vt+1\begin{aligned} \phi_{t+1} + \gamma_{t+1} \mu_{t+1} v_{t+1} &= \phi_{t} + \gamma_t \mu_t v_t - \gamma_t v_{t+1} \\ \phi_{t+1} &= \phi_{t} + \gamma_t (\mu_t v_t - v_{t+1}) - \gamma_{t+1} \mu_{t+1} v_{t+1} \end{aligned}
ϕt+1=ϕtγtf(ϕt)γt+1μt+1vt+1\boxed{\phi_{t+1} = \phi_{t} - \gamma_t \nabla f(\phi_{t}) - \gamma_{t+1} \mu_{t+1} v_{t+1}}

The lookahead momentum vt+1v_{t+1} is multiplied by γt+1\gamma_{t+1} and μt+1\mu_{t+1} - the learning rate and momentum decay for next step, not this step!

You’ll need to retrieve γt+1\gamma_{t+1} and μt+1\mu_{t+1} from your schedulers, but PyTorch’s learning rate schedulers have an off-by-one issue that make this unsolvable. The standard training step counterbalances it with another off-by-one error, reversing the order of optimizer and scheduler. Loading the scheduler requires yet another off-by-one adjustment. You should roll your own random-access scheduler and not use torch’s scheduler API.

Adam

LaProp can use Nesterov momentum directly. I don’t know anyone using LaProp except Lucas Nestler.

Adam requires more work. Do not naively apply Bengio’s version inside the numerator, like NAdam does. It is wrong because it neglects the interaction with Adam’s denominator. When the denominator changes between steps, as with loss spikes or low beta2, the movements for lookahead and retraction do not cancel.

The non-Nesterov Adam update is θt+1θt=Adamt(f(θt))=γtmt+1nt+1+ϵ\theta_{t+1} - \theta_t = \text{Adam}_t(\nabla f(\theta_{t})) = -\gamma_{t} \frac{m_{t+1}}{\sqrt{n_{t+1}} + \epsilon}.

For correct Nesterov acceleration, we calculate the gradient at a lookahead position: θt+1θt=Adamt(f(θt+μt(θtθt1)))\theta_{t+1} - \theta_t= \text{Adam}_t(\nabla f(\theta_t + \mu_t (\theta_t - \theta_{t-1})))

With lookahead parameters ϕt=θt+μt(θtθt1)=θt+μtAdamt1(f(ϕt1))\phi_t = \theta_t + \mu_t (\theta_t - \theta_{t-1})= \theta_t + \mu_t \text{Adam}_{t-1}(\nabla f(\phi_{t-1})), the update becomes:

θt+1=θt+Adamt(f(ϕt))ϕt+1μt+1Adamt(f(ϕt))=ϕtμtAdamt1(f(ϕt1))+Adamt(f(ϕt))\begin{aligned} \theta_{t+1} &= \theta_t + \text{Adam}_t(\nabla f(\phi_t)) \\ \phi_{t+1} - \mu_{t+1} \text{Adam}_{t}(\nabla f(\phi_{t})) &= \phi_t - \mu_t \text{Adam}_{t-1}(\nabla f(\phi_{t-1})) + \text{Adam}_t(\nabla f(\phi_t)) \end{aligned}
ϕt+1=ϕt+(1+μt+1)Adamt(f(ϕt))μtAdamt1(f(ϕt1))\boxed{\phi_{t+1} = \phi_t + (1 + \mu_{t+1}) \text{Adam}_t(\nabla f(\phi_t)) - \mu_t \text{Adam}_{t-1}(\nabla f(\phi_{t-1}))}

You can recalculate Adamt1(f(ϕt1))\text{Adam}_{t-1}(\nabla f(\phi_{t-1})) using the optimizer’s buffers before updating them with the gradient.

This derivation handled Adam as a black box, without depending on its internal momentum behavior. So this formula can accelerate momentum for any optimizer. There is a gap in this theory - the optimal μ\mu for Nesterov is not specified, and lacks a connection to β1\beta_1 inside Adam.

This formula is Eq (7) in Bengio’s paper. What’s old is new again.

Muon

Muon has the same issue and solution, except last step’s update is too expensive to recalculate. Either store last step’s update (memory intensive!), give up on Nesterov momentum, or hope Muon’s existing Nesterov-like momentum is fine.

Weight decay

Apply weight decay to the original parameters, not the lookahead parameters.

Inference

Remember that lookahead and normal parameters are different. You can inference with the normal parameters by either explicitly undoing the lookahead or setting your schedules to make them equal. It’s unclear if lookahead or normal parameters are better.

Experiments

If this were a paper, here would be the time-honored tradition of sabotaging the baselines to show how important this correction is. I am confident I can insert a subtle optimizer flaw that won’t be spotted, even when fully described. Regular folks would be excited by the large improvement numbers and experts would ignore them.

The optimizer I use is very different from Adam and I don’t want to re-implement Adam just to produce some charts.

Nesterov acceleration is only sometimes positive for model training. You could empirically find the optimal Nesterov-ness by testing different lookahead distances, an obvious idea re-discovered by many papers. I haven’t figured out what causes Nesterov to hurt or help.

Credits

Thanks to Alex Birch for doing my work so that I can write blog posts instead.

To discuss this article: EleutherAI discord or kevin+3qm@anlatan.ai.