绪论

Triton是一门适配python的高性能GPU编程语言(暂时只认为是语言),学习路线可以从完成官方的tutorial开始。我的博客里主要想讲一些不一样的。

CUDA Version: 12.2

Triton Version: 3.1.0

GPU相关知识

想必大家上来跑tutorial遇到的第一个问题是,获取DEVICE的接口报错了!

1
2
3
4
import triton.runtime import driver
DEVICE = driver.active.get_active_torch_device()

# >>> AttributeError: 'CudaDriver' object has no attribute 'get_active_torch_device'

查阅源码后发现,应该是nvidia那边的接口变掉了,导致triton中无法重载:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# triton/python/triton/backends/driver.py
class DriverBase(metaclass=ABCMeta):

@classmethod
@abstractmethod
def is_active(self):
pass

@abstractmethod
def get_current_target(self):
pass

@abstractmethod
def get_active_torch_device(self):
pass

@abstractmethod
def get_benchmarker(self) -> Benchmarker:
"""
Return the benchmarking function that this backend should use by default.
"""
raise NotImplementedError

def __init__(self) -> None:
pass

driver.active
# >>> <nvi.CudaDriver object at 0x7ff294e43d30>

因此,我们可以用别的API来代替:

1
2
3
4
5
6
driver.active.get_current_target()
# >>> GPUTarget(backend='cuda', arch=90, warp_size=32)
DEVICE = driver.active.get_current_target().backend
# >>> 'cuda'
DEVICE_ID = driver.active.get_current_device()
# >>> 0

在有些时候,我们还需要更多GPU的信息来辅助并行编程:

1
2
3
4
5
6
properties = driver.active.utils.get_device_properties(DEVICE_ID)
# >>> {'max_shared_mem': 232448, 'max_num_regs': 65536, 'multiprocessor_count': 132, 'warpSize': 32, 'sm_clock_rate': 1980000, 'mem_clock_rate': 2619000, 'mem_bus_width': 5120}
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]

解释一下这四个主要的GPU参数:

  • NUM_SM 是指的GPU中Streaming Multiprocessor(SM)的数量。SM是GPU上的核心处理单元,包含完整的内存、寄存器等。整个CUDA编程的核心就是将任务分成多个BLOCKs,然后这些BLOCKs会被均匀地分给所有SM执行。我使用的GPU型号是H800,总共有132个SM。

  • NUM_REGS 是指每个SM中可用的寄存器(registers)的最大数量。寄存器每个线程不共享,如果单个线程所需要使用的寄存器很多,则同时在一个SM上运行的线程数量就会减少。

  • SIZE_SMEM 是指每个SM可用的共享内存(Shared Memory)的大小,通常以字节(Bytes)为单位。SMEM访问速度远快于显存,同一个SM中的所有线程均可共享,可用于数据交换等。linux系统中也有共享内存的概念,可以看作是“基于内存的文件系统”,通常指的是 /dev/shm 这块区域,用于进程间通信等操作,就不用把数据写到硬盘里去了。

1
2
3
df -h /dev/shm
>>> 文件系统 大小 已用 可用 已用% 挂载点
>>> tmpfs 64G 0 64G 0% /dev/shm
  • WARP_SIZE 比较复杂一点,我们首先需要理解Warp的概念。Warp是GPU上线程调度的基本单元,一个Warp中的所有线程会执行相同的命令。这并不代表一个Warp中所有线程是完全一样的,而是说,如果Warp中有一半的指令做的是A,而另一半的指令做的是A->B,则在第一阶段所有线程会同时处理,而在第二阶段有一半的线程会陪着另一半空转。因此,避免Warp分歧也是一个很重要的优化点。WARP_SIZE 则表示一个Warp中包含的线程数量,基本上是32。

Vector Addition

源代码很简单就不解释了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch

import triton
import triton.language as tl
from triton.runtime import autotune

DEVICE = triton.runtime.driver.active.get_current_target().backend


@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)


