Part IV of Mathematical Structure of Mamba - Mamba&Mamba2
本篇是mamba系列blog的第四篇文章,系列文章见:
Part IV of Mathematical Structure of Mamba - Mamba&Mamba2
剩余预计还有一篇文章正在生产中~
计算公式
我们都知道SSM的公式:
这里的 $t$ 其实就代表了 seq_len
这一维。我们将公式展开,就可以得到一个卷积的形式
我们记 $A_{t:s} := A_tA_{t-1}…A_{s+1}$,则最终的输出 $y_t = C_t^Th_t = C_t^T\sum_{s=0}^t A_{t:s}B_sx_s$。
如果写成向量形式,则有 $y=Mx$,$M$ 有如下格式:
其中,分为对角块与方阵。其中的方阵部分还可以进一步简化:
更一般的形式为,对于矩阵 $M_{j:j’,i’:i}$,其中 $j’ > j \geq i > i’$:
后续的工作,就是写算子把这些东西全部高效地算出来。整体思路就是,按上述分块,先块内算,再块间算。
Triton 算子
有关 Triton 的基本知识欢迎参考这篇Triton Tutorial
_chunk_cumsum_fwd
需要先介绍一下 $A$ 的离散化(记 $\tilde{A} $ 为离散化之前的矩阵),见 Part III of Mathematical Structure of Mamba - S4D
我们知道了,为了保证 $A$ 的实部总为负,我们需要用指数形式来处理离散化:$e^{\Delta A}$。而指数形式的乘积实际上就对应了指数的累计和,这也是这个算子在做的事情。
经过 DSS 与 S4D 这两篇文章的沉淀,作者发现A使用对角线比起NPLR也没差很多。特别是从mamba开始抛弃HiPPO后,就更没有使用NPLR的必要了,直接快进到对角线。
$dt: (B,S,H)$, $A: (H,)$,可以视作每个头都有一个控制状态衰减的变量。
这个算子做的事情其实比较简单,就是计算 $e^{\sum \Delta_i A_i}$,不过是分块做的。
grid
我们把 $dt$ 按照seqlen这一维切开,切成 chunk_size 大小的块。同时,head这一维也会切分,每份大小是 BLOCK_SIZE_H
,这个参数后续会用autotune去搜。最后得到一个三维的grid,代表了 [batch, seqlen, heads]
1 | def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): |
之后简单看一下整个triton算子的逻辑,我会以 tl.load()
为基础,看每份数据是怎么读进来的;然后再介绍它们之间如何计算。
dt
因为每块数据处理第 pid_b
条的 chunk_size
条数据,所以首先定位到 dt_ptr
,再取下大小为 [BLOCK_SIZE_H, BLOCK_SIZE_CHUNK]
的一块。
这里比较需要注意的就是,在读入数据的时候交换了 $S$ 和 $H$ 这两维,是为了后续计算方便。只要stride是正确的话,读进来的数据是一定正确的。
1 | dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen |
A
A
就没啥好说的,就一个维度,分成 BLOCK_SIZE_H
大小即可
1 | A_ptrs = A_ptr + offs_h * stride_A_head |
计算
主要就是计算 $dt * A$,然后求累计和
1 | dA = dt * A[:, None] |
值得一提的是,最后存下来的结果的维度是 [batch, nheads, nchunks, chunk_size]
,等于是把seqlen这一维给拆开了。并且,块间的交互也还没做,比如说本来累计和应该是 [1,2,3,4],现在的分块结果可能是 [1,2,1,2],后续还需要一步块间RNN传递。
_chunk_state_fwd
这段代码主要是算states的,对应的是原计算公式中的
先明确一下各输入的维度。
x: [batch, seq_len, nheads, headdim]
,其中 nheads * headdim = d_model
dt: [batch, nheads, nchunks, chunk_size]
,其中 nchunks * chunk_size = seq_len
B: [batch, seq_len, ngroups, dstate]
,其中 ngroups
是为了tp的参数,一般为1;dstate
就是状态向量的维度
dA_cumsum
和 dt
一样
grid
grid的分法看起来有点奇怪,第0维把 headdim * dstate
分成了一组,这是因为 $ABx$ 的结果就是 headdim * dstate
,可以这么理解,表示输入 x
的某个维度,对内部状态 h
的某个维度的影响。
1 | # mamba_ssm/ops/triton/ssd_chunk_state.py |
第1维是把 batch
和 nchunks
组合到一起,然后到kernel里之后又光速分开了。看起来排序方式是 [nchunks, batch]
,不懂为啥要拼起来传进去。
1 | pid_bc = tl.program_id(axis=1) |
第2维就简明一点,直接对 nheads
分块
1 | pid_h = tl.program_id(axis=2) |
dA_cs_last
之后依然是读取数据,我们还是以 tl.load()
为支点读代码。首先是 dA_cumsum
,也就是一个 chunk_size
大小中的前缀和 $\sum_{i=0}^t \Delta_iA_i$。
1 | dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head |
dA_cumsum
的维度是 [batch, nheads, nchunks, chunk_size]
。我们的读取方式很简单,分别朝第0、1、2维移动对应长度。注意,dA_cs_last
读的是这个chunk中的最后一个元素,因此叫last。
seq_idx
seq_idx
暂时跳过,我不知道这个是干什么的。
x
这里又很神秘的,对于 chunk_size
这一维还要切一下。读取方式是之前说明过的交换维度,只要stride是对的,数据就是对的,只不过是写入顺序的问题罢了。所以读取进来的 x
是 [pid_b, (k:k+1)*BLOCK_SIZE_K, pid_h, (pid_m:pid_m+1) * BLOCK_SIZE_M]
1 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
b
b
的取值是 [pid_b, (k: k + 1) * BLOCK_SIZE_K, pid_group, (pid_n: pid_n+1) * BLOCK_SIZE_N]
注意不要搞错了 b
的意义,这里的 b
可以看作是用来处理 x
的投影矩阵,还是要投到 dstate 上去的。
1 | offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
dA_cs_k 与 dt_k
就是读进来一个chunk_size的元素,然后进一步把 chunk_size(seq_len) 这一维切了一下,不太重要
1 | dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize |
每一小块的具体计算
先把前缀和计算这一段的差,比如说 dA_cs_last = A_3 * A_2 * A_1 * A_0
, dA_cs_k = [A_0, A_1 * A_0]
这里的小块 b
是在 [seq_len, dstate]
维度上取出来的大小为 [BLOCK_SIZE_K, BLOCK_SIZE_N]
的块。这里的scale乘到了 seq_len
维度上,因为不同的时间步需要乘的A不同嘛。
后面和大小为 [BLOCK_SIZE_M, BLOCK_SIZE_K]
的 x
做点积,得到是一个 [headdim, dstate]
的矩阵。然后按照开头公式所说的加起来。最后存下来的states的维度是 [batch, nchunks, nheads, headdim, dstate]
1 | scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k |
_state_passing_fwd
上面一步的计算其实是算少了的,比如说你第一步有的块是只有 $A_7A_6A_5A_4$的,那你本来想算 $A_{7:0}B_0x_0$ 的就少算了。方法就是从算过的 $A_{3:0}B_0x_0$ 往下传。函数名字也很形象。
代码上有一点点改变,首先就是把states的后两维聚到一起了,方便计算,然后 dA_cumsum
只需要取最后一个就可以。
1 | states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], |
grid
上来grid又是乱序,我真的会谢。dim
这一维指的是states的后两维,可能是这个算子里不涉及这个的操作,所以抛到最外侧了吧(但是内存空间又没变,早知如此为什么不之前存的时候就换顺序呢?)。
1 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) |
states & dA
函数的具体内容倒是简单,读入 states
和 dA
。
dA
的维度是 [batch, nheads, nchunks]
;state
的维度是 [batch, nchunks, nheads, headdim*dstate]
这里 dA
好像就是把一整个 nchunks
取出来,这样就是 nchunks
段乘积。
states
则直接在除了nchunks的方向上移动到对应位置,最后一维取了一个 BLOCK_SIZE
做并行
1 | dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head |
一个比较朴素的问题是,怎么保证读的正确对应呢?这里用了一个循环来更新:
1 | for c in range(nchunks): |
但是说实话,具体怎么并行的我也有点没想清楚,为啥这么取就能把所有块的算好呢。后续又该怎么取呢?有点神秘。
_bmm_chunk_fwd
并非善类
grid
分的维度是 [chunk_size, chunk_size, batch, nchunks]
Tridao瞎起名字,给我整笑了。外面传进去 C
和 dstate
,在里面形参叫 A
和 k
,
1 | grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), batch, nchunks if not has_groups else nchunks * ngroups) |