|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +import os |
| 4 | + |
| 5 | +import pandas |
| 6 | +import tilus |
| 7 | +import torch |
| 8 | +from tilus import float16, float32, int32, uint32 |
| 9 | +from tilus.utils import benchmark_func, cdiv |
| 10 | + |
| 11 | +if not tilus.target.get_current_target().supports(tilus.target.nvgpu_sm100a): |
| 12 | + # skip this example if the current target does not support nvgpu_sm100a |
| 13 | + exit(0) |
| 14 | + |
| 15 | +tilus.option.cache_dir(os.path.join(os.path.dirname(__file__), "cache")) |
| 16 | +tilus.option.debug.dump_ir() |
| 17 | + |
| 18 | + |
| 19 | +@tilus.autotune("block_m, block_n", [[128, 64], [128, 128], [128, 256]]) |
| 20 | +@tilus.autotune("block_k", [16, 32, 64]) |
| 21 | +@tilus.autotune("stages", [2, 3, 4]) |
| 22 | +class BlackwellMatmul(tilus.Script): |
| 23 | + def __init__(self, block_m: int, block_n: int, block_k: int, stages: int): |
| 24 | + super().__init__() |
| 25 | + self.block_m = block_m |
| 26 | + self.block_n = block_n |
| 27 | + self.block_k = block_k |
| 28 | + self.stages = stages |
| 29 | + |
| 30 | + def __call__( |
| 31 | + self, |
| 32 | + m_size: int32, |
| 33 | + n_size: int, |
| 34 | + k_size: int, |
| 35 | + a_ptr: ~float16, |
| 36 | + b_ptr: ~float16, |
| 37 | + c_ptr: ~float16, |
| 38 | + ): |
| 39 | + self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)] |
| 40 | + self.attrs.warps = 4 |
| 41 | + |
| 42 | + offset_m: int32 = self.block_m * self.blockIdx.x |
| 43 | + offset_n: int32 = self.block_n * self.blockIdx.y |
| 44 | + |
| 45 | + g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) |
| 46 | + g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) |
| 47 | + s_a = self.shared_tensor( |
| 48 | + dtype=float16, shape=[self.stages, self.block_m, self.block_k] |
| 49 | + ) |
| 50 | + s_b = self.shared_tensor( |
| 51 | + dtype=float16, shape=[self.stages, self.block_n, self.block_k] |
| 52 | + ) |
| 53 | + |
| 54 | + # allocate a tensor in tensor memory (tmem) |
| 55 | + t_acc = self.tcgen05.alloc( |
| 56 | + dtype=float32, shape=[self.block_m, self.block_n], init=0.0 |
| 57 | + ) |
| 58 | + |
| 59 | + # allocate barriers and the initial phases |
| 60 | + consumer_barriers = self.mbarrier.alloc( |
| 61 | + count=[1 for _ in range(self.stages)] |
| 62 | + ) # whether the data is ready for consumption |
| 63 | + producer_barriers = self.mbarrier.alloc( |
| 64 | + count=[1 for _ in range(self.stages)] |
| 65 | + ) # whether the data is ready to be filled |
| 66 | + |
| 67 | + with self.thread_group(group_index=0, group_size=32): |
| 68 | + # tma warp |
| 69 | + stage: int32 = 0 |
| 70 | + producer_phases = self.register_tensor( |
| 71 | + dtype=uint32, shape=[self.stages], init=1 |
| 72 | + ) # all stages are ready to be filled at the beginning |
| 73 | + for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages): |
| 74 | + self.mbarrier.wait( |
| 75 | + producer_barriers[stage], phase=producer_phases[stage] |
| 76 | + ) # wait until the stage is ready to be filled |
| 77 | + producer_phases[stage] ^= 1 |
| 78 | + with self.single_thread(): |
| 79 | + self.tma.global_to_shared( |
| 80 | + src=g_a, |
| 81 | + dst=s_a[stage], |
| 82 | + offsets=[offset_m, offset_k], |
| 83 | + mbarrier=consumer_barriers[stage], |
| 84 | + ) |
| 85 | + self.tma.global_to_shared( |
| 86 | + src=g_b, |
| 87 | + dst=s_b[stage], |
| 88 | + offsets=[offset_n, offset_k], |
| 89 | + mbarrier=consumer_barriers[stage], |
| 90 | + ) |
| 91 | + self.mbarrier.arrive(consumer_barriers[stage]) |
| 92 | + stage = (stage + 1) % self.stages |
| 93 | + |
| 94 | + # remaining mma stages to wait for completion |
| 95 | + for _ in self.range(min(self.stages, cdiv(k_size, self.block_k))): |
| 96 | + self.mbarrier.wait( |
| 97 | + producer_barriers[stage], phase=producer_phases[stage] |
| 98 | + ) # wait until the stage is ready to be filled |
| 99 | + producer_phases[stage] ^= 1 |
| 100 | + stage = (stage + 1) % self.stages |
| 101 | + |
| 102 | + with self.thread_group(group_index=1, group_size=32): |
| 103 | + # mma warp |
| 104 | + consumer_phases = self.register_tensor( |
| 105 | + dtype=uint32, shape=[self.stages], init=0 |
| 106 | + ) # all stages are not ready for consumption at the beginning |
| 107 | + stage: int32 = 0 |
| 108 | + for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages): |
| 109 | + self.mbarrier.wait( |
| 110 | + consumer_barriers[stage], phase=consumer_phases[stage] |
| 111 | + ) # wait until the stage is ready for consumption |
| 112 | + consumer_phases[stage] ^= 1 |
| 113 | + with self.single_thread(): |
| 114 | + self.tcgen05.mma(s_a[stage], s_b[stage].transpose(), t_acc) |
| 115 | + self.tcgen05.commit(mbarrier=producer_barriers[stage]) |
| 116 | + stage = (stage + 1) % self.stages |
| 117 | + |
| 118 | + self.sync() |
| 119 | + |
| 120 | + # load the result from tensor memory to register |
| 121 | + r_acc = self.tcgen05.load( |
| 122 | + t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n] |
| 123 | + ) |
| 124 | + |
| 125 | + g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) |
| 126 | + self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n]) |
| 127 | + |
| 128 | + # all allocated tensor memory must be deallocated |
| 129 | + self.sync() |
| 130 | + self.tcgen05.dealloc(t_acc) |
| 131 | + |
| 132 | + |
| 133 | +def main(bench=True): |
| 134 | + matmul = BlackwellMatmul() |
| 135 | + |
| 136 | + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] |
| 137 | + rows = [] |
| 138 | + |
| 139 | + for m_size, n_size, k_size in [ |
| 140 | + [4096, 4096, 4096], |
| 141 | + [4096, 4096, 14336], |
| 142 | + [8192, 8192, 8192], |
| 143 | + [10240, 10240, 10240], |
| 144 | + ]: |
| 145 | + print(f"Running with m_size={m_size}, n_size={n_size}, k_size={k_size}") |
| 146 | + a = torch.randn(m_size, k_size, dtype=torch.float16, device="cuda") |
| 147 | + b = torch.randn(n_size, k_size, dtype=torch.float16, device="cuda") |
| 148 | + c = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda") |
| 149 | + |
| 150 | + matmul(m_size, n_size, k_size, a, b, c) |
| 151 | + torch.cuda.synchronize() |
| 152 | + |
| 153 | + c_ref = a @ b.T |
| 154 | + |
| 155 | + torch.testing.assert_close(c, c_ref, atol=1e-2, rtol=1e-2) |
| 156 | + |
| 157 | + # benchmark |
| 158 | + if bench: |
| 159 | + for name, func in [ |
| 160 | + ("torch", lambda: a @ b.T), |
| 161 | + ("tilus", lambda: matmul(m_size, n_size, k_size, a, b, c)), |
| 162 | + ]: |
| 163 | + latency = benchmark_func(func, warmup=5, repeat=20) |
| 164 | + tflops = 2 * m_size * n_size * k_size / latency * 1e-9 |
| 165 | + rows.append([m_size, n_size, k_size, name, latency, tflops]) |
| 166 | + |
| 167 | + if bench: |
| 168 | + df = pandas.DataFrame(rows, columns=headers) |
| 169 | + print(df) |
| 170 | + |
| 171 | + |
| 172 | +if __name__ == "__main__": |
| 173 | + main(bench=True) |
| 174 | + # ncu_run(main, bench=False, kernel_regex="hidet|nvjet") |
0 commit comments