@autotune(
configs=[
triton.Config({'BLOCK_SIZE': 256}, num_warps=4),
triton.Config({'BLOCK_SIZE': 512}, num_warps=4),
triton.Config({'BLOCK_SIZE': 512}, num_warps=8),
],
key=['n_elements'],
)
@triton.jit
def add_kernel_autotune(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)


def add(
x: torch.Tensor,
y: torch.Tensor,
block_size: int = None,
num_warps: int = 4,
autotune: bool = False
) -> torch.Tensor:
output = torch.empty_like(x)
assert x.device.type == DEVICE and y.device.type == DEVICE and output.device.type == DEVICE
n_elements = output.numel()

grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) # noqa: E731
if autotune:
add_kernel_autotune[grid](x, y, output, n_elements)
else:
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=block_size, num_warps=num_warps)

return output


def main():
torch.manual_seed(0)
size = 98432
x = torch.randn(size, device=DEVICE)
y = torch.randn(size, device=DEVICE)

output_triton = add(x, y, block_size=256)
output_pytorch = x + y

print(output_triton)
print(output_pytorch)
print(f'{torch.max(torch.abs(output_triton - output_pytorch))}')


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'],
x_vals=[2**i for i in range(12, 28, 1)],
x_log=True,
line_arg='provider',
line_vals=['torch', 'triton_bs128', 'triton_bs256', 'triton_bs512', 'triton_nw4', 'triton_nw8', 'triton_nw16', 'triton_autotune'],
line_names=['Torch', 'Triton BS=128', 'Triton BS=256', 'Triton BS=512', 'Triton NW=4', 'Triton NW=8', 'Triton NW=16', 'Triton Autotune'],
styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('purple', '-'), ('orange', '-'), ('cyan', '-'), ('magenta', '-'), ('yellow', '-')],
ylabel='GB/s',
plot_name='vector-add-performance',
args={},
))
def benchmark(size, provider):
x = torch.rand(size, device=DEVICE, dtype=torch.float32)
y = torch.rand(size, device=DEVICE, dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
elif provider == 'triton_bs128':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, block_size=128), quantiles=quantiles)
elif provider == 'triton_bs256':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, block_size=256), quantiles=quantiles)
elif provider == 'triton_bs512':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, block_size=512), quantiles=quantiles)
elif provider == 'triton_nw4':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, block_size=512, num_warps=4), quantiles=quantiles)
elif provider == 'triton_nw8':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, block_size=512, num_warps=8), quantiles=quantiles)
elif provider == 'triton_nw16':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, block_size=512, num_warps=16), quantiles=quantiles)
elif provider == 'triton_autotune':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, autotune=True), quantiles=quantiles)
gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) # noqa: E731
return gbps(ms), gbps(max_ms), gbps(min_ms)


if __name__ == "__main__":
# main()
benchmark.run(print_data=True, show_plots=True, save_path="./results/01_vector_addition")

性能调优

比起源码,我额外增加了 from triton.runtime import autotune ,它的作用就是对于不同size的输入,会在首次执行时搜一遍所有可能的配置,找到其中效率最高的,后续对同样的size就会用固定的配置。

简单尝试一下的话,就会发现影响程序性能的因素有两个 BLOCK_SIZEnum_warps。其实理论上应该是 BLOCK_SIZEnum_warpsinput_size 三者的关系决定了性能。我们用控制变量的方式来测一下它们的影响。

Vector Addition Triton Kernel Performance。测BS时默认NW为4;测NW时默认BS为512


从图中我们能看出两个明显掉点的曲线,一个是 BLOCK_SIZE 为128时,此时的 BLOCK_SIZE 太小,分成的warps太多,导致管理调度 $2^{27} / 128 / 32$ 个块成了开销瓶颈;另一个是 num_warps 为16时,过大的线程块导致了对寄存器等资源的竞争更加剧烈,甚至可能发生寄存器溢出,显著影响效率。其实我也还不太会分析具体的原因,下面是详细的测试表格:

