本篇博客会解读一下本人最近发表的 《Mousse: Rectifying the Geometry of Muon with Curvature-Aware
Preconditioning
》 这篇工作,并分享一些最近在优化器方面工作的一些心得和理解。内容包含很多主观认识,也很欢迎大家与我交流。

该博客仍在施工中;This is still a Work in Progress (WIP)

Introduction

上学期我的大部分时间都在思考一些有关预训练的问题。虽然在我看来,学界对于预训练的热情正在逐渐褪去(当然只限于 scaling 方面,各种架构反倒是愈加层出不穷,但也从侧面反映了当前模型正在逼近其边界),不过由于我天然的松弛感,我依然在无忧无虑地阅读着各类文献。其中,苏神的博客意外成为了我最喜欢阅读的内容之一,也间接成为了我继续写博客的契机,同时我也自然而然地开始关注到 attention、优化器等领域。特别地,苏神有关 muon 的几篇文章我觉得讲的是很明白的,有很多实际训练的经验以及理论上的认识,我也自己记录了一篇以作拙劣的模仿。感兴趣的读者可以去阅读一下苏神的文章,并在建立对 muon 的初步认识后,再来一起讨论下 Mousse 所做的内容~

Why Muon is Good?

Muon 优化器本质上和 AdamW 等优化器非常不一样。它既非自适应、也完全称不上二阶,仅仅只是在特征方向上将动量一刀切,做谱归一化;硬要说的话,它反而更像 SignSGD 与 Lion 这种强制性的归一化策略。早在我们在学校上优化课时就知道,牛顿法在优化领域属于是“宇宙万法的源头”;而 Muon 意料之外的优秀表现,恰恰代表了大模型优化过程存在两种截然不同的哲学理念:由牛顿法引导的曲率派 与 以归一化主导的归一派。在我看来,后者得以异军突起的根本原因在于,我们对于曲率的预估是有误差的。从 AdamW 到 Shampoo 的升级,可以看成是一种把对 Hessian 的预估从 Diagnol 提升到了 Layer-wise Kronecker Product 。然而即便如此,不论是 EMA 也好,还是 micro batch size 也罢,都注定了这种估计是有误差的,是有损的。当这种误差大到,引入曲率没有提升甚至负提升时,那就还不如直接做一次简单的一刀切,仅仅做简单的正则化,人为设定一种强大的先验来保持训练的稳定。

优化器的公式是简单的,但是其背后的,LLM 真实的 loss landscape 却是无比复杂的。再好的数学性质,也需要与大模型的实际情况相匹配,此即所谓的归纳偏置(inductive bias),适合 LLM 的才是最好的。从结果来看,优化器的归一派无疑是 LLM 乐于接受的先验。也许是因为参数空间存在巨大的冗余,模型既可以在精准导航下找到一处崎岖的山谷,也可以在归一化的先验下,找到一个所谓各向同性的位置。它们可能在 loss 层面上表现出来一样好,但背后却落入了完全不同的损失曲面。也许这也是为什么,用 Muon 优化器训练的模型用 AdamW 续训,或是反过来,效果都不好的原因了: 两者最后收敛邻域的空间结构完全不同。AdamW 也许能停在一处悬崖,如果它的二阶矩能够识别到该方向曲率巨大而自动放缓步长,但此时 Muon 不管不顾地迈出的一步单位步长却可能摧毁模型的表现;而用 Muon 训练的模型,也许天然地就更亲睐一种谱各向同性的流形,自然地往这一子空间去收敛。

当然,除了归一化之外,Muon 最大的进步就是选择在谱空间执行归一化操作(并且引入 NS 迭代降低了计算复杂度),使其成为了一个 矩阵级 的优化器。矩阵级的优化器强于 element-wise 的优化器目前来看应该是一种共识,本质原因可能是因为矩阵算子背后代表着更好的流形约束。

而 Mousse 所做的内容就是,尝试结合 牛顿派(Shampoo) 的曲率调节 与 归一派(Muon) 的谱约束的优点,看看能不能取得更好的效果。

Preliminary

我们先来思考一个简单的问题。我们都知道,梯度下降的更新公式是:

我们用一种容易理解的方式,不妨假设 $\theta$ 的量纲是 秒。而梯度本身的物理意义是,参数空间在各方向上移动一点点,对于 Loss 的影响。我们假设 Loss 的量纲是 米,那么梯度的量纲就是 米/秒

