• AIPressRoom
  • Posts
  • Is PyTorch’s Nesterov Momentum Implementation Flawed? | by Jason Vega | Sep, 2023

Is PyTorch’s Nesterov Momentum Implementation Flawed? | by Jason Vega | Sep, 2023

When you look carefully at PyTorch’s documentation of SGD, you can see that their implementation of Nesterov momentum has just a few variations from the formulation discovered within the original paper. Most notably, PyTorch’s implementation evaluates the gradient on the present parameters, whereas the entire level of Nesterov momentum is to guage the gradient at shifted parameters. Sadly, it seems that dialogue about these discrepancies on the web is scarce. On this submit, we’ll look at and clarify the variations between PyTorch’s implementation and the unique formulation of Nesterov momentum. Finally, we’ll see how PyTorch’s implementation just isn’t incorrect, however quite an approximation, and speculate about the good thing about their implementation.

The original paper describes Nesterov momentum utilizing the next replace guidelines:

the place v_{t+1} and θ_{t+1} are the speed vector and mannequin parameters respectively at time t, μ is the momentum issue, and ε is the educational charge. The observe in PyTorch’s SGD documentation states they use the next replace guidelines:

the place g_{t+1} represents the gradient used to compute v_{t+1}. We will increase the replace rule for θ_{t+1} to get:

From this we will infer that:

and the replace guidelines turn into:

These are the replace guidelines that PyTorch makes use of in idea. I discussed earlier that PyTorch truly evaluates the gradient on the present parameters as a substitute of the shifted parameters. This may be seen by trying on the algorithm description within the PyTorch SGD documentation. We’ll examine this additional in a while.

Word that for each the unique (1, 2) and PyTorch (3, 4) formulations, if v_0 = 0, then the primary replace to θ turns into:

Though the PyTorch SGD documentation observe states that the algorithm initializes the momentum buffer to the gradient at step one, we’ll later present that this suggests v_0 = 0.

There are two speedy variations when going from the unique (1, 2) to the PyTorch (3, 4) formulation:

  1. The educational charge is moved exterior of v_{t+1}.

  2. Within the replace rule for v_{t+1}, the time period involving the gradient is added as a substitute of subtracted, and within the replace rule for θ_{t+1}, the time period involving the speed vector is subtracted as a substitute of added. The distinction in signal contained in the gradient time period is solely a consequence of this as proven within the earlier part.

To grasp these variations, let’s first increase the replace guidelines. As hinted at here, the impact of the primary distinction is extra obvious if we contemplate studying charge schedules. So, we contemplate a generalization of the replace guidelines the place ε is not fastened however can now fluctuate over time, and denote ε_t as the educational charge at time step t. For brevity, let:

Assuming v_0 = 0, the unique formulation turns into:

and the PyTorch formulation turns into:

Within the unique formulation (6), if the educational charge have been to alter at time t, then solely the magnitude of the time period at i = t within the summation can be affected, and the magnitudes of all the opposite phrases would stay the identical. Consequently, the speedy affect of the educational charge change is sort of restricted, and we must watch for the educational charge change to “trickle” down over subsequent time steps to have a stronger affect on the general step dimension. In distinction, within the PyTorch formulation (7), if the educational charge have been to alter at time t, then the magnitude of the whole step can be affected instantly.

For v_0 = 0, it’s clear from the expanded guidelines that the second distinction in the end has no impact; in both formulation, the step works out to a reduced sum of gradients that’s subtracted from the present parameters.

Ignoring weight decay and dampening, by analyzing the SGD algorithm in PyTorch’s documentation, we will see that the applied replace guidelines are:

the place θ’_{t+1} are the mannequin parameters at time t and

We’ll consult with equations 3 and 4 because the PyTorch “observe” formulation, and equations 8 and 9 because the PyTorch “applied” formulation. We make a distinction between θ and θ’ for a cause that may turn into obvious quickly. Essentially the most evident distinction from the observe formulation is that the gradient is evaluated on the present parameters quite than the shifted parameters. From this alone it might seem that the replace guidelines the algorithm implements just isn’t a correct implementation of Nesterov momentum.

