Posted on October 15, 2023

Deep dive into the edge of stability

There is a fascinating phenomenon that occurs at parties: initially, when there are only a few guests, the volume of the room starts out low; over time as people more and more people arrive and more and more beers are emptied, people start to talk louder and louder, until the volume of the room stabilizes at a steady state corresponding to the loudest volume people can comfortably speak at, and tends to stay at this level until people get too drunk, tired, or hoarse and go home. This phenomenon has a number of intriguing properties: the steady-state volume is almost always a bit louder than anyone would prefer, and yet this inexorable rise in the ambient noise levels occurs nonetheless. The phenomenon is also self-stabilizing, in that there is a maximum value above which any deviations are sharply penalized: barring nightclubs, most venues that require screaming in your conversation partner’s ear will see a great exodus of attendees, which will have the effect of reducing the volume back to the maximum tolerable level.

In this blog post, I will make the argument that the noisy cocktail party presents a strikingly accurate metaphor for gradient descent in neural networks. We will see how the evolution sharpness of the local loss landscape in networks trained with gradient descent exhibits a similar trend towards increasing and then stablizing at an analogous maximum tolerable value, as determined by the step size used by the optimization algorithm. This phenomenon, termed the edge of stability by Cohen et al. has the potential to give us insight into generalization and optimization dynamics in neural networks, and to explain a few surprising prior observations connecting learning rates and flat minima. This has unsurprisingly made it a trendy topic among theorists who want to prove something useful for deep learning practitioners, so there’s a lot of exciting work to cover.

Let’s start with the basics: what exactly is the “edge of stability”?

The basics

The edge of stability is a regime of neural network training where the curvature of the local loss landscape is equal to 2/step size. In the case of gradient descent with learning rate \(\eta\), this means that \(\lambda_{\max}(\mathcal{H}) \approx \frac{2}{\eta}\), where \(\mathcal{H}\) denotes the Hessian with respect to the loss \(\ell\) of parameters \(\theta\). The “edge of stability phenomenon” refers to a two-step process observed in neural networks where

  1. the sharpness \(\lambda_{\max}(\mathcal{H})\) progressively increases until
  2. it stabilizes at value at or just above \(\frac{2}{\eta}\), at which point the loss after each optimizer step decreases non-monotonically.

In the case of a quadratic loss function, a learning rate smaller than 2/step size would result in stable consistent progress to reduce the objective function; a learning rate greater than 2/step size would (and indeed, as Cohen et al. show, does quickly) diverge. The fact that loss functions are not quadratic with respect to neural network parameters means that, rather than diverging, gradient descent just follows a bumpy downward trajectory in practice.

Cohen et al. weren’t the first deep learning researchers to notice that early training dynamics of neural networks look very different from, and can heavily influence, later dynamics: the break-even point” noted by Jastrzebski and collection of researchers with equally unpronouncable names was closely connected observation that the learning rate used early in training determines the curvature of the region of the loss landscape that gradient descent converges to; similarly, Lewkowycz et al. noted that if the curvature of the loss landscape is too large relative to the step size, gradient descent will sometimes “catapault” to a region of lower curvature after initially being destabilized, rather than completely diverging.

Self-stabilization and implicit regularization

I stumbled upon the edge of stability paper after having read two papers studying the implicit regularization of finite-step-size (stochastic) gradient descent on the curvature of the loss landscape. After having read these papers, I was under the impression that increasing the learning rate (or batch size, in the case of stochastic GD) led to increased flatness of the local minimum to which optimization converged because of an implicit regularization effect. Both papers applied a similar approach, starting by approximating finite-step-size gradient descent

\[ \theta_t := \dot{\tilde{\theta_t}} = \nabla_\theta \ell(\tilde{\theta_t})\]

with a gradient flow

\[ \tilde{\theta_t} := \dot{\tilde{\theta_t}} = \nabla_\theta \ell(\tilde{\theta_t})\]

