Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions examples/blackwell_matmul/matmul_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os

import pandas
import tilus
import torch
from tilus import float16, float32, int32, uint32
from tilus.utils import benchmark_func, cdiv

if not tilus.target.get_current_target().supports(tilus.target.nvgpu_sm100a):
# skip this example if the current target does not support nvgpu_sm100a
exit(0)

tilus.option.cache_dir(os.path.join(os.path.dirname(__file__), "cache"))
tilus.option.debug.dump_ir()


@tilus.autotune("block_m, block_n", [[128, 64], [128, 128], [128, 256]])
@tilus.autotune("block_k", [16, 32, 64])
@tilus.autotune("stages", [2, 3, 4])
class BlackwellMatmul(tilus.Script):
def __init__(self, block_m: int, block_n: int, block_k: int, stages: int):
super().__init__()
self.block_m = block_m
self.block_n = block_n
self.block_k = block_k
self.stages = stages

def __call__(
self,
m_size: int32,
n_size: int,
k_size: int,
a_ptr: ~float16,
b_ptr: ~float16,
c_ptr: ~float16,
):
self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
self.attrs.warps = 4

offset_m: int32 = self.block_m * self.blockIdx.x
offset_n: int32 = self.block_n * self.blockIdx.y

g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
s_a = self.shared_tensor(
dtype=float16, shape=[self.stages, self.block_m, self.block_k]
)
s_b = self.shared_tensor(
dtype=float16, shape=[self.stages, self.block_n, self.block_k]
)

# allocate a tensor in tensor memory (tmem)
t_acc = self.tcgen05.alloc(
dtype=float32, shape=[self.block_m, self.block_n], init=0.0
)

# allocate barriers and the initial phases
consumer_barriers = self.mbarrier.alloc(
count=[1 for _ in range(self.stages)]
) # whether the data is ready for consumption
producer_barriers = self.mbarrier.alloc(
count=[1 for _ in range(self.stages)]
) # whether the data is ready to be filled

with self.thread_group(group_index=0, group_size=32):
# tma warp
stage: int32 = 0
producer_phases = self.register_tensor(
dtype=uint32, shape=[self.stages], init=1
) # all stages are ready to be filled at the beginning
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
self.mbarrier.wait(
producer_barriers[stage], phase=producer_phases[stage]
) # wait until the stage is ready to be filled
producer_phases[stage] ^= 1
with self.single_thread():
self.tma.global_to_shared(
src=g_a,
dst=s_a[stage],
offsets=[offset_m, offset_k],
mbarrier=consumer_barriers[stage],
)
self.tma.global_to_shared(
src=g_b,
dst=s_b[stage],
offsets=[offset_n, offset_k],
mbarrier=consumer_barriers[stage],
)
self.mbarrier.arrive(consumer_barriers[stage])
stage = (stage + 1) % self.stages

# remaining mma stages to wait for completion
for _ in self.range(min(self.stages, cdiv(k_size, self.block_k))):
self.mbarrier.wait(
producer_barriers[stage], phase=producer_phases[stage]
) # wait until the stage is ready to be filled
producer_phases[stage] ^= 1
stage = (stage + 1) % self.stages

with self.thread_group(group_index=1, group_size=32):
# mma warp
consumer_phases = self.register_tensor(
dtype=uint32, shape=[self.stages], init=0
) # all stages are not ready for consumption at the beginning
stage: int32 = 0
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
self.mbarrier.wait(
consumer_barriers[stage], phase=consumer_phases[stage]
) # wait until the stage is ready for consumption
consumer_phases[stage] ^= 1
with self.single_thread():
self.tcgen05.mma(s_a[stage], s_b[stage].transpose(), t_acc)
self.tcgen05.commit(mbarrier=producer_barriers[stage])
stage = (stage + 1) % self.stages

self.sync()

# load the result from tensor memory to register
r_acc = self.tcgen05.load(
t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n]
)

g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])

# all allocated tensor memory must be deallocated
self.sync()
self.tcgen05.dealloc(t_acc)


def main(bench=True):
matmul = BlackwellMatmul()

headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
rows = []

for m_size, n_size, k_size in [
[4096, 4096, 4096],
[4096, 4096, 14336],
[8192, 8192, 8192],
[10240, 10240, 10240],
]:
print(f"Running with m_size={m_size}, n_size={n_size}, k_size={k_size}")
a = torch.randn(m_size, k_size, dtype=torch.float16, device="cuda")
b = torch.randn(n_size, k_size, dtype=torch.float16, device="cuda")
c = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda")

matmul(m_size, n_size, k_size, a, b, c)
torch.cuda.synchronize()

c_ref = a @ b.T

torch.testing.assert_close(c, c_ref, atol=1e-2, rtol=1e-2)

# benchmark
if bench:
for name, func in [
("torch", lambda: a @ b.T),
("tilus", lambda: matmul(m_size, n_size, k_size, a, b, c)),
]:
latency = benchmark_func(func, warmup=5, repeat=20)
tflops = 2 * m_size * n_size * k_size / latency * 1e-9
rows.append([m_size, n_size, k_size, name, latency, tflops])

if bench:
df = pandas.DataFrame(rows, columns=headers)
print(df)


