本篇是mamba系列blog的第四篇文章,系列文章见:

剩余预计还有一篇文章正在生产中~

计算公式

我们都知道SSM的公式:

这里的 $t$ 其实就代表了 seq_len 这一维。我们将公式展开,就可以得到一个卷积的形式

我们记 $A_{t:s} := A_tA_{t-1}…A_{s+1}$,则最终的输出 $y_t = C_t^Th_t = C_t^T\sum_{s=0}^t A_{t:s}B_sx_s$。

如果写成向量形式,则有 $y=Mx$,$M$ 有如下格式:

其中,分为对角块与方阵。其中的方阵部分还可以进一步简化:

更一般的形式为,对于矩阵 $M_{j:j’,i’:i}$,其中 $j’ > j \geq i > i’$:

后续的工作,就是写算子把这些东西全部高效地算出来。整体思路就是,按上述分块,先块内算,再块间算。

Triton 算子

有关 Triton 的基本知识欢迎参考这篇Triton Tutorial

_chunk_cumsum_fwd

需要先介绍一下 $A$ 的离散化(记 $\tilde{A} $ 为离散化之前的矩阵),见 Part III of Mathematical Structure of Mamba - S4D

我们知道了,为了保证 $A$ 的实部总为负,我们需要用指数形式来处理离散化:$e^{\Delta A}$。而指数形式的乘积实际上就对应了指数的累计和,这也是这个算子在做的事情。

经过 DSS 与 S4D 这两篇文章的沉淀,作者发现A使用对角线比起NPLR也没差很多。特别是从mamba开始抛弃HiPPO后,就更没有使用NPLR的必要了,直接快进到对角线。

$dt: (B,S,H)$, $A: (H,)$,可以视作每个头都有一个控制状态衰减的变量。

这个算子做的事情其实比较简单,就是计算 $e^{\sum \Delta_i A_i}$,不过是分块做的。

grid

我们把 $dt$ 按照seqlen这一维切开,切成 chunk_size 大小的块。同时,head这一维也会切分,每份大小是 BLOCK_SIZE_H,这个参数后续会用autotune去搜。最后得到一个三维的grid,代表了 [batch, seqlen, heads]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
batch, seqlen, nheads = dt.shape
assert A.shape == (nheads,)
if dt_bias is not None:
assert dt_bias.shape == (nheads,)
nchunks = math.ceil(seqlen / chunk_size)
dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
with torch.cuda.device(dt.device.index):
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
dt, A, dt_bias, dt_out, dA_cumsum,
batch, seqlen, nheads, chunk_size,
dt_limit[0], dt_limit[1],
dt.stride(0), dt.stride(1), dt.stride(2),
A.stride(0),
dt_bias.stride(0) if dt_bias is not None else 0,
dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
dt_softplus,
HAS_DT_BIAS=dt_bias is not None,
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
)
return dA_cumsum, dt_out

之后简单看一下整个triton算子的逻辑,我会以 tl.load() 为基础,看每份数据是怎么读进来的;然后再介绍它们之间如何计算。

dt

因为每块数据处理第 pid_b 条的 chunk_size 条数据,所以首先定位到 dt_ptr,再取下大小为 [BLOCK_SIZE_H, BLOCK_SIZE_CHUNK] 的一块。

这里比较需要注意的就是,在读入数据的时候交换了 $S$ 和 $H$ 这两维,是为了后续计算方便。只要stride是正确的话,读进来的数据是一定正确的。

1
2
3
4
5
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)

A

A 就没啥好说的,就一个维度,分成 BLOCK_SIZE_H 大小即可

1
2
A_ptrs = A_ptr + offs_h * stride_A_head
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)

计算

主要就是计算 $dt * A$,然后求累计和

1
2
dA = dt * A[:, None]
dA_cs = tl.cumsum(dA, axis=1)

值得一提的是,最后存下来的结果的维度是 [batch, nheads, nchunks, chunk_size],等于是把seqlen这一维给拆开了。并且,块间的交互也还没做,比如说本来累计和应该是 [1,2,3,4],现在的分块结果可能是 [1,2,1,2],后续还需要一步块间RNN传递。

_chunk_state_fwd

这段代码主要是算states的,对应的是原计算公式中的

先明确一下各输入的维度。

x: [batch, seq_len, nheads, headdim],其中 nheads * headdim = d_model

dt: [batch, nheads, nchunks, chunk_size],其中 nchunks * chunk_size = seq_len

B: [batch, seq_len, ngroups, dstate],其中 ngroups 是为了tp的参数,一般为1;dstate 就是状态向量的维度

dA_cumsumdt 一样

grid

grid的分法看起来有点奇怪,第0维把 headdim * dstate 分成了一组,这是因为 $ABx$ 的结果就是 headdim * dstate,可以这么理解,表示输入 x 的某个维度,对内部状态 h 的某个维度的影响。

1
2
3
4
5
6
7
# mamba_ssm/ops/triton/ssd_chunk_state.py
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
batch * nchunks, nheads)

