Introduction

Greg Yang 大名鼎鼎的 Tensor Program,给出了模型如何高效 feature learning 的一个有效先验与相关理论实践。本篇 blog 介绍一个简化、或者说是更本质的版本,被概括为 Spectral Condition,谱条件。

Spectral Condition

谱条件总共分为两个部分:1)如何设置初始化, 2)如何 scale 学习率

它们的目的分别是控制 $\lVert W \rVert$ 与 $\lVert \Delta W \rVert$ 为 $\Theta(1)$。控制在 $\Theta (1)$ 的先验其实很朴素,太大更新就会炸掉,太小则捕捉不到feature的信息,总之是为了保证模型进行适当的学习。

当然,这里还有很多问题,比如说范数怎么取,之类的,后面会逐步介绍。

Sparse/Dense Vector & Natural Norm

作者区分了两类向量:

Sparse vector

指的是,只有 $\Theta(1)$ 数量的分量非零,比如 one-hot 向量。它的 $\mathcal{l}_2$ 范数本身只有 $\Theta(1)$ 级别。

Dense vector

有 $\Theta(m)$ 数量的分量非零,很多分量都在贡献长度。因此,$\mathcal{l}_2$ 范数满足

Natural Norm

上述的 Dense Vector 就会有一个问题,它的 $\mathcal{l}_2$ 范数大小会随着网络宽度而增长,不能忠实地反映元素的数量级,而我们想要保持的什么 $\Delta W \sim \Theta(1)$ 这种性质都不应和网络宽度耦合在一起。

因此,作者定义了一个自然范数,用于修正这种偏差。即对于 Dense Vector 而言,它的自然 $\mathcal{l}_2$ 范数为:

这其实就是 RMS norm,它只关心每个分量平均有多大。而对于 Sparse Vector 则不用修正

Natural Spectral Norm

在定义了 Natural Norm 之后,我们自然可以诱导一个 Natural Spectral Norm

对于输入输出均为 Dense 的算子而言,有:

这里原文疑似写反了......

而对于 embedding 这种,就是

Initialization

我们希望矩阵的自然谱范数满足:

Linear & lm_head

该条件等价于

对于一个正态分布初始化的 Gaussian 矩阵 $W \in \mathbb{R}^{m \times n}$,若元素标准差为 $\sigma$,通常有:

这个结论我们可以考虑随机矩阵中的 Marchenko-Pastur 定律。对于 $W$ 这样一个均值为 $0$,方差为 $\sigma^2$ 的 i.i.d. 随机矩阵,在 $m, n \rightarrow \infty$ 且 $\gamma = \frac{n}{m}$ 收敛到一个常数的极限下,它的样本协方差矩阵

的特征值分布将依概率收敛于 Marchenko-Pastur 分布,支撑集是:

而我们知道奇异值的平方就是 $W^TW$ 的特征值,因此

实在是一个非常漂亮的结论。当然还有其它解释方法,不过我觉得随机矩阵理论是一个很本质的解释。

而要让 $\sigma(\sqrt{m} + \sqrt{n}) \sim \sqrt{\frac{m}{n}}$,我们有:

这里本来也不是精确的数值关系,原论文干脆搞了一个近似:

大概差一个常数数值,用起来可能差不多吧,很难说。

对于 lm_head 层,我们有

而我们知道 $V \gg d$,所以对于 lm_head,有 $\sigma \sim \frac{1}{\sqrt{d}}$

Embedding

对于 Embedding 层,该条件等价于

然而,这个推导是有问题的,原因在于对于自然谱范数

事实上满足这个 $\lVert x \rVert=1$ 的条件的集合,包含了大量模型永远不会遇到的方向,Embedding 层实际上只会遇到 one-hot 向量而已。所以实际上,我们只需考虑

这里的 $w_i$ 其实就是 Embedding 矩阵中的一列,其中每个元素也有 $w_{ji} \sim \mathcal{N}(0, \sigma^2)$,因此有

所以

Code

实践下来还有一个对于 $\text{c_proj}$ 层的缩放,主要是考虑到残差的叠加问题。还有 embedding 层也是,有说不要把 $\sigma$ 设为 $\Theta(1)$,而是更小一点保证稳定。Anyway,实践牵扯的东西很多,大家可以在自己的setting下试一下看看是否需要这些 trick。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import math

import torch
import torch.nn as nn


def _spectral_linear_std(fan_out: int, fan_in: int) -> float:
""" This is a practical approximation from paper"""
return (1.0 / math.sqrt(fan_in)) * min(1.0, math.sqrt(fan_out / fan_in))


def _spectral_linear_std_noapprox(fan_out: int, fan_in: int) -> float:
""" This is the theoretical spectral bound """
return (1.0 / math.sqrt(fan_in)) * (math.sqrt(fan_out) / (math.sqrt(fan_in) + math.sqrt(fan_out)))


def _spectral_embedding_std_stable(n_embd: int) -> float:
""" This is a practical approximation to keep stable training """
return 1.0 / math.sqrt(n_embd)


def _init_spectral_module(module_name: str, module: nn.Module, config) -> None:
if isinstance(module, nn.Linear):
fan_out = module.weight.size(0)
fan_in = module.weight.size(1)
if config.spectral_linear_init_std == "approx":
std = _spectral_linear_std(fan_out, fan_in)
elif config.spectral_linear_init_std == "noapprox":
std = _spectral_linear_std_noapprox(fan_out, fan_in)
else:
raise ValueError(f"Invalid spectral linear init std: {config.spectral_linear_init_std}")

if module_name.endswith("c_proj"):
std = std / math.sqrt(2 * config.n_layer)
torch.nn.init.normal_(
module.weight,
mean=0.0,
std=std,
)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
if config.spectral_embedding_init_std == "stable":
std = _spectral_embedding_std_stable(config.n_embd)
elif config.spectral_embedding_init_std == "exact":
std = 1.0
else:
raise ValueError(f"Invalid spectral embedding init std: {config.spectral_embedding_init_std}")
torch.nn.init.normal_(
module.weight,
mean=0.0,
std=std,
)

Learning rate scaling

TODO

Conclusion

我自己的实验框架应用了 Spectral Condition 后确实拿到了收益;同时也踩了一些坑,比如说 mHC 引入的那些参数,不可一股脑应用 mup。自己试一遍会理解更深刻。

对于一般的线性层,做的是 feature $\rightarrow$ feature 的操作,那自然谱范数就是一个最合适的选择。我觉得一个很重要的思想是,我们应当把矩阵当成一个算子而不是一系列参数来看待,控制算子强度的 谱范数 会比 F范数 这些信息更核心。

对于像 embedding 这种层,我们需要注意其谱范数定义域的差异,即此时输入只有 $e_i$ 而非任意的 $\lVert x \rVert = 1$。不过我其实有点好奇,lmhead 应该和中间的 linear 享受同样的策略吗?因为它的输出其实是某种概率分布,而不是一个feature。控制它的谱范数似乎意义不大,只是会影响 softmax 的温度系数而已,也许这里有更好的控制策略偏置?总之,希望能在理解谱范数本质的基础上考虑去应用 mup,知道哪些 module 该套,哪些module 不该套,以及对于这些并非普通 linear 的 module 如何应用 mup 策略。