本篇是mamba系列blog的第二篇文章,系列文章见:

剩余预计还有一篇文章正在生产中~

State space model

前一篇文章介绍的HiPPO为我们奠定了在离散时间系统下维护模型状态的方法。视 $c_k$ 为模型当前的状态,$f_k$ 为模型当前的输入序列,即可通过公式更新得到模型后续的状态 $c_{k+1}$,这是一个RNN-like的结构。无独有偶,在工程学科中的状态空间模型(State Space Model)也有类似的结构:

直观理解,$A$ 代表了模型当前状态如何影响后续状态,$B$ 代表了模型当前输入如何影响后续状态,$C$ 代表了模型当前状态如何影响模型输出,$D$ 代表了模型当前输入如何影响模型输出。

我们可以对这建立在连续信号上的公式离散化,得到离散数据上的形式:

$\bar{A}, \bar{B}, \bar{C}, \bar{D}$ 的具体形式可以通过不同的离散化方法(前向欧拉、后向欧拉、双线性等)从 $A,B,C,D$ 得到。这种形式的SSM可以在离散数据上计算,并且计算方式是step-by-step的。

例如,对于双线性离散化方法有:

为什么我们需要离散化

个人看法

这个问题很关键,它直接决定了Mamba的改进目标。

试想一下,以文本为例,我们的输入输出都是离散的模态,那所谓的连续究竟是在何处呢?是模型内部的状态演化。上述的一系列以SSM为动力学建模的模型,都假设了模型存在一个连续的状态。但这实际上有保证吗?

考虑 $h(t’) = h(t) + x(t)$,对于离散的冲击 $x(t)$,$h(t)$ 显然就无法保持连续了。所以,需要对该更新方程进行离散化,在离散方程 $h_{t+1} = \bar{A}h_t + \bar{B}x_t$ 上的冲击 $x_t$,其实被隐式地理解为了在区间 $[t, t+1)$内的持续影响。

这一步是否合理有待商榷,毕竟把离散冲击当成持续影响这一步本身就有点抽象。在mamba中,为了解决这一个问题而提出的方法是,构建输入依赖的离散采样率 $\Delta$,通过预测 $\Delta_k$ 动态调整离散化区间的长度:

  • 若 $\Delta_k$ 较小,系统更频繁地响应输入,接近原来的离散冲击。

  • 若 $\Delta_k$ 较大,系统将输入“平滑扩散”到更长区间

所以理论上,它在某些时刻应该是不能保证连续的,只是想了办法让脉冲的影响可控。

Update

经xr哥提醒,开头的token embedding可能是离散的,但是到中间层时,hidden state可能已经比较连续了。我记得S4的续作H3就有结合transformer和SSM,难道就是开头用transformer,中间用SSM?这样就可以减少离散脉冲对理论结构的影响了。所以Jamba这种混合模型,就该是开头transformer,中间SSM?

感觉mamba不同层之间的 $\Delta$ 应该会有显著差异吧

后续要再去钻研下

建模

我们首先忽略 $\bar{D}$ ,这是一个简单的残差连接。

对于离散形式公式做简单推导:

可以看到第 $k$ 步的输出具有通项形式:

这是一个卷积核为序列长度的卷积,可以记作:

因此,State Space Model既可以建模为RNN,也可以建模为CNN。在我的理解中,prefill阶段就可以使用卷积形式,结合离散傅里叶变换DFT,只需要先变换到频域上再逐点相乘即可($\mathcal{F}\{f*g\} = F(f) \cdot F(G)$),时间复杂度为 $O(L\log L)$;后续推理继续用RNN形式自回归地生成,复杂度为 $O(L)$。

这里我考虑得非常不严谨,实际上我自己也没想清楚。时间复杂度瓶颈还可能受到,矩阵乘法的复杂度,对矩阵进行傅里叶变换的复杂度,算子实现等的影响。如果想要严格讨论时间复杂度,可能配合代码分析更加合适。