then approximating the difference between the gradient flow solution and the exact finite step-size updates with a second-order taylor expansion to get

\[ \theta_t - \tilde{\theta}_t \approx \frac{1}{2} \| \nabla_\theta \ell(\theta_0) \|^2 t^2\]

and finally showing that there is another gradient flow \(\hat{\theta}_t\) defined by

\[ \partial_t \hat{\theta_t} = \nabla ( \ell(\theta_t) + \lambda \| \nabla_\theta \ell(\theta_t)\|^2)\] which zeros out the error in this second-order approximation. It turns out that the exact coefficient \(\lambda\) on the gradient penalty is proportional to the step size of gradient descent.

So we basically have one set of papers which say that there’s a natural tendency for finite-step-size gradient descent to regularize the gradient norm, and a second set of papers which say that there’s a tendency for gradient descent to increase the maximum eigenvalue of the Hessian up to some value proportional to the inverse of the step size. Two different notions of sharpness, two different analysis methodologies, but two results that both say the sharpness should be monotone decreasing in the step size. What gives?

I don’t know if there’s necessarily a deep and profound connection between the results described above, but there is a straightforward one. The gradient norm at a local minimum is zero by definition, but if the gradient is changing very quickly near this minimum, then its norm will probably be a lot larger at nearby points. As a result, a large maximal Hessian eigenvalue will end up being correlated with a large gradient norm in a lot of optimization problems. Provided the optimizer step size is too big to converge to the sharp local minimum, it will end up just bouncing around in the region of higher curvature and this correlation will stay in place. So it shouldn’t be surprising that similar results concerning the gradient norm and maximal Hessian eigenvalue exist: they are subtly different but correlated aspects of the optimization process.

Why does sharpness increase in the first place?

This is something that I don’t think we yet have a great answer to. Most theoretical analyses I’ve seen of this phenomenon focus principally on the behaviour of optimization after the sharpness has already increased, and make some assumption that the gradients early in training have some inherent bias that increases the sharpness of the Hessian. In some cases weaker assumptions are able to produce a similar end result, for example Wang et al. assume that the Hessian sharpness is proportionalto the norm of the network’s output weights, then show that the output weight norm increases and so the Hessian sharpness must increase as well. This seems consistent with similar observations I and others have made that parameter norm tends to increase over time without explicit regularization, which would, all else being equal, increase the maximum eigenvalue of the Hessian. It might be that the reason for the increase in the spectral norm of the Hessian is actually most easily explained by an entropy maximization argument: for example, it might simply be that the volume of parameter space with a given sharpness increases with the sharpness value, and so even following a random walk we would expect the Hessian norm to increase over time.

Agarwala et al. provide the most compelling explanation for progressive sharpening that I’ve been able to find. This paper explicitly uses the equivalence between the Hessian and the gradient norm, and also uses the trick of only looking at every other step of gradient descent so that they can ignore oscillations in the direction of maximal sharpness. They look at just about the simplest nonlinear learning problem imaginable: optimizing a quadratic \(\ell(\theta) = \frac{1}{2}[\theta^\top Q \theta - E]\), where \(Q \in \mathbb{R}^{P \times P}\), and is symmetric (so it only has real eigenvalues). They were able to show that edge of stability dynamics can appear in this model if the geometry of the loss, i.e. the eigenvalues of \(Q\), satisfied certain properties and the learning rate was in a particular range. Depending on the eigenvalue distribution, they could either obtain edge of stability dynamics or recover the catapult mechanism, which I thought was interesting. However, all of this comes with the massive caveat that this is a tiny tiny system that is missing a lot of the nonlinear complexities of neural networks.

