Triton Tutorial
绪论
Triton是一门适配python的高性能GPU编程语言(暂时只认为是语言),学习路线可以从完成官方的tutorial开始。我的博客里主要想讲一些不一样的。
CUDA Version: 12.2
Triton Version: 3.1.0
GPU相关知识
想必大家上来跑tutorial遇到的第一个问题是,获取DEVICE的接口报错了!
1 | import triton.runtime import driver |
查阅源码后发现,应该是nvidia那边的接口变掉了,导致triton中无法重载:
1 | # triton/python/triton/backends/driver.py |
因此,我们可以用别的API来代替:
1 | driver.active.get_current_target() |
在有些时候,我们还需要更多GPU的信息来辅助并行编程:
1 | properties = driver.active.utils.get_device_properties(DEVICE_ID) |
解释一下这四个主要的GPU参数:
NUM_SM
是指的GPU中Streaming Multiprocessor(SM)的数量。SM是GPU上的核心处理单元,包含完整的内存、寄存器等。整个CUDA编程的核心就是将任务分成多个BLOCKs,然后这些BLOCKs会被均匀地分给所有SM执行。我使用的GPU型号是H800,总共有132个SM。NUM_REGS
是指每个SM中可用的寄存器(registers)的最大数量。寄存器每个线程不共享,如果单个线程所需要使用的寄存器很多,则同时在一个SM上运行的线程数量就会减少。SIZE_SMEM
是指每个SM可用的共享内存(Shared Memory)的大小,通常以字节(Bytes)为单位。SMEM访问速度远快于显存,同一个SM中的所有线程均可共享,可用于数据交换等。linux系统中也有共享内存的概念,可以看作是“基于内存的文件系统”,通常指的是/dev/shm
这块区域,用于进程间通信等操作,就不用把数据写到硬盘里去了。
1 | df -h /dev/shm |
WARP_SIZE
比较复杂一点,我们首先需要理解Warp的概念。Warp是GPU上线程调度的基本单元,一个Warp中的所有线程会执行相同的命令。这并不代表一个Warp中所有线程是完全一样的,而是说,如果Warp中有一半的指令做的是A,而另一半的指令做的是A->B,则在第一阶段所有线程会同时处理,而在第二阶段有一半的线程会陪着另一半空转。因此,避免Warp分歧也是一个很重要的优化点。WARP_SIZE
则表示一个Warp中包含的线程数量,基本上是32。
Vector Addition
源代码很简单就不解释了
1 | import torch |
性能调优
比起源码,我额外增加了 from triton.runtime import autotune
,它的作用就是对于不同size的输入,会在首次执行时搜一遍所有可能的配置,找到其中效率最高的,后续对同样的size就会用固定的配置。
简单尝试一下的话,就会发现影响程序性能的因素有两个 BLOCK_SIZE
与 num_warps
。其实理论上应该是 BLOCK_SIZE
、 num_warps
和 input_size
三者的关系决定了性能。我们用控制变量的方式来测一下它们的影响。
Vector Addition Triton Kernel Performance。测BS时默认NW为4;测NW时默认BS为512
从图中我们能看出两个明显掉点的曲线,一个是 BLOCK_SIZE
为128时,此时的 BLOCK_SIZE
太小,分成的warps太多,导致管理调度 $2^{27} / 128 / 32$ 个块成了开销瓶颈;另一个是 num_warps
为16时,过大的线程块导致了对寄存器等资源的竞争更加剧烈,甚至可能发生寄存器溢出,显著影响效率。其实我也还不太会分析具体的原因,下面是详细的测试表格:
size | Torch | Triton BS=128 | Triton BS=256 | Triton BS=512 | Triton NW=4 | Triton NW=8 | Triton NW=16 | Triton Autotune |
---|---|---|---|---|---|---|---|---|
4096.000000 | 9.035294 | 9.197604 | 9.142857 | 9.142857 | 9.142857 | 9.142857 | 9.088757 | 8.982456 |
8192.000000 | 17.964912 | 18.070588 | 17.860465 | 17.757226 | 18.070588 | 17.757226 | 18.070588 | 17.757226 |
16384.000000 | 35.514452 | 35.310345 | 35.310345 | 35.310345 | 35.310345 | 35.310345 | 35.720930 | 35.310345 |
32768.000000 | 68.266666 | 68.266666 | 69.423731 | 69.033707 | 69.033707 | 70.217145 | 69.818181 | 69.818181 |
65536.000000 | 132.843245 | 135.779009 | 135.779009 | 133.565214 | 133.565214 | 133.565214 | 135.032965 | 136.533331 |
131072.000000 | 253.360834 | 252.061538 | 255.999991 | 258.694729 | 260.063494 | 252.061538 | 253.360834 | 258.694729 |
262144.000000 | 457.227922 | 444.814490 | 455.111110 | 465.895721 | 463.698115 | 461.521112 | 453.013839 | 465.895721 |
524288.000000 | 750.412251 | 747.558951 | 774.047204 | 771.011790 | 771.011790 | 765.011652 | 741.916954 | 777.106702 |
1048576.000000 | 1228.800031 | 1159.929234 | 1187.963788 | 1221.167675 | 1228.800031 | 1184.385557 | 1156.517652 | 1217.387051 |
2097152.000000 | 1687.622326 | 1569.724635 | 1676.827323 | 1698.557221 | 1687.622326 | 1684.008546 | 1569.724635 | 1694.896509 |
4194304.000000 | 2154.608134 | 1927.529447 | 2148.721353 | 2160.527432 | 2163.499294 | 2145.789924 | 1852.607766 | 2163.499294 |
8388608.000000 | 2532.792214 | 2204.434417 | 2538.924930 | 2534.833158 | 2532.792214 | 2543.029933 | 2160.527432 | 2545.087416 |
16777216.000000 | 2755.784589 | 2349.311512 | 2763.045981 | 2756.388411 | 2755.784589 | 2763.045981 | 2394.920460 | 2758.200902 |
33554432.000000 | 2941.994716 | 2428.196121 | 2949.580950 | 2944.059933 | 2942.682906 | 2950.618366 | 2558.022378 | 2949.580950 |
67108864.000000 | 3021.832823 | 2487.232999 | 3032.757656 | 3021.832823 | 3022.740106 | 3032.392472 | 2644.858132 | 3032.027035 |
134217728.000000 | 3067.880595 | 2514.570784 | 3071.437476 | 3066.945668 | 3067.506556 | 3072.937670 | 2676.788108 | 3072.562746 |
如果采用 BS=4096,NW=16
的配置的话,性能并不会显著下降,看起来似乎不大符合原来“寄存器溢出导致性能下降”的假设
gemini的回答是:
当 Kernel 需要处理一个包含4096个元素的向量 (如 offsets = tl.arange(0, BLOCK_SIZE)) 时,编译器清楚地知道,不可能一次性将所有4096个元素(对于每个线程来说,是它负责的那部分)都完整地、同时地放在寄存器中进行操作。这远超出了单个线程或 Warp 的实际寄存器容量。因此,编译器很可能会生成一种高度流水线化 (pipelined) 或分块 (tiled/chunked) 的执行代码。它会将这4096个元素的操作分解成更小的、可管理的批次。例如,加载一小批数据到寄存器,计算,存储结果,然后再处理下一小批。
这种流水线化的处理方式,其瞬时 (instantaneous) 寄存器需求可能相对较低。也就是说,在任何特定时刻,每个线程为正在处理的小数据块所活跃使用的寄存器数量可能不多。编译器处理512个元素时,它可能尝试一种不同的优化策略。也许它认为可以将更大部分的“向量”同时保持活跃在寄存器中,或者采用一种不那么积极分解成小块的策略,因为它认为这在某些情况下(如有足够寄存器时)可能更有效。
这种策略如果本身对寄存器的需求就比较高,那么当 num_warps 强制一个较紧的寄存器预算时,就更容易导致寄存器溢出。
暂时不知道对错,感觉我还需要一些性能调优工具。
我们可以发现一个有趣的现象,即固定 num_warps
时,BLOCK_SIZE
设为256是最佳的;固定 BLOCK_SIZE
时,num_warps
设为8是最佳的。
由于性能受到多种因素的影响,因此影响程序效率的参数往往存在着被称为 sweet spot(甜点区) 的最佳范围,高了也不行,低了也不行。
一个比较好的方法是采用autotune,试几种可选参数后丢进去。这样程序会在初次运行时自动测试几种配置,后续对于同样的size均采用最佳配置即可。
值得一提的是,在计算非常简单的 Vector Addition 的场景下,性能瓶颈不在计算而在于带宽,因此最佳性能差距不会太大。
Fused Softmax
遇到的一个神人问题是,源代码中的 kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
报错
报错原因是,代码中已经预先编译过内核了
1 | kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps, grid=(1, )) |
而在 softmax_kernel
的原始定义中,BLOCK_SIZE
和 num_stages
被声明为了 tl.constexpr
类型,表示编译时常量,即在内核中编译的过程中已经按照这个常量编译了,不应该在运行时再次传入。
1 |
|
因此,源代码应该修改为
1 | kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols) |
最后,完整的代码如下:
1 | import torch |
代码说明
这个kernel比 Vector Addition
要稍微复杂一点,因为它涉及一个伪2D的并行。
说是伪2D是因为,它的 BLOCK_SIZE
取值是 triton.next_power_of_2(n_cols)
,因此,每行不需要再单独划分。
主要的循环 for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
的功能是:
如果总共有 P 个程序并行,当前程序为 i,则取出第 $[i, P+i, 2P+i, \cdots]$ 行处理
后续每行取
BLOCK_SIZE
其实就已经取完了,并没有在行上并行
对于 @triton.jit
装饰的函数,其中的中间变量,会放在寄存器或者共享内存中,直到 tl.store()
才会写回(如果占太多了也有可能自动offload到显存)。共享内存一般是SRAM,硬件特性决定了比一般的显存(HBM)要快很多。
num_stages
是一个流水线处理的参数,表示并行程度。如果 num_stages>1
,则在程序处理第一条数据时,会同时开始加载第二条数据,这样可以加速,但是对SRAM的大小又有了要求。
因此有 num_stages = 4 if SIZE_SMEM > 200000 else 2
benchmark
从图中可以看出,还是要快不少的,不同 num_stages
之间差距不大。
Fused Softmax Triton Kernel Performance