结构化状态空间

在S4的前置工作中,作者发现随机初始化的SSM性能非常差,但使用HiPPO Matrix对 $A$ 初始化:

就取得了非常大的性能提升。例如,仅仅对矩阵 $A$ 从随机初始化改成HiPPO初始化就在sequential MNIST上把性能从60%提升到了98%,这步可以称为是结构化。

注意这里把之前公式中,外面的负号放进来了。

高效算法

不过,对SSM进行直接计算存在效率上的问题:

  • 对于Recurrent形式的SSM,由于参数可学习,因此在每一步上都需要重新计算 $\bar{A}=(I-\frac{\Delta}{2}A)^{-1}(I+\frac{\Delta}{2}A)$ ,需要进行矩阵-矩阵相乘。如果能找到某种基于矩阵-向量乘积的计算方法,计算量将大幅降低。

  • 对于Convolutional形式的SSM,除了 $DFT$ 的加速,计算瓶颈还在卷积核的计算上。长度为 $L$ 的卷积核 $\bar{K}$ 包含矩阵 $\bar{A}^{L-1}$ ,即HiPPO矩阵 $\bar{A}$ 的 $L-1$ 次幂。需要找到该高阶矩阵幂的快速算法。

对于原始的更新公式:

如果令 $(A,B,C)\sim (V^{-1}AV, V^{-1}B, CV)$,这实际上是对state进行了一个线性变换 $x_k=V\bar{x}_k$

如果我们能找到性质良好的矩阵 $V$ 对HiPPO矩阵 $A$ 进行规范化,比如将 $A$ 对角化,就可以很容易计算 $V^{-1}AV$ 的高次幂

事实上确实存在这样的对角化方法。S4的作者证明,对于所有的HiPPO Matrix(HiPPO-LegS、HiPPO-LegT、HiPPO-LagT),都可以用矩阵 $V_{ij} = \begin{pmatrix} i+j \\ i-j \end{pmatrix}$ 进行对角化(这里指的是组合数)。然而直接对 $A$ 进行这样的对角化会有数值稳定性上的问题,因为很容易得到:

即矩阵 $V$ 中最大元素的值的大小几乎是指数增加的,反映在实际运行时,序列长度 $L$ 只要稍微大一点,卷积核中就会出现 NaN

Normal Plus Low-Rank

直接使用 $V_{ij}=\begin{pmatrix} i+j \\ i-j \end{pmatrix}$ 虽然可以进行对角化,但会有稳定性问题。如果矩阵 $A$ 为正规矩阵,则可以用酉矩阵对 $A$ 进行酉对角化。不幸的是所有HiPPO matrix都不是正规矩阵。

虽然无法直接对 $A$ 进行稳定的对角化,但我们观察到,所有的HiPPO matrix都能被分解为一个正规矩阵和一个低秩矩阵的和。

对矩阵相加一个每个元素均为 $\frac{1}{2}(2n+1)^{\frac{1}{2}}(2k+1)^{\frac{1}{2}}$ 的矩阵(低秩,秩为1),可以得到:

这个矩阵等于一个单位阵 $-\frac{1}{2}I$ 加一个斜对称矩阵 $S$ ,斜对称矩阵是正规矩阵的一个特例,可以用酉矩阵进行对角化。虽然 $-\frac{1}{2}I+S$ 不再是斜对称矩阵,但它仍可以使用某些使 $S$ 对角化的酉矩阵进行对角化,即仍为正规矩阵。

正式地,我们可以将所有HiPPO Matrix表示为NPLR(Normal Plus Low-Rank):

其中 $PQ^T$ 为低秩分解,像上文HiPPO-LegS的低秩矩阵直接就是秩为1,可以很方便的分解为两个向量的乘积。

为了后续的频域计算,我们从这里开始将矩阵的定义扩展到复数空间。进行一些变换我们得到:

