Skip to content

Commit 97c02ff

Browse files
[Gluon] Add Cluster Launch Control (CLC) support for Blackwell GPUs (#9361)
This adds support for NVIDIA's Cluster Launch Control (CLC) feature on Blackwell (SM100+) GPUs, enabling dynamic work distribution for persistent kernels. CLC allows running workers to cancel not-yet-launched clusters and take over their work, improving load balancing when SM availability varies. New Gluon API (triton.experimental.gluon.language.nvidia.blackwell.clc): - try_cancel(result, mbar): Issue async CLC request to cancel a pending cluster - is_canceled(result): Check if cancellation succeeded (returns non-zero) - get_first_ctaid(result, dim): Get the canceled cluster's first CTA ID MLIR ops added: - ttng.clc_try_cancel: Lowers to clusterlaunchcontrol.try_cancel.async PTX - ttng.clc_is_canceled: Lowers to clusterlaunchcontrol.query_cancel.is_canceled - ttng.clc_get_first_ctaid: Lowers to clusterlaunchcontrol.query_cancel.get_first_ctaid All ops include SM100+ compute capability checks and emit errors on older GPUs. Tutorial included demonstrating CLC matmul achieving 92.5% of cuBLAS performance on 8192x8192x8192 FP16 matrices. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: Peter Bell <peterbell10@openai.com>
1 parent ed4ef36 commit 97c02ff

11 files changed

Lines changed: 885 additions & 1 deletion

File tree

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,99 @@ def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
8787
let hasVerifier = 1;
8888
}
8989

90+
//
91+
// Cluster Launch Control (CLC) Ops - Blackwell SM100+
92+
//
93+
def TTNG_CLCTryCancelOp : TTNG_Op<"clc_try_cancel", []> {
94+
let summary = "Issue CLC try_cancel to cancel a pending cluster";
95+
96+
let description = [{
97+
Issues a clusterlaunchcontrol.try_cancel instruction to atomically cancel
98+
a pending cluster launch. The result is written asynchronously to the
99+
result buffer and the mbarrier is signaled on completion.
100+
101+
This is used for dynamic persistent kernels on Blackwell (SM100+).
102+
103+
The result buffer must be 16-byte aligned shared memory.
104+
The mbarrier must be 8-byte aligned shared memory.
105+
}];
106+
107+
let arguments = (ins
108+
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$result,
109+
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$mbarrier,
110+
I1Attr:$multicast
111+
);
112+
113+
let assemblyFormat = [{
114+
$result `,` $mbarrier attr-dict `:` qualified(type($result)) `,` qualified(type($mbarrier))
115+
}];
116+
let hasVerifier = 1;
117+
}
118+
119+
def TTNG_CLCLoadResultOp : TTNG_Op<"clc_load_result", []> {
120+
let summary = "Load CLC response from shared memory into registers";
121+
122+
let description = [{
123+
Loads the 128-bit CLC response from shared memory into two i64 registers.
124+
This allows subsequent is_canceled and get_first_ctaid operations to
125+
operate on registers without re-reading shared memory.
126+
}];
127+
128+
let arguments = (ins
129+
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src
130+
);
131+
132+
let results = (outs I128:$clcResult);
133+
134+
let assemblyFormat = [{
135+
$src attr-dict `:` qualified(type($src)) `->` type($clcResult)
136+
}];
137+
let hasVerifier = 1;
138+
}
139+
140+
def TTNG_CLCIsCanceledOp : TTNG_Op<"clc_is_canceled", [Pure]> {
141+
let summary = "Check if CLC response indicates successful cancellation";
142+
143+
let description = [{
144+
Decodes the CLC response to check if a cluster was successfully
145+
canceled. Returns true if canceled, false otherwise.
146+
}];
147+
148+
let arguments = (ins I128:$clcResult);
149+
150+
let results = (outs I1:$is_canceled);
151+
152+
let assemblyFormat = [{
153+
$clcResult attr-dict `:` type($clcResult) `->` type($is_canceled)
154+
}];
155+
}
156+
157+
def TTNG_CLCGetProgramIdOp : TTNG_Op<"clc_get_program_id", [Pure]> {
158+
let summary = "Get CTA ID coordinate from CLC response";
159+
160+
let description = [{
161+
Decodes the CLC response to get the first CTA ID coordinate of the
162+
canceled cluster. The dim attribute specifies which dimension (0=x, 1=y, 2=z).
163+
}];
164+
165+
let arguments = (ins
166+
I128:$clcResult,
167+
TT_ProgramDim:$dim
168+
);
169+
170+
let results = (outs I32:$result);
171+
172+
let assemblyFormat = [{
173+
$clcResult `,` $dim attr-dict `:` type($clcResult) `->` type($result)
174+
}];
175+
176+
let builders = [
177+
OpBuilder<(ins "Value":$clcResult, "int":$axis), [{
178+
build($_builder, $_state, clcResult, ProgramIDDim(axis));
179+
}]>
180+
];
181+
}
182+
90183
//
91184
// WarpGroupDot Op
92185
//

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h"
3535
#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.cpp.inc"
3636
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
37+
#include "triton/Tools/StrUtil.h"
3738
#include "llvm/Support/Casting.h"
3839
#include "llvm/Support/ErrorHandling.h"
3940

