|
| 1 | +""" |
| 2 | +Cluster Launch Control (CLC) for Blackwell (SM100+) dynamic persistent kernels. |
| 3 | +
|
| 4 | +CLC enables hardware-based dynamic work scheduling where running workers can |
| 5 | +cancel not-yet-launched clusters and take over their work via the |
| 6 | +clusterlaunchcontrol.try_cancel instruction. |
| 7 | +""" |
| 8 | +from __future__ import annotations |
| 9 | + |
| 10 | +import triton.experimental.gluon.language._core as gl |
| 11 | +from triton.experimental.gluon.language._core import builtin, tensor, shared_memory_descriptor, base_value, base_type |
| 12 | +from typing import TYPE_CHECKING, List, Tuple |
| 13 | + |
| 14 | +if TYPE_CHECKING: |
| 15 | + from triton._C.libtriton.gluon_ir import GluonOpBuilder |
| 16 | + from triton._C.libtriton import ir |
| 17 | + |
| 18 | +__all__ = [ |
| 19 | + "try_cancel", |
| 20 | + "load_result", |
| 21 | + "clc_result", |
| 22 | +] |
| 23 | + |
| 24 | + |
| 25 | +@builtin |
| 26 | +def try_cancel(result: shared_memory_descriptor, barrier, multicast=False, _semantic=None): |
| 27 | + """ |
| 28 | + Issue a CLC try_cancel request to atomically cancel a pending cluster launch. |
| 29 | +
|
| 30 | + Args: |
| 31 | + result (shared_memory_descriptor): 16-byte aligned shared memory for the response |
| 32 | + barrier (shared_memory_descriptor): 8-byte aligned mbarrier for completion signaling |
| 33 | + multicast (bool): If True, broadcast result to all CTAs in cluster |
| 34 | +
|
| 35 | + Only supported on SM100+ (Blackwell). |
| 36 | + """ |
| 37 | + _semantic.builder.create_clc_try_cancel(result.handle, barrier.handle, multicast) |
| 38 | + |
| 39 | + |
| 40 | +@builtin |
| 41 | +def load_result(src, _semantic=None): |
| 42 | + """ |
| 43 | + Load the CLC response from shared memory into registers. |
| 44 | +
|
| 45 | + Args: |
| 46 | + src (shared_memory_descriptor): The CLC response buffer |
| 47 | +
|
| 48 | + Returns: |
| 49 | + CLCResult: Object with is_canceled() and get_first_ctaid(dim) methods |
| 50 | + """ |
| 51 | + handle = _semantic.builder.create_clc_load_result(src.handle) |
| 52 | + return clc_result(handle) |
| 53 | + |
| 54 | + |
| 55 | +class clc_result_type(base_type): |
| 56 | + |
| 57 | + def to_ir(self, builder: GluonOpBuilder) -> None: |
| 58 | + return builder.get_int128_ty() |
| 59 | + |
| 60 | + def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[shared_memory_descriptor, int]: |
| 61 | + value = clc_result(handles[cursor]) |
| 62 | + return value, cursor + 1 |
| 63 | + |
| 64 | + def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None: |
| 65 | + out.append(self.to_ir(builder)) |
| 66 | + |
| 67 | + def __str__(self) -> str: |
| 68 | + return "clc_result" |
| 69 | + |
| 70 | + def __eq__(self, other) -> bool: |
| 71 | + return type(self) is type(other) |
| 72 | + |
| 73 | + def mangle(self) -> str: |
| 74 | + return "CLC" |
| 75 | + |
| 76 | + |
| 77 | +class clc_result(base_value): |
| 78 | + """CLC response loaded into registers. Query without re-reading memory.""" |
| 79 | + |
| 80 | + def __init__(self, handle): |
| 81 | + self.handle = handle |
| 82 | + self.type = clc_result_type() |
| 83 | + |
| 84 | + def _flatten_ir(self, handles: List[ir.value]) -> None: |
| 85 | + handles.append(self.handle) |
| 86 | + |
| 87 | + def _set_name(self, builder: ir.builder, name: str) -> None: |
| 88 | + self.handle.set_loc(builder.create_name_loc(name, self.handle.get_loc())) |
| 89 | + |
| 90 | + @builtin |
| 91 | + def is_canceled(self, _semantic=None): |
| 92 | + """ |
| 93 | + Check if the CLC response indicates a successful cancellation. |
| 94 | +
|
| 95 | + Returns: |
| 96 | + tensor: True if a cluster was successfully canceled, False otherwise |
| 97 | + """ |
| 98 | + handle = _semantic.builder.create_clc_is_canceled(self.handle) |
| 99 | + return tensor(handle, gl.int1) |
| 100 | + |
| 101 | + @builtin |
| 102 | + def program_id(self, dim, _semantic=None): |
| 103 | + """ |
| 104 | + Get the Program ID of the canceled cluster. |
| 105 | +
|
| 106 | + Args: |
| 107 | + dim (int): Dimension to get (0=x, 1=y, 2=z) |
| 108 | +
|
| 109 | + Returns: |
| 110 | + tensor: The Program ID for the specified dimension |
| 111 | + """ |
| 112 | + handle = _semantic.builder.create_clc_get_program_id(self.handle, dim) |
| 113 | + return tensor(handle, gl.int32) |
0 commit comments