Part I of Mathematical Structure of Mamba - Hippo
本篇是mamba系列blog的第一篇文章
Part I of Mathematical Structure of Mamba - Hippo
剩余预计还有一篇文章正在生产中~
序言
考虑一个很长的一维序列,当我们希望模型具有“记忆”能力的时候,我们实际上期望模型能在当前时间步对很久以前的时间步上的数据点具有无损恢复的能力。然而,模型的大小是不可能随序列长度增长的,也就是说我们需要使用有限的参数恢复出无限多时间步的数据点。
很显然无损恢复是不可能的。但是,我们可以用一种具有渐进收敛性的方法,尽可能减少这种损失。我们面临的第一个问题就是,怎样描述这个损失?一个非常自然的想法是直接把我们恢复出来的数据点和真实的数据点的距离求 $L_2$ 范数。Naive的 $L_2$ 范数假设了所有历史数据点是同等重要的。
到这里,我们起码有了一种最简单的求损失的方法。如果我们把所有真实数据点看作对时间步的函数,那么我们上面求损失的过程正是函数逼近的过程,即我们不知道真实函数的表达式,但我们获取了它的若干采样数据点,我们可以依赖这些数据点,选取一个已知表达式的函数来逼近它。这个用于逼近的函数包含有限多的待优化的参数,而参数的数量不随序列长度变化。
记忆问题转化为了函数逼近问题
数学框架
考虑定义在 $t>0$ 上的一维连续函数 $f(t)\in \mathbb{R}$,则 $t$ 时刻之前的历史可以表述为
对于每个 $t$ 时刻,我们希望找到一个已知表达式的函数 $g^{(t)}\in \text{span\{}\mathcal{G}\}$ 来逼近历史 $f_{\leq t}$,其中 $\mathcal{G}$ 是我们所有可选择函数构成的函数族。
在给定概率测度 $M^{(t)}$ ,对应概率密度 $\mu$ 满足 $\int_{-\infty}^{t}\mu^{(t)}(x)\text{d}x = 1$ 的情况下,两个函数的内积可以表示为:
同时,带概率测度的 $L_2$ 范数变为
要进行函数逼近,我们首先要选取一组基函数来构建函数族。最常见的选择是使用多项式基,来对概率测度 $M^{(t)}$ 构建一组 正交多项式 $\mathcal{G}=\{g_n\}_{n<N}$,满足 $\forall i,j \in [N], \langle g_i, g_j \rangle _{\mu^{(t)}}=\int_0^{\infty}g_i(x)g_j(x)\mu^{(t)}(x)\text{d}x=0$,来找到一个 $g^{(t)}=\sum_{k=0}^{N-1}c_k(t)g_k$ 去最小化在测度 $M^{(t)}$ 下与 f(t) 之间的损失,即:
对于上式, $f_{\leq t}$ 可以直接观测得到,$\mathcal{G}$ 可以通过构造得到(勒让德多项式、切比雪夫多项式、雅可比多项式等等),$\mu^{(t)}$可以直接设定(例如设定为所有历史同等重要,或只关注此时刻往前固定窗口大小的历史),那么我们唯一要优化的就是一组系数 $[c_k(t)]_{k\in[N]} \in \mathbb{R}^{N}$
后文中介绍了如何在测度加权的情况下推导出正交多项式
在某一时刻 $t$ 的一组参数 $c_k(t)$ ,我们可以简单通过最小二乘法求得
对 $c_i$ 求导,令导数为0
由于 $\int g_k^{(t)}(x)g_i^{(t)}(x)\mu^{(t)}(x)\text{d}x=1$ 当且仅当 $k=i$,所以
也即 $c_i(t)=\langle f_{\leq t}, g_i^{(t)} \rangle_M = \int f_{\leq t}(x)g_i^{(t)}(x)\mu^{(t)}(x)\text{d}x$
数学上,我们可以简单计算得到。但考虑时间复杂度,每次计算都需要遍历所有的历史求出 $f_{\leq t}$,$n$ 次操作的总复杂度依然是 $O(n^2)$ ,无法接受。
可以看到 $c(t)$ 是关于时间 $t$ 的函数,考虑到相比上个时刻 $T-\delta$ 仅有一个增量数据点 $f(T)$,自然考虑 $c(t)$ 的求解是否存在递推形式,也即关注 $c(t-\delta)$,$c(t)$,$f_{\leq t}$ 三者之间的关系。在连续情况下当 $\delta \rightarrow 0$ 时,我们关心的是 $\frac{d}{dt}c(t)$,$c(t)$,$f_{\leq t}$ 三者之间的关系,也就是求解一个关于 $c(t)$ 的常微分方程(ODE):
数学推导
根据前面的数学框架,我们需要找到一组正交多项式 $\{g_n\}_{n<N}$ 和测度 $M^{(t)}$,并求解常微分方程 $\frac{d}{dt}c(t)=u(t, c(t), f_{\leq t})$。选取不同的多项式和测度,我们会得到不同的ODE方程和解,此处仅介绍 HiPPO 框架中最成功的一种变体:HiPPO-LegS(Scaled Legendre,LegS),它使用勒让德多项式作为正交多项式,并选取均匀重要性的测度。
概率测度
HiPPO论文中提出的三种概率测度,从左往右分别为:LegT(将滑动窗口内的历史视为同等重要)、LagT(历史的重要性指数衰减)、LegS(将所有历史视为同等重要)
在某个时刻 $t$,将此前的所有历史视为同等重要,我们得到HiPPO-LegS的测度:
正交多项式
为了获得正交多项式,一般的做法是对一组函数进行Gram-Schmidt正交化。当我们对 $1, x, x^2, \cdots$ 进行正交化后,得到的就是勒让德多项式:$P_0(x), P_1(x), P_2(x), \cdots$。
标准勒让德多项式有一个性质,我们不加证明地给出:在区间 $[-1,1]$ 上,两个多项式 $P_n, P_m$ 具有以下关系:
$\delta_{nm}$ 表示当 $n \neq m$ 时等于 0,因为正交嘛;当 $n=m$ 时等于1
由于该性质定义在区间 $[-1,1]$ 上,我们进行变量代换 $x=\frac{2x’}{t}-1$ 放缩到我们关心的 $[0,t]$ 区间,代换后的公式为:
两个多项式在区间 $[0,t]$ 上积分等价于在给定测度 $\mu^{(t)}(x)=\frac{1_{[0,t]}}{t}$ 下在整个数轴上的积分,上式写为:
我们仅考虑 $n = m$ 的情况:
令 $g_n^{(t)}(x)=(2n+1)^{\frac{1}{2}}P_n(\frac{2x}{t}-1)$,我们有:
所以,在测度 $M^{(t)}$ 下归一化的勒让德正交多项式为:
怎么证明当 $n \neq m$ 时有 $\int g_n^{(t)}(x) g_m^{(t)}(x) \mu^{(t)}(x) \text{d}x = 0$ ?
常微分方程
要计算 $\frac{d}{dt}c(t)$, 我们需要对 $c(t) \in \mathbb{R}^N$ 的每一维 $c_n(t)$ 求导。根据 $c_n(t):=\langle f_{\leq t}, g_n^{(t)} \rangle_{\mu^{(t)}}$,同时我们希望避免对 $f_{\leq t}$ 进行计算,那么可以借助计算 $\frac{d}{dt}g_n^{(t)}$ 和 $\frac{d}{dt}\mu^{(t)}$ 得到(看了后面推导就知道为什么了)。
在此前的推导中,我们将 $t$ 视为一个固定的时刻而非变量,因此 $t$ 作为上标出现。而在接下来的ODE推导中,时间 $t$ 也是变量,因此下面的记号改为类似 $g(t, x)$,$\mu(t,x)$ 的形式。
$f_{\leq t}$ 是一个已经确定的函数,代表的是我们看到的采样点背后的真实逻辑,并不是 $t$ 的函数。
可以理解为,$f_{\leq t}$ 是通过 $f(1), …, f(t)$ 拟合出来的一个函数,求出这个函数的复杂度是 $O(t)$
补充一下勒让德多项式的递推关系式 $(2n+1)P_n=P_{n+1}’-P_{n-1}’$,以及推论 $(x+1)P_n’(x)=nP_n(x) + (2n-1)P_{n-1}(x) + (2n-3)P_{n-2}(x) + \cdots$,下面是勒让德多项式的表,可以自己验证一下。
标准勒让德多项式
对于 $\frac{\partial}{\partial t}\mu(t,x)$,我们有:
对于 $\frac{\partial}{\partial t}g_n(t,x)$,我们有:
进行变量代换 $z=\frac{2x}{t}-1$ 简化公式,并使用勒让德多项式的推论可以得到:
然后,我们可以来计算 $\frac{d}{dt}c_n(t)$
代入 $\frac{\partial}{\partial t}g_n(t,x)$,$\frac{d}{dt}\mu(t,x)$ 的计算结果,可以得到:
其中 $g_n(t,t) = (2n+1)^{\frac{1}{2}}P_n(1)=(2n+1)^{\frac{1}{2}}$,代入得到:
至此,我们完成了推导,并实现了初始目标,将 $\frac{d}{dt}c(t)$ 表示为了 $c(t)$ 与 $f_{\leq t}$ 的函数,且不依赖对观测函数 $f_{\leq t}$ 的积分,避免了对所有历史的计算。有了 $\frac{d}{dt}c_n(t)$ 的表达式,我们可以将向量 $c(t)$ 表示为矩阵计算形式:
这里的矩阵 $A$ 就是大名鼎鼎的HiPPO Matrix,S4(Structured State SpaceS)中“结构化”指的就是用结构化的矩阵 $A$ 来对 SSM 的参数初始化
HiPPO矩阵初始化代码
补充一下HiPPO矩阵初始化的代码,HiPPO矩阵是一个 $N \times N$ 维的下三角矩阵:
1 | def make_HiPPO(N): |
返回 $-A$ 是为了和后续的SSM状态方程对齐,把负号放到参数内部。
HiPPO 离散化
以上推导都是在假设 $f(t)$ 是关于 $t$ 的连续函数的情况下进行的。通过离散化,我们可以获得针对序列数据的离散ODE。连续函数离散化的方法有很多,例如前向欧拉、后向欧拉、双线性等等。在这里我们介绍形式相对简洁的前向欧拉方法以及离散化后的HiPPO-LegS,需要注意的是在HiPPO论文的实验以及S4中使用的都是双线性离散化方法,因为具有更好的数值稳定性,而mamba则是采用的零阶保持(zero-order hold, ZOH)方法。
已知 $\frac{d}{dt}c(t) = -\frac{1}{t}Ac(t) + \frac{1}{t}Bf(t)$ ,对于前向欧拉,假设采样间隔为 $\Delta t$,我们从时刻 $t$ 向前走一小步得到 $t+\Delta t$,对这一小段在等式两边积分得到:
对于等式右边我们将 $[t,t+\Delta t]$ 内 $\frac{d}{dt}c(t)$ 的值都保持为 $t$ 时刻的值来近似,有:
移项得到:
将 $t_k$, $t_k+\Delta t$ 时刻分别当成离散的 $k$,$k+1$ 时间步,即 $t_k=k \cdot \Delta t$,我们得到前向欧拉离散化的HiPPO-LegS:
双线性离散化就是 $c(t+\Delta t)-c(t)=\frac{\Delta t }{2}(\frac{dc(t)}{dt}+ \frac{dc(t+\Delta t)}{dt})$
在S4D部分会介绍零阶保持方法。
Reference
组会内部分享报告 《A Gentle Introduction to HiPPO🦛 and its Friends》