Muon 与 AdamW 的对比

AdamW

AdamW 我们都很熟悉

若我们仅考虑第一步,则 $\theta_1 = \theta_0 -\eta_t (\frac{G_t}{|G_t| + \epsilon} + \lambda_t \theta_0) \approx \theta_0 -\eta_t [\text{sign}(G_t) + \lambda_t \theta_0]$,只不过后续随着动量累积,行为逐渐变得复杂。

AdamW 的自适应体现在,它用二阶矩估计来自适应地调整更新步长。它不关心梯度的方向,只是对于梯度较大的步骤,它会强制减小更新幅度;反之它会鼓励较小的梯度更新走的更长。虽然有点玄妙,但已经经过千锤百炼证明稳定性了。

Lion

Lion 感觉像是 Alpha Evolve 的雏形。给定 AdamW 的过程,指定几种可能的操作(删除、修改、增加),然后遗传算法开搜。

当然这不是重点。重要的是 Lion 指出,不需要额外的二阶矩估计来自适应更新幅度,只需要sign归一化即可。

原论文搜出来的结果是 $\beta_1 = 0.9, \beta_2 = 0.99$,在NLP任务上是 $\beta_1 = 0.95, \beta_2 = 0.98$

经过sign归一化后,$u$ 每个分量的更新绝对值都是1。为了和AdamW计算的尺度对齐(通常比AdamW的更新值大10倍),因此学习率要缩小10倍以上;为了保持权重衰减的幅度不变,权重衰减就要放大相应的倍数。

AdamW 上,预训练1B模型时的学习率可以取到 $5 \times 10^{-3}$ 左右,在 Lion 上衰减一下就到 $10^{-4}$ 级别;而 weight decay 也需要相应放大10倍。

Tiger

苏神同时做了 $\beta_1 = \beta_2$ 的实验,虽然效果不如Lion,但是尚在可接受范围内。同时发现可以通过简单改动,节省下梯度累积时需要存储历史梯度的开销。感觉比较trivial。

理解 sign

Lion通过sign 操作平等地对待了每一个分量,使得模型充分地发挥了每一个分量的作用,从而有更好的泛化性能。如果是SGD,那么更新的大小正比于它的梯度,然而有些分量梯度小,可能仅仅是因为它没初始化好,而并非它不重要,所以 Lion 的sign 操作算是为每个参数都提供了“恢复活力”甚至“再创辉煌”的机会。

这里苏神指出,虽然在训练开始阶段考虑泛化比较合理,但如果一个参数的梯度长期较小,那似乎确实可能说明这个参数作用不大。可能可以用别的策略来更多地鼓励重要参数的更新,适当增加其更新幅度。

Muon

我读下来,感觉 Muon 最大的特点是,它的归一化不再是element wise的。这样,Lion中遇到的“躺平”参数问题就可以自然而然地解决了,天然地会把更新量分给别的重要参数。

对于矩阵参数 $W \in \mathbb{R}^{n \times m}$,更新公式如下

其中,我们设动量矩阵的SVD分解为 $\text{SVD}(M) = U, \Sigma, V^T$,则 $\text{msign}(M) = U_{[:, :r]}V_{[:, :r]}^T$

易知 $U \in \mathbb{R}^{n \times n}, \Sigma \in \mathbb{R}^{n \times m}, V \in \mathbb{R}^{m \times m}$,$r$ 是矩阵的秩

也就是说,我们是对动量矩阵 $M_t$ 做奇异值分解,用sign函数来归一化 $M_t$ 的奇异值,在矩阵层面做 sign 归一化。

Muon的计算

Newton-schulz迭代

SVD太复杂,我们需要推导其它的 Muon 的计算形式

计算复杂度偏高的一步就是求矩阵指数 $(M^TM)^{-\frac{1}{2}}$,我们考虑在 $M^TM = I$ 处泰勒展开 $(M^TM)^{-\frac{1}{2}}$。考虑标量函数 $t^{-\frac{1}{2}}$

我们保留到二阶,结果是 $t^{-\frac{1}{2}} \approx \frac{15}{8} - \frac{5}{4}t + \frac{3}{8} t^2$,代入矩阵 $M^TM$

这里我们可以不断把计算得到的 $\text{msign}(M)$ 代入计算,得到更好的 $\text{msign}(M)$

为什么可以迭代呢?因为 $\text{msign}(M)$ 是一个不动点!已知 $M^TM = I$,代入右式后结果就是 $M$

稍微深入想一下,对于任意的 $F(M) \rightarrow I$ 的问题,我们都可以通过在 $F(M) = I$ 处泰勒展开来得到一个 $n$ 阶的迭代式,只要起始点不太远就大概率能收敛。真无敌了。

网络搜索

我们不从泰勒展开出发,然是直接将 $M_{t+1} = aM_t + bM_t(M_t^TM_t) + cM_t(M_t^TM_t)^2$ 作为一个优化问题,去求解 $a,b,c$