if __name__ == "__main__":
main(bench=True)
# ncu_run(main, bench=False, kernel_regex="hidet|nvjet")
8 changes: 6 additions & 2 deletions python/tilus/backends/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def __init__(self, codegen):
def __post_init__(self):
pass

@property
def contexts(self):
return self.codegen.contexts

def host_prepend(self, stmt: Expr | HidetStmt) -> None:
"""Prepend a statement to the host function.

Expand All @@ -43,7 +47,7 @@ def host_prepend(self, stmt: Expr | HidetStmt) -> None:
stmt: Expr or HidetStmt
The statement to be prepended.
"""
self.codegen.host_builder.scope_stack[-1].insert(0, stmt)
self.codegen.host_builder.scope_stack[0].insert(0, stmt)

def host_append(self, stmt: Expr | HidetStmt) -> None:
"""Append a statement to the host function.
Expand All @@ -63,7 +67,7 @@ def kernel_prepend(self, stmt: Expr | HidetStmt) -> None:
stmt: Expr or HidetStmt
The statement to be prepended.
"""
self.codegen.builder.scope_stack[-1].insert(0, stmt)
self.codegen.builder.scope_stack[0].insert(0, stmt)

def kernel_append(self, stmt: Expr | HidetStmt) -> None:
"""Append a statement to the kernel function.
Expand Down
6 changes: 5 additions & 1 deletion python/tilus/backends/contexts/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from tilus.backends.contexts.global_view_ctx import GlobalTensorViewContext
from tilus.backends.contexts.gmem_alloc_ctx import GlobalMemoryAllocationContext
from tilus.backends.contexts.invariant_ctx import InvariantTrackingContext
from tilus.backends.contexts.mbarrier_alloc_ctx import BarrierAllocContext
from tilus.backends.contexts.smem_alloc_ctx import SharedMemoryAllocationContext
from tilus.backends.contexts.sync_ctx import SyncContext
from tilus.backends.contexts.tcgen05_ctx import Tcgen05EmitContext


Expand All @@ -32,6 +34,8 @@ def __init__(self, codegen):
self.invariant_ctx: InvariantTrackingContext = InvariantTrackingContext(codegen)
self.smem_alloc_ctx: SharedMemoryAllocationContext = SharedMemoryAllocationContext(codegen)
self.tcgen05_ctx: Tcgen05EmitContext = Tcgen05EmitContext(codegen)
self.barrier_alloc_ctx: BarrierAllocContext = BarrierAllocContext(codegen)
self.sync_ctx: SyncContext = SyncContext(codegen)

def contexts(self) -> list[BaseEmitContext]:
"""Get all contexts as a list.
Expand All @@ -56,5 +60,5 @@ def finalize(self):

This method is called when the codegen is finished for all instructions.
"""
for ctx in self.contexts():
for ctx in reversed(self.contexts()):
ctx.finalize()
70 changes: 70 additions & 0 deletions python/tilus/backends/contexts/mbarrier_alloc_ctx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Sequence

from hidet.ir.builders import StmtBuilder
from hidet.ir.dtypes import uint32, uint64
from hidet.ir.expr import Expr, Var
from hidet.ir.primitives.cuda.barrier import fence_view_async_shared
from hidet.ir.primitives.cuda.cvta import cvta_generic_to_shared
from hidet.ir.primitives.cuda.smem import dynamic_shared_memory

from tilus.backends.context import BaseEmitContext
from tilus.extensions.hidet.ir.primitives.cuda.mbarrier import mbarrier_init_shared
from tilus.ir.layout import ops
from tilus.ir.tensor import SharedTensor


class BarrierAllocContext(BaseEmitContext):
"""Context used to manage the allocation of barriers."""

def __post_init__(self):
self.counts: list[Expr] = []
self.barriers: list[Var] = []
self.barrier_addr: Var = Var("barriers", type=uint32)

def finalize(self):
# allocate shared memory for all barriers
num_barriers = len(self.counts)

if num_barriers == 0:
# No barriers to allocate
return

tensor = SharedTensor(dtype=uint64, shape=(num_barriers,), optional_layout=ops.shared_row_major(num_barriers))
virtual_smem_addr = self.contexts.smem_alloc_ctx.allocate_shared_tensor(tensor, nbytes=tensor.nbytes)
sb = StmtBuilder()
sb.declare(
v=self.barrier_addr,
init=cvta_generic_to_shared(dynamic_shared_memory(byte_offset=virtual_smem_addr, dtype=uint64)),
)

for i in range(num_barriers):
sb.declare(v=self.barriers[i], init=self.barrier_addr + uint32(i * uint64.nbytes))
sb.append(mbarrier_init_shared(mbarrier_addr=self.barriers[i], arrive_count=uint32(self.counts[i])))
sb.append(fence_view_async_shared())
self.kernel_prepend(sb.finish())

def allocate_barriers(self, counts: Sequence[Expr | int]) -> list[Var]:
"""
Allocate a list of barriers with given counts. Each barrier is a 64-bit data structure stored in shared memory.
This function returns the address of the first barrier in the shared space.
"""
barrier_vars = [Var("barrier_{}".format(c), type=uint32) for c in counts]
self.counts.extend([uint32(c) if isinstance(c, int) else c for c in counts])
self.barriers.extend(barrier_vars)
return barrier_vars
Loading
Loading