size Torch Triton BS=128 Triton BS=256 Triton BS=512 Triton NW=4 Triton NW=8 Triton NW=16 Triton Autotune
4096.000000 9.035294 9.197604 9.142857 9.142857 9.142857 9.142857 9.088757 8.982456
8192.000000 17.964912 18.070588 17.860465 17.757226 18.070588 17.757226 18.070588 17.757226
16384.000000 35.514452 35.310345 35.310345 35.310345 35.310345 35.310345 35.720930 35.310345
32768.000000 68.266666 68.266666 69.423731 69.033707 69.033707 70.217145 69.818181 69.818181
65536.000000 132.843245 135.779009 135.779009 133.565214 133.565214 133.565214 135.032965 136.533331
131072.000000 253.360834 252.061538 255.999991 258.694729 260.063494 252.061538 253.360834 258.694729
262144.000000 457.227922 444.814490 455.111110 465.895721 463.698115 461.521112 453.013839 465.895721
524288.000000 750.412251 747.558951 774.047204 771.011790 771.011790 765.011652 741.916954 777.106702
1048576.000000 1228.800031 1159.929234 1187.963788 1221.167675 1228.800031 1184.385557 1156.517652 1217.387051
2097152.000000 1687.622326 1569.724635 1676.827323 1698.557221 1687.622326 1684.008546 1569.724635 1694.896509
4194304.000000 2154.608134 1927.529447 2148.721353 2160.527432 2163.499294 2145.789924 1852.607766 2163.499294
8388608.000000 2532.792214 2204.434417 2538.924930 2534.833158 2532.792214 2543.029933 2160.527432 2545.087416
16777216.000000 2755.784589 2349.311512 2763.045981 2756.388411 2755.784589 2763.045981 2394.920460 2758.200902
33554432.000000 2941.994716 2428.196121 2949.580950 2944.059933 2942.682906 2950.618366 2558.022378 2949.580950
67108864.000000 3021.832823 2487.232999 3032.757656 3021.832823 3022.740106 3032.392472 2644.858132 3032.027035
134217728.000000 3067.880595 2514.570784 3071.437476 3066.945668 3067.506556 3072.937670 2676.788108 3072.562746

如果采用 BS=4096,NW=16 的配置的话,性能并不会显著下降,看起来似乎不大符合原来“寄存器溢出导致性能下降”的假设

gemini的回答是:

  • 当 Kernel 需要处理一个包含4096个元素的向量 (如 offsets = tl.arange(0, BLOCK_SIZE)) 时,编译器清楚地知道,不可能一次性将所有4096个元素(对于每个线程来说,是它负责的那部分)都完整地、同时地放在寄存器中进行操作。这远超出了单个线程或 Warp 的实际寄存器容量。因此,编译器很可能会生成一种高度流水线化 (pipelined) 或分块 (tiled/chunked) 的执行代码。它会将这4096个元素的操作分解成更小的、可管理的批次。例如,加载一小批数据到寄存器,计算,存储结果,然后再处理下一小批。
    这种流水线化的处理方式,其瞬时 (instantaneous) 寄存器需求可能相对较低。也就是说,在任何特定时刻,每个线程为正在处理的小数据块所活跃使用的寄存器数量可能不多。

  • 编译器处理512个元素时,它可能尝试一种不同的优化策略。也许它认为可以将更大部分的“向量”同时保持活跃在寄存器中,或者采用一种不那么积极分解成小块的策略,因为它认为这在某些情况下(如有足够寄存器时)可能更有效。
    这种策略如果本身对寄存器的需求就比较高,那么当 num_warps 强制一个较紧的寄存器预算时,就更容易导致寄存器溢出。

暂时不知道对错,感觉我还需要一些性能调优工具。

我们可以发现一个有趣的现象,即固定 num_warps 时,BLOCK_SIZE 设为256是最佳的;固定 BLOCK_SIZE 时,num_warps 设为8是最佳的。

由于性能受到多种因素的影响,因此影响程序效率的参数往往存在着被称为 sweet spot(甜点区) 的最佳范围,高了也不行,低了也不行。