To get closer to the real world, in a higher-dimensional version of this model they apply a limiting argument to show that under some assumptions, for sufficiently high dimensional, randomly generated quadratic problems, the expected value of the second derivative of the NTK’s spectral norm is positive and equal to the variance of the loss, while its second deritavie is zero. In other words: at least at initialization, the sharpness of the loss landscape increases under gradient flow dynamics. The paper also looks at quadratic approximations of neural network training but these looked like they produced very different dynamics from the real optimization trajectories. So overall, it’s interesting that the dynamics of gradient descent seem like they tend to push up sharpness, but I don’t think we can decisively say that this increase happens for the exact same reason in quadratic models and in neural networks.

Why doesn’t sharpness increase indefinitely?

Although I haven’t found an entirely satisfying explanation for why sharpness in neural networks increases in the first place, if we assume that it does then there are a number of really interesting theoretical results trying to explain why it levels off at some large but finite value rather than diverging. Damian et al. identify a self-stabilizing property of gradient descent dynamics which, although it involves a pretty hairy derivation using a third-order Taylor expansion of the optimization dynamics that takes immense mental fortitude to follow, provides what I think are some useful insights into gradient descent dynamics. The result of this analysis is that the gradient of the sharpness, \(\nabla S(\theta)\) is exactly equal to the third-order term you get from applying the third-order taylor expansion to the principal eigenvector of the Hessian. So what you get in the end is a self-stabilizing process: if the parameters diverge too far in the direction of the top eigenvector of the Hessian, then eventually the third order term will outweigh the lower order terms in the expansion and drive sharpness down.

This analysis results in a concise 4-stage characterization of learning, similar to those proposed by Li et al. in their work that uses output norm weight as a proxy for sharpness. In phase 1, sharpness monotonically increases (this is the least satisfying piece of the analysis – they just assume that whatever gradient descent does normally biases towards increasing sharpness) until it reaches \(\frac{2}{\eta}\). In phase 2, sharpness passes the magical \(\frac{2}{\eta}\) threshold and chaos ensues, with the parameters diverging in the direction of the top eigenvector of the hessian. In phase 3 this seemingly-uncontrolled growth becomes large enough that the third-order term in the Taylor expansion of the gradient descent dynamics starts to kick in, and this drives the sharpness back down. Finally, phase 4 denotes the return to stability, as sharpness drops below \(\frac{2}{\eta}\) and the now-stabilized gradient updates have the effect of reducing the norm of the parameters.

This is a simplified analysis that obviously has to make a lot of assumptions about the structure of the loss landscape and the effect of higher-order terms in the learning dynamics. However, the paper does compare their predicted dynamics to what is actually obtained by not-quite-trivial neural networks (3-layer CNN + MLPs and a 2-layer transformer) and get reasonable concordance between prediction and observation. I assume this choice of layers was partially due to the observation that third-order effects drive self-stabilization, and you need this many layers to observe nontrivial higher order effects. This assumption is borne out by Zhu et al. who consider what is probably the simplest learning problem in which anyone will ever manage to induce edge of stability dynamics: \(\ell(x, y) = \frac{1}{2}(1 - x^2y^2)^2\). I definitely recommend giving the paper a look – it has great visualizations and the theoretical results involve hilariously large constants. If you’re busy, the TL;DR is that they are able to give a precise characterization of the dynamics of this learning problem, and show that it exhibits edge-of-stability behaviour under suitable initializations and step sizes.

Extensions

Obviously, understanding the behaviour of gradient descent is relevant for deep learning researchers since all of our algorithms are basically variants on SGD. However, few people use vanilla, fixed-step-size SGD as it is normally taught in an intro to ML course – and we definitely don’t use full-batch GD. So a big question I and I think a lot of other people had after reading the original edge of stability paper was whether something analogous would happen in adaptive step size optimizers. Lo and behold, in 2022 Cohen (with a different et al. this time) showed that you can get an analogous result for adam. The spirit of the finding is the same: sharpness increases until it breaks your optimizer, and then it stabilizes. However, analyzing it requires a bit more work because of all of the bells and whistles that get thrown into adaptive optimizers.