秒 := 秒 - 米/秒,量纲不同的物理量怎么能计算呢?所以里面必然存在一处,连接参数空间与梯度空间的桥梁。如果我们设参数空间为 $\mathcal{M}$,从泛函分析的角度来看,梯度吃进一个变化量 $\text{d} \theta$,然后吐出一个标量 $\text{d} L$,它可以被视为一个泛函,而梯度空间就是切空间 $T_\theta \mathcal{M}$ 的对偶空间 $T_\theta^\ast \mathcal{M}$;或者从微分几何的视角全局来看,梯度空间是参数空间的余切丛 $T^\ast \mathcal{M}$。优化器的核心任务,就是寻找一个映射,把居住在余切丛的梯度,搬运回切丛,变成更新量,即

这种映射自然存在其限制。直观点说就是,需要在 $G$ 的指导下,找到一个 $\Delta W$;而为了保证更新的稳定性,我们又会对 $\Delta W$ 施加范数约束。前者,我们需要考虑的是 $\min_{U \in \mathcal{T}_W\mathcal{W}} \langle G, U \rangle$;后者则是 $\Vert U\Vert \leq 1$。我们分开来讨论。

Natural Pairing

前者这里有一个符号上的概念需要澄清。$G$ 生活在梯度空间 $\mathcal{G}$,而 $U$ 生活在切空间 $\mathcal{T}_W \mathcal{W}$,它们的“内积”怎么定义呢?一种不那么准确的理解方式是,我们把 $\mathcal{G}$ 和 $\mathcal{T}_W \mathcal{W}$ 都嵌入 $\mathbb{R}^N$ 空间,然后再人为选一种内积。但这样会解释不通,比如说没法用内积去诱导 $\mathcal{T}_W \mathcal{W}$ 上的算子范数,很奇怪。另一种更标准的解释是,这里的 $\langle \cdot \,, \cdot \rangle$ 指的是自然配对(Natural Pairing),它是一个更底层、不需要任何几何结构的概念。

具体而言,这里的 $G$ 可以看作是切空间的对偶空间上的一个泛函,它吃进一个切空间的元素 $\text{d}\theta$,吐出来一个标量 $\text{d}L$。这种元素与对偶的泛函天然就可以配对计算:$\langle G, U \rangle := G(U)$。正如我们开头说的量纲的例子,$U$ 的单位是 秒,$G$ 的单位是 米/秒,它们的乘积天然具有物理意义。

而内积,则是在自然配对之上,叠加了人为添加的规则,使得同空间中的任意两个向量可以执行配对操作。比如说 $\mathbb{R}^2$ 空间里的两个元素 $(1,2)$ 和 $(2,3)$,我们将 $(1,2)$ 变成一个对偶向量,比如说转置操作即可,$(1, 2)^T$ 就成为了一个 $\mathbb{R}^2$ 的泛函。$\langle (1, 2), (2, 3) \rangle = (1, 2)^T(2, 3) = 8$。也就是说:

除了转置,还有拉伸、旋转等操作,都可以成为这里构造对偶向量的 Metric。总之,我们在这里引入一点点微分几何的严谨定义。对于矩阵计算而言,我们取 $\langle G, U\rangle = \text{Tr}(G^TU)$

Norm Constraint

后者就是一个范数约束。根据选择范数的不同,最终会导出不同的优化器。如果我们选择 Frobenius 范数,那么就会推出 SGD 的更新公式;如果我们选择 逐元素的 $l_\infty$ 范数,就会导出 Adam 的更新公式。这是一个统一的框架。

Muon 的更新之处在于,它选择了谱范数作为约束。谱范数的计算方式是矩阵的最大奇异值,也即最大拉伸强度,所以我在论文里都使用算子范数 $\Vert\cdot\Vert_{op}$ 来表示了。而谱范数有一个非常优良的性质是 $\Vert\cdot\Vert_{op} \leq \Vert\cdot\Vert_F$,因此它是一个更紧的约束。(至于为什么更紧的约束表现更好,我只有一些直觉的猜测:大抵是模型更新似乎非常存在木桶效应,限制的更紧,不会在“短板”上崩掉,剩下的所有木板的方向才可以一起增大,加快收敛,毕竟本身更新就是一种启发式的操作;这可能也是为什么,Muon prefer larger learning rate —— 当然用对齐 RMS norm 解释可能更准确,前面的只是我的一个猜测。也许 Muon 能忍受的 RMS norm 更大?)

