Scaling Law
偷学了一手 scaling law,遂记录一下
Critical Batch Size
原文链接:An Empirical Model of Large-Batch Training
在我们使用小 batch size 去计算梯度的时候,实际上可以视为对整体梯度的无偏估计,但存在一个噪声。batch size 越大,则噪声方差越小。然而,这种增长可能是有限的 —— 当 batch size 已经足够大时,增加 batch size 的收益可能较小。因此,有必要确定一个综合考虑下最合适的 batch size
小 batch size(蓝线)会产生更多的噪声,大 batch size (橙线)则相对稳定
对于一个模型 $\theta$,损失函数可以视为 $L(\theta) = \mathbb{E}_{x\sim \rho(x)}[L_x(\theta)]$,梯度估计为
设为无偏估计,则 $\mathbb{E}_{x_1 \cdots B \sim p} [G_{est}(\theta)] = G(\theta)$
而梯度方差(更准确的来说,协方差)则为
当 $ i \neq j$ 时,由于视 $x_i$ 与 $x_j$ 为 i.i.d ,所以
因此当且仅当 $i = j$ 时有贡献
其中 $\Sigma (\theta)$ 为单条样本的梯度协方差。这告诉我们,梯度噪声方差是随 batch size 线性收敛的。
再考虑模型优化过程,讨论简化的SGD的情况。设 $G=\nabla L(\theta)$ 为真实梯度,$H$ 为 Hessian 矩阵,$\eta$ 为学习率,批次梯度为 $g = G_{est}$,做泰勒展开
两边取期望
代入协方差
所以
对学习率求导
如果梯度没有噪声,则 $\Sigma=0$,最优学习率 $\eta_{max} = \frac{||G||^2}{G^THG}$
令 $B_{noise} = \frac{\text{tr}(H\Sigma)}{G^THG}$,则
此时有对 loss 的最优提升:
可以看到,当 $B$ 远小于 $B_{noise}$ 时,提升 $B$ 会得到较大的线性提升;而当 $B$ 远大于 $B_{noise}$ 时,提升 $B$ 对训练的影响会变小。
如果我们用一步 $\text{d}S$ 能取得的最优 loss 下降值为 $\text{d}L$,则如果我们用 batch size 为 $B$,则需要走 $(1 + \frac{B_{noise}(s)}{B(s)})\text{d}S$ 步(注意这里应为当前 step 的函数),处理的样本数为 $\text{d}E = B(s) \text{d}S$,所以
现在,我们面临一个权衡:
增加 $B$ 可以减少步数 $S$
但增加 $B$ 会增加每步的计算量 $E$
我们引入一个 “汇率” $r$ 来量化这个权衡,意为:如果对 batch size 做小改动 $\text{d}B$,为了节省一步 $\text{d}S$ 需要付出多少计算量 $\text{d}E$
注意到,一个最优的策略,汇率 $r$ 应当是一个常数。反之,则必然存在一处汇率高于平均值,一处汇率低于平均值,存在将后者的预算分配给前者的这样一个套利空间。因此,存在理论最优的汇率 $r^*$,使得:
则我们能找到最优的 batch size,随着 $B_{noise}$ 指标动态变化
现在,我们将这一式子代回 $S$ 与 $E$ 的表达式,同时约定:
则代入后有:
所以:
令右式为 $\gamma$,我们能得到一个很漂亮的幂律关系:
Tokens Limited
已知单步最优损失下降为
实际上,我们关注的并不是单步的 $\Delta L_{opt}$,而是总训练过程的 $\int_0^{S_{final}} \Delta L_{opt}(B(S))dS$
显然,这里有一个约束关系,我们考虑总tokens数恒定。
应用拉格朗日乘子法
取极值必要条件是 拉格朗日量 $F = \frac{\Delta L_{max}}{1 + \frac{B_{noise}}{B}} - \lambda B(S)$ 满足 欧拉-拉格朗日 方程
其中
所以
最终的结果是
这里的 $\Delta L_{max}$ 可以视为随训练变化的超参。根据经验,我们知道训练后期的梯度是很小的,而cosine learning rate的话,H 会升高,所以 $\Delta L_{max}$ 会很小。这意味着在训练后期,收益已经很小了,与其增大batch size 减少噪声,不如小batch size 多探索。至少我是这么理解的
H 比较难算,但如果只近似算对角线元素计算复杂度还是和 forward-backward 一样的,可以参考AdaHessian的近似。
OpenAI Scaling Law
原文链接:Scaling Laws for Neural Language Models
在上述推导中,我们得到了一个稍稍有点复杂的幂律关系,然而 OpenAI 却提出上述公式可以简化为:
也即
拟合出来的结果非常准
搜索并拟合不同参数量不同 Loss 下 Emin 和 Smin 的关系
从数学角度理解,这其实是 Cauchy–Schwarz 不等式的一个特例。其积分形式为:对于定义在积分域 $S$ 上的任意两个实函数 $f(s)$ 和 $g(s)$,以下不等式恒成立:
等号成立当且仅当存在常数 $c$ 使得 $f(s) = c g(s)$
这意味着,OpenAI 提出的近似要成立,必须满足 $B_{noise}(s) = c$,也即 $\frac{\text{tr}(H\Sigma)}{G^THG}=c$,显然,这是一个较为苛刻的条件。直观上来看,stable 学习率设定下的训练后期可能可以满足这个条件
同时,我们令 $B_{critic} = \frac{E_{min}}{S_{min}}$,可以视作 $B_{noise}(s)$ 的面积 除以 区间长度,即平均 $B_{noise}$。这是一个 Loss 曲线的指标,代表着该任务的平均噪声强度,越大则所需的batch size越大。
OpenAI 对该指标进行了拟合,得到结论
其中 $B_* \sim 2\cdot 10^8 $ tokens,$\alpha_B \sim 0.21$ 是常数
有了平均值的刻画,我们就不需要积分了,直接代入平均值,可以近似定义:
$S_{min}$ 代表着当 $B \rightarrow \infty$ 时,理论最优的 Steps 数,$S$ 可以看作实际 $B$ 对应的
通过这个公式计算得到的 $S_{min}$,OpenAI 拿去拟合了下述的 Power Law 公式
效果很不错,如图所示
搜索并拟合不同参数量下 Loss 和 Smin 的关系
题外话
上述 $B_{critic} = \frac{E_{min}}{S_{min}}$ 的定义是没问题的,但是 OpenAI 拿来拟合 $B_{critic}$ 的值是用公式
得到的,但这个公式是不准的,导致得到的 $B_{critic}(L) \approx \frac{B_*}{L^{\frac{1}{\alpha_B}}}$ 也不准,虽然实践中效果确实不差。在我看来,这些拟合的公式都是经验性的,量纲都对不齐,所以没事不要想着代来代去,除非很确定其中每一步的约束与近似。
还有,大部分推导都是在 SGD 上做的,这就已经和 Adam 有很大的区别了,不必奉为圭臬。当然,有个参考还是很好的。
Multi-Power Law
写不动了,这篇我也只是听了一下讲解粗看了一下,随便记一点(((
上述公式刻画了 Loss 与 N,S,B,C,但没有刻画与 learning rate 的关系。下面一篇就填补了这个空缺,并且实践效果比较好。
假设 $t$ 步时,使用 stable 策略的等效步数为 $Z(t) = \frac{S(t)}{\eta_0}, 其中 S(t) = \sum_{\tau=0}^t \eta_{\tau}$
作者将 loss 拆解为了两个部分:
其中后一项差值定义为 $LD(t) := L_{const}(Z(t)) - L(t)$,而前一项则根据过去经验定义为 $L_{const}(Z(t)) = L_0 + A \cdot (S(t) + S_W)^{-\alpha}$,$S_W$ 为 warmup 阶段的学习率之和
关于第二项,可以认为是学习率衰减造成的偏差,作者经过简单的拟合,发现线性拟合就还不错
不过,这个还不准,所以考虑了另一种分析框架。我们定义一些辅助轨迹
$L_0(t)$,用恒定学习率 $\eta_0$ 一直训练下去
$L_1(t)$,第1步用 $\eta_0$,之后用恒定学习率 $\eta_1$ 一直训练下去
$L_2(t)$,用 $[\eta_0, \eta_1]$,随后用恒定学习率 $\eta_2$ 一直训练下去
我们定义 $S_k(t) = \sum_{\tau=k}^t \eta_{\tau}$,则 $L_k$ 的等效学习步数 $t_k$ 满足
对于 $k$ 和 $k+1$ 过程,中间 Loss Reduction 为:
则通过差分得到最终的 $LD$ 项为
作者做了实验发现 $LD$ 会先上升然后到达一个界限,所以猜了一个幂律关系去拟合
进一步分析上述的 $\tilde{B}$,$\tilde{C}$,有 $\tilde{B} = B(\eta_A - \eta_B)$,$\tilde{C} = C \eta_B^{-\gamma}$,所以
最终拟合得到的一个有关 learning rate 的 scaling law
总的来说,MPL的拟合实际试下来比较准的,但它只能跨learning rate,所以应用范围比较有限。
写不动了,MPL 有机会再补充吧