@@ -1152,6 +1153,30 @@ LogicalResult TensormapCreateOp::verify() {
11521153
return success();
11531154
}
11541155

1156+
// -- CLCTryCancelOp --
1157+
static LogicalResult verifyCLCResultMemdesc(Location loc, MemDescType desc) {
1158+
auto int_ty = dyn_cast<IntegerType>(desc.getElementType());
1159+
if (!int_ty || int_ty.getWidth() != 64) {
1160+
return emitError(loc)
1161+
<< "Expected CLC result buffer to have type int64, but got"
1162+
<< desc.getElementType();
1163+
}
1164+
if (desc.getShape().size() != 1 || desc.getShape()[0] != 2) {
1165+
return emitError(loc)
1166+
<< "Expected CLC result buffer to have shape [2], but got ["
1167+
<< triton::join(desc.getShape(), ", ") << "]";
1168+
}
1169+
return success();
1170+
}
1171+
1172+
LogicalResult CLCTryCancelOp::verify() {
1173+
return verifyCLCResultMemdesc(getLoc(), getResult().getType());
1174+
}
1175+
1176+
LogicalResult CLCLoadResultOp::verify() {
1177+
return verifyCLCResultMemdesc(getLoc(), getSrc().getType());
1178+
}
1179+
11551180
} // namespace nvidia_gpu
11561181
} // namespace triton
11571182
} // namespace mlir