因此,计算 $A$ 的高次幂转化为对 $\Lambda - ({V^{\ast}} P)({V^{\ast}} Q)^{\ast} $ 的计算。这个矩阵具有更良好的形式,称为DPLR (Diagonal Plus Low-Rank)。然而即使是DPLR,到这一步仍然不好计算。接下来针对两种模式的SSM,即RNN view(推理)和CNN view(训练),分别用矩阵理论进行推导,找到高效的计算方法。

RNN view

根据前面的推导,我们知道所有HiPPO Matrices都可以表示为DPLR,不失一般性,我们记为 $A=\Lambda - PQ^{\ast}$。

在S4中,我们使用双线性离散化方法,假设步长为 $\Delta$,我们有:

分别计算两个相乘项,我们有:

使用Woodbury matrix identity: $(A+UCV)^{-1} = A^{-1} - A^{-1}U(C^{-1}+VA^{-1}U)^{-1}VA^{-1}$,我们可以得到:

其中 $A_0 = \frac{2}{\Delta}I + (\Lambda - PQ^{\ast}), D = (\frac{2}{\Delta}-\Lambda)^{-1}, A_1 = D-DP(I + Q^{\ast}DP) ^{-1}Q^{\ast}D$

代入可得:

再将上式代入离散SSM,可以得到:

注意到 $A_0, A_1$ 的组成均为单位阵、对角阵和低秩矩阵,因此 $A_1, A_0$的计算仅包含矩阵-向量乘法,RNN view计算一步的复杂度为 $O(N)$

CNN view

我们稍微对符号进行一点修改便于后面的推导(将 $C$ 记为列向量):

我们得到的长度为 $L$ 的卷积核为:

我们定义一个对应于卷积核的SSM生成函数(一个定义在频域的函数,可以通过DFT变换到时域的卷积核),在节点 $z$ 其表达式为:

特别地,截断SSM生成函数为:

对于 $M$ 个节点的向量 $\Omega \in \mathbb{C}^{M}$,我们有:

上面都是针对生成卷积核的函数的表达式,我们记卷积核及其频域形式如下:

我们取一组单位根 $\Omega = e^{-2\pi i \frac{k}{L}}, k\in[L]$,可得:

上式实际上是对卷积核进行离散傅里叶变换(DFT): $\hat{}K = \mathcal{F}_L{K}$,下面我们推导 $\hat{K}$ 的计算。

展开截断SSM生成函数:

其中 $\tilde{C}^{\ast} = C^{\ast}(I-\bar{A}^L)$,由于 $C$ 可学习,实践中可以直接将 $\tilde{C}^{\ast}$ 设为可学习向量。将上式展开到可以得到:

即:

使用Woodbury matrix identity: $(A+UCV)^{-1} = A^{-1} - A^{-1}U(C^{-1}+VA^{-1}U)^{-1}VA^{-1}$,我们可以得到:

其中 $R(z) = (\frac{2}{\Delta}\frac{1-z}{1+z} - \Lambda)^{-1}$。到这里,截断SSM生成函数的计算可以被最终归结为对 $R(z)$ 及其与其他矩阵的乘法计算,具体来说,我们希望高效地计算 $Q^{\ast}R(\Omega; \Lambda)P$,将 $\frac{2}{\Delta}\frac{1-z}{1+z}$ 视为 $\omega \in \Omega$,我们实际上是要计算 $\sum_j\frac{q_i^{\ast}p_j}{\omega-\lambda_j}$,而这是一个 Cauchy matrix-vector multiplication,已有高效算法。

具体来说,对于 Cauchy matrix $M$:

Cauchy matrix-vector product的计算复杂度为:

因此CNN view 计算复杂度为 $O((L+N)\log(L+N)\log \frac{1}{\epsilon})$,可以记为次线性形式:$\tilde{O}(L+N)$

Reference

组会内部分享报告 《A Gentle Introduction to HiPPO🦛 and its Friends》

Efficiently Modeling Long Sequences with Structured State Spaces