一个比较好的方法是采用autotune,试几种可选参数后丢进去。这样程序会在初次运行时自动测试几种配置,后续对于同样的size均采用最佳配置即可。

值得一提的是,在计算非常简单的 Vector Addition 的场景下,性能瓶颈不在计算而在于带宽,因此最佳性能差距不会太大。

Fused Softmax

遇到的一个神人问题是,源代码中的 kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages) 报错

报错原因是,代码中已经预先编译过内核了

1
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps, grid=(1, ))

而在 softmax_kernel 的原始定义中,BLOCK_SIZEnum_stages 被声明为了 tl.constexpr 类型,表示编译时常量,即在内核中编译的过程中已经按照这个常量编译了,不应该在运行时再次传入。

1
2
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):

因此,源代码应该修改为

1
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols)

最后,完整的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch

import triton
import triton.language as tl
from triton.runtime import driver

target = driver.active.get_current_target()
DEVICE = target.backend
DEVICE_ID = driver.active.get_current_device()

properties = driver.active.utils.get_device_properties(DEVICE_ID)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]


def naive_softmax(x):
# read MN elements, write M elements
x_max = x.max(dim=1)[0]
# read MN+M elements, write MN elements
z = x - x_max[:, None]
# read MN elements, write MN elements
numerator = torch.exp(z)
# read MN elements, write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements, write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret


@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)

input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))

row_minus_max = row - tl.max(row, axis=0)

numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator

output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)


def softmax(x, num_stages_to_use):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)

num_warps = 8
# num_stages = 4 if SIZE_SMEM > 200000 else 2
num_stages = num_stages_to_use

y = torch.empty_like(x)
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()

n_regs = kernel.n_regs
size_smem = kernel.metadata.shared

occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy

num_programs = min(num_programs, n_rows)

kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols)
return y


def main():
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x, 4)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 129)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['torch', 'triton_ns2', 'triton_ns3', 'triton_ns4'],
line_names=[
"Torch",
"Triton (NS=2)",
"Triton (NS=3)",
"Triton (NS=4)",
],
styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('purple', '-')],
ylabel="GB/s",
plot_name="softmax-performance-vs-num_stages",
args={'M': 4096},
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE).Stream()
getattr(torch, DEVICE).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
elif provider == 'triton_ns2':
ms = triton.testing.do_bench(lambda: softmax(x, num_stages_to_use=2))
elif provider == 'triton_ns3':
ms = triton.testing.do_bench(lambda: softmax(x, num_stages_to_use=3))
elif provider == 'triton_ns4':
ms = triton.testing.do_bench(lambda: softmax(x, num_stages_to_use=4))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) # noqa: E731
return gbps(ms)


if __name__ == "__main__":
# main()
benchmark.run(show_plots=True, print_data=True, save_path="./results/02_fused_softmax")

代码说明

这个kernel比 Vector Addition 要稍微复杂一点,因为它涉及一个伪2D的并行。

说是伪2D是因为,它的 BLOCK_SIZE 取值是 triton.next_power_of_2(n_cols),因此,每行不需要再单独划分。

主要的循环 for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): 的功能是:

  • 如果总共有 P 个程序并行,当前程序为 i,则取出第 $[i, P+i, 2P+i, \cdots]$ 行处理

  • 后续每行取 BLOCK_SIZE 其实就已经取完了,并没有在行上并行

对于 @triton.jit 装饰的函数,其中的中间变量,会放在寄存器或者共享内存中,直到 tl.store() 才会写回(如果占太多了也有可能自动offload到显存)。共享内存一般是SRAM,硬件特性决定了比一般的显存(HBM)要快很多。

num_stages 是一个流水线处理的参数,表示并行程度。如果 num_stages>1,则在程序处理第一条数据时,会同时开始加载第二条数据,这样可以加速,但是对SRAM的大小又有了要求。

因此有 num_stages = 4 if SIZE_SMEM > 200000 else 2

benchmark

从图中可以看出,还是要快不少的,不同 num_stages 之间差距不大。

Fused Softmax Triton Kernel Performance