num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n

第1维是把 batchnchunks 组合到一起,然后到kernel里之后又光速分开了。看起来排序方式是 [nchunks, batch],不懂为啥要拼起来传进去。

1
2
3
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch

第2维就简明一点,直接对 nheads 分块

1
pid_h = tl.program_id(axis=2)

dA_cs_last

之后依然是读取数据,我们还是以 tl.load() 为支点读代码。首先是 dA_cumsum,也就是一个 chunk_size 大小中的前缀和 $\sum_{i=0}^t \Delta_iA_i$。

1
2
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)

dA_cumsum 的维度是 [batch, nheads, nchunks, chunk_size]。我们的读取方式很简单,分别朝第0、1、2维移动对应长度。注意,dA_cs_last 读的是这个chunk中的最后一个元素,因此叫last。

seq_idx

seq_idx 暂时跳过,我不知道这个是干什么的。

x

这里又很神秘的,对于 chunk_size 这一维还要切一下。读取方式是之前说明过的交换维度,只要stride是对的,数据就是对的,只不过是写入顺序的问题罢了。所以读取进来的 x[pid_b, (k:k+1)*BLOCK_SIZE_K, pid_h, (pid_m:pid_m+1) * BLOCK_SIZE_M]

1
2
3
4
5
6
7
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)

for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0)
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen

b

b 的取值是 [pid_b, (k: k + 1) * BLOCK_SIZE_K, pid_group, (pid_n: pid_n+1) * BLOCK_SIZE_N]

注意不要搞错了 b 的意义,这里的 b 可以看作是用来处理 x 的投影矩阵,还是要投到 dstate 上去的。

1
2
3
4
5
6
7
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)

for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen

dA_cs_k 与 dt_k

就是读进来一个chunk_size的元素,然后进一步把 chunk_size(seq_len) 这一维切了一下,不太重要

1
2
3
4
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize

每一小块的具体计算

先把前缀和计算这一段的差,比如说 dA_cs_last = A_3 * A_2 * A_1 * A_0dA_cs_k = [A_0, A_1 * A_0]

这里的小块 b 是在 [seq_len, dstate] 维度上取出来的大小为 [BLOCK_SIZE_K, BLOCK_SIZE_N] 的块。这里的scale乘到了 seq_len 维度上,因为不同的时间步需要乘的A不同嘛。

后面和大小为 [BLOCK_SIZE_M, BLOCK_SIZE_K]x 做点积,得到是一个 [headdim, dstate] 的矩阵。然后按照开头公式所说的加起来。最后存下来的states的维度是 [batch, nchunks, nheads, headdim, dstate]

1
2
3
scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k
b *= scale[:, None]
acc += tl.dot(x, b)

_state_passing_fwd

上面一步的计算其实是算少了的,比如说你第一步有的块是只有 $A_7A_6A_5A_4$的,那你本来想算 $A_{7:0}B_0x_0$ 的就少算了。方法就是从算过的 $A_{3:0}B_0x_0$ 往下传。函数名字也很形象。

代码上有一点点改变,首先就是把states的后两维聚到一起了,方便计算,然后 dA_cumsum 只需要取最后一个就可以。

1
2
3
states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype)

grid

上来grid又是乱序,我真的会谢。dim 这一维指的是states的后两维,可能是这个算子里不涉及这个的操作,所以抛到最外侧了吧(但是内存空间又没变,早知如此为什么不之前存的时候就换顺序呢?)。

1
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)

states & dA

函数的具体内容倒是简单,读入 statesdA

dA 的维度是 [batch, nheads, nchunks]state 的维度是 [batch, nchunks, nheads, headdim*dstate]

这里 dA 好像就是把一整个 nchunks 取出来,这样就是 nchunks 段乘积。

states 则直接在除了nchunks的方向上移动到对应位置,最后一维取了一个 BLOCK_SIZE 做并行

1
2
3
4
5
6
7
8
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
states_ptrs = states_ptr + offs_m * stride_states_dim

for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)

一个比较朴素的问题是,怎么保证读的正确对应呢?这里用了一个循环来更新:

1
2
3
4
5
6
7
8
9
10
11
12
13
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
scale = tl.exp(dA_cs)

states = scale * states + new_states # 循环更新
if c < nchunks - 1:
tl.store(out_ptrs, states, mask=offs_m < dim)
else:
tl.store(final_states_ptrs, states, mask=offs_m < dim)
states_ptrs += stride_states_chunk
dA_cs_ptr += stride_dA_cs_chunk
out_ptrs += stride_out_chunk

但是说实话,具体怎么并行的我也有点没想清楚,为啥这么取就能把所有块的算好呢。后续又该怎么取呢?有点神秘。

_bmm_chunk_fwd

并非善类

grid

分的维度是 [chunk_size, chunk_size, batch, nchunks]

Tridao瞎起名字,给我整笑了。外面传进去 Cdstate,在里面形参叫 Ak

1
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), batch, nchunks if not has_groups else nchunks * ngroups)