令 $M_0 = \frac{M}{||M||_F}$,这样不改变 SVD 后得到的 $U,V$,同时可以让 $X_0$ 的所有奇异值在 $[0, 1]$ 之间,更加稳定。

更新公式可以表示为:

显然,中间是对角阵,我们其实只是在迭代这个对角阵。并且对角阵的幂只是各个对角线元素各自求幂,因此问题可以简化成单个奇异值的迭代。我们只需要输入 初始奇异值 $\sigma$,迭代次数 $T$,在迭代 $g(x) = ax + bx^3 + cx^5$ $T$ 次后拟合最终结果到1即可

重参数化

在 $a,b,c$ 的初始值选择上有一个小技巧,重参数化 $g(x) = ax + bx^3 + cx^5 = x + kx(x^2-x_1^2)(x^2-x_2^2)$

这样的好处是,可以直观的表示出了迭代的5个不动点 $0, \pm x_1 \pm x_2$,选择 $x_1 < 1, x_2 > 1$,可以让迭代保持在 1 附近。用MSE作损失函数训练即可。

用随机矩阵作为训练集,是否合理呢?

从范数的视角理解 Muon

理解 SGD

接下来我们从邻近梯度下降(Proximal Gradient Descent)出发,从更广义的角度理解梯度下降法。为了简便我们以向量为例。

邻近梯度下降的公式是:

可以直观地理解,一方面我们希望损失 $\mathcal{L}(w)$ 最低;另一方面,我们又不希望 $w_{t+1}$ 离 $w_t$ 太远,以免崩溃。$\eta$ 就是一个调节这个探索范围大小的参数。

如果 $\eta$ 足够小,可以认为 $\Delta w = w_{t+1} - w_t$ 是很小的,因此我们可以对 $\mathcal{L}(w)$ 做泰勒展开:

其中,$\mathcal{L}(w_t)$ 是常量,不影响结果,可以省去。令 $g_t = \nabla_{w_t}\mathcal{L}(w_t)$,我们有:

为了求出这个argmin,我们对 $\Delta w$ 求导。如果 $||\cdot||$ 是 L2 范数,则 $||w||^2 = w^Tw$,有

解得最终的结果竟然就是梯度下降

这一推导告诉我们,梯度下降就可以视为 学习率加权的L2-范数约束下将损失函数近似为简单函数 的邻近梯度下降

Muon 的范数本质

回到矩阵 $W \in \mathbb{R}^{m \times n}$,更新规则应为:

我们选择 Frobenius 内积,$\langle G_t, \Delta W \rangle_F = \text{Tr}(G_t^T\Delta W)$

将 $\Delta W$ 解耦为范数和方向,设 $\gamma = ||\Delta W||$, $\Phi = - \frac{\Delta W}{||\Delta W||}$,我们得到:

如果选择 F 范数,则和把矩阵展平后的 L2 范数一样,也就是我们在向量形式中推导的梯度下降

如果选择谱范数,则有

设 $G$ 的SVD分解为 $ U \Sigma V^T = \sum_{i=1}^r \sigma_i u_i v_i^T$,我们有

由于约束 $||\Phi||_2=1$,有 $||\Phi v_i||_2 \leq ||v_i||_2 = 1$,于是 $u_i^T \Phi v_i \leq 1$,所以

等号在所有 $u_i^T \Phi v_i$ 都等于1时取到,此时

简单代回,我们知道:

注意这不是原来 求min 的解,只是由约束条件推出的,它能告诉我们的是,Muon就是谱范数约束下的最速下降方向。具体的步长由学习率调控。

而 谱范数 的约束强于 F范数,恒成立 $||\Phi||_2 \leq ||\Phi||_F$,这可能是暗示 Muon 优越性的一个点。

Adam参数迁移

在搜寻 Muon 超参数时,苏神介绍了一种直接从已经调好的AdamW参数迁移的方法。(Kimi应该还是搜了一遍的)

具体而言,他们观察到 Adam 更新量的 RMS (Root Mean Square) 较为稳定,通常在 0.2 ~ 0.4 之间。RMS 定义为:

因此,希望将Muon的更新量也对齐到这个范围,最终取为0.2

这样,可以复用 Adam 的参数,令 $\eta_t : = \frac{0.2\eta_t}{\text{RMS}(M_t)}$ 即可

更进一步,Muon的动量的RMS值是可以显式算出来的:

考虑到随机矩阵严格低秩的概率比较小,这里可以认为 $r = \min(n, m)$,从而有 $\text{RMS}(M) = \sqrt{\frac{1}{\max(n,m)}}$,最终的更新公式就是:

从稳定秩的角度考虑,$M = UV^T$ 天然就是满秩,所以取 $r = \min(n, m)$ 完全没问题

比较有意思的一个点就是,这个更新公式也指出了,不同形状的参数需要有不同的学习率。

参考文献

Muon优化器赏析:从向量到矩阵的本质跨越

Muon续集:为什么我们选择尝试Muon?