We’ll now look at how the PyTorch algorithm in the end approximates Nesterov momentum. Derivations for an older model of PyTorch may be discovered here from Ivo Danihelka, referenced in this GitHub issue. Derivations for the present model of PyTorch may be discovered here, which is a comparatively simple adjustment from the earlier derivations. We offer a LaTeX rendering of those (re-derived) derivations right here for the reader’s comfort. The applied formulation is derived by a easy change of variables. Particularly, we let:

It instantly turns into clear that the observe replace rule for v_{t+1} (3) turns into equal to the applied replace rule for v_{t+1} (8) after the change of variables. We now wish to derive an replace rule for θ’_{t+1} by way of θ’_t:

That is precisely the replace rule we noticed applied in PyTorch (9). At a excessive stage, the PyTorch implementation assumes the present parameters θ’_t are already the shifted model of the “precise” parameters θ_t. Therefore, at every time step, the “precise” parameters θ_t are associated to the present parameters θ’_t by:

Nevertheless, it seems from the supply code that the PyTorch SGD implementation doesn’t make any correction on the finish of the algorithm to retrieve the ultimate “precise” parameters, so the ultimate output is technically an approximation of the “precise” parameters.

Lastly, we now present that v_0 should be 0:

Furthermore, we will affirm that the primary replace to the “precise” parameters is similar first replace made within the unique formulation when v_0 = 0:

We will see that that is equal to equation 5.

In fact, the massive remaining query is: Why does PyTorch hassle in any respect to reformulate Nesterov momentum from equations 3 and 4 to equations 8 and 9? One potential clarification is that the reformulation would possibly present some financial savings within the variety of arithmetic operations required. To guage this potential clarification, let’s rely the variety of arithmetic operations. For the observe formulation (3, 4), we now have:

Right here, there are a complete of seven operations. For the applied formulation (8, 9), we now have:

Right here, there are a complete of six operations. The second gradient within the PyTorch implementation simply makes use of the saved consequence from the primary gradient computation, so just one gradient computation is carried out at every time step. So, one obvious profit is that the PyTorch implementation cuts down on one extra multiplication operation at every step.

In abstract:

  1. The replace guidelines said in PyTorch’s SGD documentation observe (3, 4) have a special location for the educational charge in comparison with the unique Nesterov momentum replace guidelines (1, 2). This enables studying charge schedules to have a direct impact on the general step dimension, whereas the unique formulation would have the impact of studying charge modifications to “trickle” down over subsequent time steps.

  2. The replace guidelines applied within the PyTorch SGD algorithm (8, 9) are an approximation to the replace guidelines said within the documentation observe (3, 4) after a easy change of variables. Though the “precise” parameters are simply recoverable from the present parameters at every time step, the PyTorch implementation doesn’t make any such correction on the finish of the algorithm, and so the ultimate parameters technically stay an approximation of the “precise” remaining parameters.

  3. An obvious good thing about the PyTorch implementation is that it avoids an extra multiplication operation at every time step.

  1. “SGD.” SGD — PyTorch 2.0 Documentation, pytorch.org/docs/secure/generated/torch.optim.SGD.html. Accessed 2 Sept. 2023.

  2. Sutskever, Ilya, et al. “On the importance of initialization and momentum in deep learning.” Worldwide Convention on Machine Studying. PMLR, 2013.

  3. Danihelka, Ivo. “Nesterov’s Momentum Made Simple.” 25 Aug. 2012.

  4. Chintala, Soumith. “nesterov momentum is incorrect in sgd · Difficulty #27 · torch/optim.” GitHub, 13 Oct. 2014, github.com/torch/optim/issues/27.

  5. Gross, Sam. “Add a observe within the docs concerning the momentum formulation utilized in optim · Difficulty #1099 · pytorch/pytorch.” GitHub, 25 Mar. 2017, github.com/pytorch/pytorch/issues/1099#issuecomment-289190614.

  6. Zhao, Yilong. “repair Nesterov Momentum Bug · Difficulty #5920 · pytorch/pytorch.” GitHub, 21 Mar. 2018, https://github.com/pytorch/pytorch/pull/5920#issuecomment-375181908.