python/src/gluon_ir.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,27 @@ void init_gluon_ir(py::module &&m) {
832832
})
833833
.def("create_cluster_wait",
834834
[](GluonOpBuilder &self) { self.create<ttng::ClusterWaitOp>(); })
835+
// CLC (Cluster Launch Control) ops - SM100+
836+
.def("create_clc_try_cancel",
837+
[](GluonOpBuilder &self, Value result, Value mbarrier,
838+
bool multicast) {
839+
self.create<ttng::CLCTryCancelOp>(result, mbarrier, multicast);
840+
})
841+
.def("create_clc_load_result",
842+
[](GluonOpBuilder &self, Value result) -> Value {
843+
auto i64Ty = self.getBuilder().getI64Type();
844+
return self.create<ttng::CLCLoadResultOp>(result);
845+
})
846+
.def("create_clc_is_canceled",
847+
[](GluonOpBuilder &self, Value clcResult) -> Value {
848+
auto i1Ty = self.getBuilder().getI1Type();
849+
return self.create<ttng::CLCIsCanceledOp>(clcResult);
850+
})
851+
.def("create_clc_get_program_id",
852+
[](GluonOpBuilder &self, Value clcResult, int dim) -> Value {
853+
auto i32Ty = self.getBuilder().getI32Type();
854+
return self.create<ttng::CLCGetProgramIdOp>(clcResult, dim);
855+
})
835856
.def("create_tcgen05_mma",
836857
[](GluonOpBuilder &self, Value a, Value b, Value acc, Value useAcc,
837858
Value pred, std::vector<Value> &mbarriers,

python/src/ir.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,10 @@ void init_triton_ir(py::module &&m) {
10081008
[](TritonOpBuilder &self) -> Type {
10091009
return self.getBuilder().getI64Type();
10101010
})
1011+
.def("get_int128_ty",
1012+
[](TritonOpBuilder &self) -> Type {
1013+
return self.getBuilder().getIntegerType(128);
1014+
})
10111015
.def("get_fp8e4nv_ty",
10121016
[](TritonOpBuilder &self) -> Type {
10131017
return self.getBuilder().getType<Float8E4M3FNType>();

python/test/gluon/test_core.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
tcgen05_commit,
4141
tcgen05_copy,
4242
float2,
43+
clc,
4344
)
4445
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
4546

@@ -3421,3 +3422,50 @@ def test_tmem_reduction(red_op, use_abs, propagate_nan, M, N, num_warps):
34213422
# Verify reduction output
34223423
# Use equal_nan=True when testing NaN propagation
34233424
torch.testing.assert_close(expected_red, red_output, atol=1e-5, rtol=1e-5, equal_nan=use_nan)
3425+
3426+
3427+
@pytest.mark.parametrize("num_ctas", [1, 2])
3428+
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
3429+
def test_clc_basic(num_ctas):
3430+
3431+
@gluon.jit
3432+
def clc_kernel(WasLaunched, IsCancelled, ProgramId, smem_size: ttgl.constexpr):
3433+
# Large shared memory allocation to force 1 block per SM
3434+
cga_layout: ttgl.constexpr = [[0]] if ttgl.num_ctas() == 2 else []
3435+
layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=cga_layout)
3436+
dummy = ttgl.allocate_shared_memory(ttgl.int64, [smem_size // 8 - 32], layout)
3437+
3438+
clc_result = ttgl.allocate_shared_memory(ttgl.int64, [2], layout)
3439+
clc_mbar = mbarrier.allocate_mbarrier()
3440+
mbarrier.init(clc_mbar, count=1)
3441+
3442+
clc.try_cancel(clc_result, clc_mbar, multicast=True)
3443+
mbarrier.expect(clc_mbar, 16)
3444+
mbarrier.wait(clc_mbar, 0)
3445+
3446+
response = clc.load_result(clc_result)
3447+
pid = ttgl.program_id(0)
3448+
ttgl.store(WasLaunched + pid, True)
3449+
ttgl.store(IsCancelled + pid, response.is_canceled())
3450+
ttgl.store(ProgramId + pid, response.program_id(0))
3451+
dummy._keep_alive()
3452+
3453+
dev_props = torch.cuda.get_device_properties("cuda")
3454+
num_sms = dev_props.multi_processor_count
3455+
smem_size = dev_props.shared_memory_per_block_optin // num_ctas
3456+
grid = 2 * (num_sms // num_ctas)
3457+
3458+
was_launched = torch.zeros([grid], dtype=torch.bool, device="cuda")
3459+
is_cancelled = torch.zeros([grid], dtype=torch.bool, device="cuda")
3460+
program_ids = torch.zeros([grid], dtype=torch.int32, device="cuda")
3461+
clc_kernel[(grid, )](was_launched, is_cancelled, program_ids, smem_size, num_ctas=num_ctas)
3462+
3463+
num_launched = torch.sum(was_launched).item()
3464+
assert num_launched < grid
3465+
3466+
num_cancelled = torch.sum(is_cancelled).item()
3467+
assert num_launched + num_cancelled == grid
3468+
3469+
for pid in range(grid):
3470+
if is_cancelled[pid]:
3471+
assert not was_launched[program_ids[pid]]

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from triton.experimental.gluon.language._semantic import _check, _compute_tmem_reg_layout
99

1010
from . import tma
11+
from . import clc
1112
from ..hopper import fence_async_shared, mbarrier
1213
from ..ampere import async_copy, mma_v2
1314

@@ -20,6 +21,7 @@
2021
__all__ = [
2122
"allocate_tensor_memory",
2223
"async_copy",
24+
"clc",
2325
"fence_async_shared",
2426
"get_tmem_reg_layout",
2527
"mbarrier",
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""
2+
Cluster Launch Control (CLC) for Blackwell (SM100+) dynamic persistent kernels.
3+
4+
CLC enables hardware-based dynamic work scheduling where running workers can
5+
cancel not-yet-launched clusters and take over their work via the
6+
clusterlaunchcontrol.try_cancel instruction.
7+
"""
8+
from __future__ import annotations
9+
10+
import triton.experimental.gluon.language._core as gl
11+
from triton.experimental.gluon.language._core import builtin, tensor, shared_memory_descriptor, base_value, base_type
12+
from typing import TYPE_CHECKING, List, Tuple
13+
14+
if TYPE_CHECKING:
15+
from triton._C.libtriton.gluon_ir import GluonOpBuilder
16+
from triton._C.libtriton import ir
17+
18+
__all__ = [
19+
"try_cancel",
20+
"load_result",
21+
"clc_result",
22+
]
23+
24+
25+
@builtin
26+
def try_cancel(result: shared_memory_descriptor, barrier, multicast=False, _semantic=None):
27+
"""
28+
Issue a CLC try_cancel request to atomically cancel a pending cluster launch.
29+
30+
Args:
31+
result (shared_memory_descriptor): 16-byte aligned shared memory for the response
32+
barrier (shared_memory_descriptor): 8-byte aligned mbarrier for completion signaling
33+
multicast (bool): If True, broadcast result to all CTAs in cluster
34+
35+
Only supported on SM100+ (Blackwell).
36+
"""
37+
_semantic.builder.create_clc_try_cancel(result.handle, barrier.handle, multicast)
38+
39+
40+
@builtin
41+
def load_result(src, _semantic=None):
42+
"""
43+
Load the CLC response from shared memory into registers.
44+
45+
Args:
46+
src (shared_memory_descriptor): The CLC response buffer
47+
48+
Returns:
49+
CLCResult: Object with is_canceled() and get_first_ctaid(dim) methods
50+
"""
51+
handle = _semantic.builder.create_clc_load_result(src.handle)
52+
return clc_result(handle)
53+
54+
55+
class clc_result_type(base_type):
56+
57+
def to_ir(self, builder: GluonOpBuilder) -> None:
58+
return builder.get_int128_ty()
59+
60+
def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[shared_memory_descriptor, int]:
61+
value = clc_result(handles[cursor])
62+
return value, cursor + 1
63+
64+
def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None:
65+
out.append(self.to_ir(builder))
66+
67+
def __str__(self) -> str:
68+
return "clc_result"
69+
70+
def __eq__(self, other) -> bool:
71+
return type(self) is type(other)
72+
73+
def mangle(self) -> str:
74+
return "CLC"
75+
76+
77+
class clc_result(base_value):
78+
"""CLC response loaded into registers. Query without re-reading memory."""
79+
80+
def __init__(self, handle):
81+
self.handle = handle
82+
self.type = clc_result_type()
83+
84+
def _flatten_ir(self, handles: List[ir.value]) -> None:
85+
handles.append(self.handle)
86+
87+
def _set_name(self, builder: ir.builder, name: str) -> None:
88+
self.handle.set_loc(builder.create_name_loc(name, self.handle.get_loc()))
89+
90+
@builtin
91+
def is_canceled(self, _semantic=None):
92+
"""
93+
Check if the CLC response indicates a successful cancellation.
94+
95+
Returns:
96+
tensor: True if a cluster was successfully canceled, False otherwise
97+
"""
98+
handle = _semantic.builder.create_clc_is_canceled(self.handle)
99+
return tensor(handle, gl.int1)
100+
101+
@builtin
102+
def program_id(self, dim, _semantic=None):
103+
"""
104+
Get the Program ID of the canceled cluster.
105+
106+
Args:
107+
dim (int): Dimension to get (0=x, 1=y, 2=z)
108+
109+
Returns:
110+
tensor: The Program ID for the specified dimension
111+
"""
112+
handle = _semantic.builder.create_clc_get_program_id(self.handle, dim)
113+
return tensor(handle, gl.int32)

python/tutorials/gluon/07-persistence.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,8 @@ def test_persistent_matmul_pipelined(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_buf
834834
# Hopper and Blackwell: we are not double-buffering the accumulator and
835835
# leaving 256 columns of TMEM unused.
836836
# - On Blackwell, we can use `clusterlaunchcontrol` to dynamically schedule
837-
# work in conjunction with the GPU, getting the best of both worlds.
837+
# work in conjunction with the GPU, getting the best of both worlds. This is
838+
# explored further in tutorial 12.
838839
#
839840
# Main takeaways:
840841
#

0 commit comments

Comments
 (0)