切空间和梯度空间的范数理应也是对偶的。谱范数的对偶是核范数,即所有特征值之和;而 F 范数比较特别,它是自对偶的。在我看来,核范数应该会更倾向于一种低秩的结构,它没办法把很多能量分给噪声方向;而 F 范数的约束就会弱一点,小方向平方后也不剩什么了,主要还是大方向主导。因此,我理解 Muon 倾向的梯度也是更加低秩的。

Muon

总之,上述介绍了一种看待优化器的统一视角,即将优化器视为一种映射操作。Muon 则可以表达为:

标准 Muon 算法

上述问题的解为 $\Delta W_{\text{Muon}} = -\text{msign}(G) = -UV^T$,其中 $G = U\Sigma V^T$ 是梯度的谱分解。它强行把所有的奇异值归一化到了1,这一极其强大的先验虽然加速了模型收敛,但把整个谱空间视为各向同性依然很奇怪。不同的奇异方向,它的崎岖程度显然是不同的,我们应当参考牛顿法,在崎岖的方向放缓步长,而在平坦的方向增大步长。在原作者的博客中,作者称这一步谱正交化的有效性可以归结为“神的怜爱”。当然,我理解这是一种调侃的说法,不过这确实是一种,强大但有效的先验。

“为何选择谱正交化 —— divine benevolence” (https://kellerjordan.github.io/posts/muon/)

Mousse

Mousse 就是针对上一问题的改进。Muon 的约束可以被视为 Stiefel 流形($\Delta W_{\text{Muon}}^T \Delta W_{\text{Muon}} = VU^TUV^T = I$,而我们要做的就是引入曲率:

这个目标可以被视为一个启发式的规则 —— 如果曲率比较大,对应 Hessian 的元素也会比较大,那么相应的 $\Delta W$ 就应该变小;反之亦然。从几何的角度来看,我们是对坐标轴进行了一次坐标变换。既然 Muon 理论上应该应用在各向同性的谱空间上,那我们就预先做一次拉伸,将“椭球”拉成“圆球”后再执行 NS 迭代,最后反向拉伸回去即可。

上述约束我们可以将其视为 $H^{\frac{1}{2}}\text{vec}(\Delta W)$ 在做内积,是一个被 $H^{\frac{1}{2}}$ 白化的空间上的 Stiefel 流形。对 Hessian 的近似,我们选用 Shampoo-style 的层级Kronecker积近似 ($H \approx (R \otimes L)^{\frac{1}{2}}$)。根据夹心公式,我们有:

令 $P = L^{\frac{1}{4}}$, $Q=R^{\frac{1}{4}}$,则现在的优化约束变为:

令 $Y = P\Delta W Q$, 那么 $\Delta W = P^{-1} Y Q^{-1}$,代回原式:

令 $\tilde{G} = P^{-1}GQ^{-1}$,则原优化问题变为:

通过与 muon 的式子做比较,我们发现它们的形式是完全一样的,直接套用结果:

所以我们只是在 Muon 的基础上进行了一次投影与逆投影,在一个更符合 Muon 假设的空间上执行谱归一化。

算法实现对比

Ablation

TODO,可以先看原文的内容~

补充

NS5 of Muon

现在的muon,基本上采用的都是 5 次NS迭代,其中每次都有不同的参数,在 Dion 库中的参数是

1
2
3
4
5
6
7
ns_consts = [
(4.0848, -6.8946, 2.9270),
(3.9505, -6.3029, 2.6377),
(3.7418, -5.5913, 2.3037),
(2.8769, -3.1427, 1.2046),
(2.8366, -3.0525, 1.2012),
]

绘制一下它的缩放性能,会发现实际的 NS 迭代并没有忠实地还原 muon 本身的设计思路。对于过小的奇异值,NS迭代并没有能力将其拉起来。也就是说,muon 优化器实际上还有一个隐含的效果,就是选择性地忽略奇异值较小的“噪声”方向,这是由于工程实践带来的。

有效范围大概在 1e-2 ~ 1

用 dion 库训的一个非 embedding 参数量为 80M 的模型。在实际训练过程中,Mean Singular Value 可能就是 1e-2 ~ 1e-4 左右的水平

Newton-Schulz 迭代前的动量的奇异值的均值

经过NS5迭代后,确实可以看到有些方向是显著偏小的。

Newton-Schulz 迭代后的动量的最小奇异值,有些方向并没有拉到1。不过,其实已经放大一点了,原来可能只有 1e-7。实际上执行 NS 迭代前会进行缩放,所以实际能不能拉到1只和 Condition Number 有关。

总而言之,感觉 NS5 忽略噪声方向也挺合理的,这可能也可以解释为什么ns系数雕花以及用单精度SVD直接算没有什么收益。