diff --git a/.github/workflows/ascend-build-and-test.yml b/.github/workflows/ascend-build-and-test.yml index 4ea3efda67..73a2c73392 100644 --- a/.github/workflows/ascend-build-and-test.yml +++ b/.github/workflows/ascend-build-and-test.yml @@ -64,6 +64,10 @@ jobs: python3 14-accuracy-comparison.py #python3 15-embedding_gather_demo.py popd + # hint tests + pushd third_party/ascend/tutorials/hint + python3 test_comment_hint.py + popd # pytest_ut pushd third_party/ascend/unittest/pytest_ut python3 -m pytest . \ diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py new file mode 100644 index 0000000000..e7860cf64d --- /dev/null +++ b/python/triton/compiler/hint_manager.py @@ -0,0 +1,135 @@ +import sys +import importlib + + +class BaseHintHandler: + # dynamicly find method + def trigger(self, hook_name, *args, **kwargs): + if hasattr(self, hook_name): + method = getattr(self, hook_name) + if callable(method): + try: + return method(*args, **kwargs) + + except TypeError as e: + import inspect + + try: + sig = inspect.signature(method) + expected = str(sig) + except Exception: + expected = "(unknown)" + + actual_args = f"{len(args)} positional" + actual_kwargs = f"keys={list(kwargs.keys())}" if kwargs else "no keywords" + + print(f"\n[Hint Trigger Mismatch] {self.__class__.__name__}.{hook_name}") + print(f" > Expect : {expected}") + print(f" > Actual : {actual_args}, {actual_kwargs}") + print(f" > Reason : {e}\n") + + raise e + return None + + +class HintManager: + + def __init__(self, backend_name): + self.backend_name = backend_name + # load Handler with backend name + self.handler = self._load_handler(backend_name) + + def _load_handler(self, backend): + if backend == 'npu': + try: + module = importlib.import_module("triton.backends.ascend.ascend_hint_handler") + return module.AscendHintHandler() + except ImportError as e: + print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr) + return BaseHintHandler() + elif backend == 'aipu': + try: + module = importlib.import_module("triton.backends.aipu.aipu_hint_handler") + return module.AipuHintHandler() + except ImportError as e: + print(f"[FlagTree] Warning: Failed to load aipu Hint Handler: {e}", file=sys.stderr) + return BaseHintHandler() + elif backend == 'cuda': + try: + module = importlib.import_module("triton.backends.nvidia.nvidia_hint_handler") + return module.NvidiaHintHandler() + except ImportError as e: + print(f"[FlagTree] Warning: Failed to load Nvidia Hint Handler: {e}", file=sys.stderr) + return BaseHintHandler() + else: + return BaseHintHandler() + + +# supported backend with matched version +SUPPORTED_BACKENDS = ["aipu", "npu", "cuda"] + +# TODO : npu will have conflicts if more backend involved +# mapping name +BACKEND_ALIASES = { + "ascend": "npu", + "huawei": "npu", + "nvidia": "cuda", +} + + +def normalize_backend_name(name: str) -> str: + if not name: + return "" + name = name.lower() + return BACKEND_ALIASES.get(name, name) + + +def hint_get_flagtree_backend() -> str: + detected_backend = "" + + import torch + + # Priority 1: Triton Driver + try: + from triton.runtime import driver + if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'): + device = driver.active.get_active_torch_device() + if isinstance(device, torch.device): + detected_backend = device.type + # unimplemented support + elif isinstance(device, str): + detected_backend = device + except ImportError: + pass + + # TODO : some backend may not support priority 1, so keep priority 2 is necessary + # Priority 2: Torch Global State + if not detected_backend: + check_priority = ["aipu", "npu", "cuda"] + + # 3. parse according to benefit + for candidate in check_priority: + module = getattr(torch, candidate, None) + if module and hasattr(module, "is_available") and module.is_available(): + detected_backend = candidate + break + + # (Normalization and Validation) + canonical_backend = normalize_backend_name(detected_backend) + + if not canonical_backend or canonical_backend not in SUPPORTED_BACKENDS: + return "" + + return canonical_backend + + +# lazy load after first call hint trigger +_global_hint_manager = None + + +def hint_trigger(hook_name, *args, **kwargs): + global _global_hint_manager + + if _global_hint_manager is None: + _global_hint_manager = HintManager(hint_get_flagtree_backend()) + return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs) diff --git a/third_party/ascend/backend/ascend_hint_handler.py b/third_party/ascend/backend/ascend_hint_handler.py new file mode 100644 index 0000000000..65e492c6ca --- /dev/null +++ b/third_party/ascend/backend/ascend_hint_handler.py @@ -0,0 +1,79 @@ +# should store at thrid_party/???/backend/ +from triton.compiler.hint_manager import BaseHintHandler +import triton.language as language +import ast +from triton.compiler.code_generator import _is_triton_value + + +class AscendHintHandler(BaseHintHandler): + + @staticmethod + def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): + import ast + from triton.compiler.code_generator import _is_triton_value + # flagtree: After normal processing, check if we need to add hint annotation + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a tl.load call with dot_pad_only_k hint + if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Name) + and node.value.func.value.id == 'tl' and node.value.func.attr == 'load'): + + # Add hint annotation to the loaded tensor(s) + for name, value in zip(names, values): + if _is_triton_value(value): + # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") + # Create hint annotation + hint_val = code_generator.builder.get_unit_attr() + code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) + + @staticmethod + def check_override_bind_sub_block(code_generator, node, bind_sub_block): + # flagtree: After normal processing, check if we need to override bind_sub_block + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a range/for loop with bind_sub_block hint + if flagtree_hints and 'bind_sub_block' in flagtree_hints: + return True + # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") + return bind_sub_block + + @staticmethod + def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): + for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) + + @staticmethod + def maps_line_numbers_to_comment_hints(jit_fn): + import tokenize + from io import StringIO + # Maps line numbers to comment hints + line_flagtree_hints = {} + code_str = jit_fn.src + g = tokenize.generate_tokens(StringIO(code_str).readline) + for tok_type, tok_text, start, end, _ in g: + if tok_type == tokenize.COMMENT: + comment = tok_text.replace(" ", "").strip() + if comment.startswith('#@hint:'): + flagtree_hints = comment[len('#@hint:'):].strip() + # Record the line number of the comment + line_num = start[0] + line_flagtree_hints[line_num] = flagtree_hints + + # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + + return line_flagtree_hints + + @staticmethod + def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): + # Attach the line number to comment mapping to the function definition node + tree.body[0].line_flagtree_hints = line_flagtree_hints diff --git a/third_party/ascend/backend/spec/triton/compiler/code_generator.py b/third_party/ascend/backend/spec/triton/compiler/code_generator.py index 2747e689cf..4af02c6173 100644 --- a/third_party/ascend/backend/spec/triton/compiler/code_generator.py +++ b/third_party/ascend/backend/spec/triton/compiler/code_generator.py @@ -23,6 +23,7 @@ from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType +from .hint_manager import hint_trigger # Central registry for all 'with' statement handlers WITH_DISPATCH = {} @@ -548,6 +549,9 @@ def visit_Assign(self, node): value = language.semantic.to_tensor(value, self.builder) self.set_value(name, value) + # switch into hintmanager + hint_trigger("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values) + def visit_AugAssign(self, node): name = node.target.id lhs = ast.Name(id=name, ctx=ast.Load()) @@ -997,6 +1001,11 @@ def visit_For(self, node): step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) else: raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # hint manager + new_bind_sub_block = hint_trigger("check_override_bind_sub_block", self, node, bind_sub_block) + if new_bind_sub_block is not None: + bind_sub_block = new_bind_sub_block + # handle negative constant step (not supported by scf.for in MLIR) negative_step = False if _is_constexpr(step) and step.value < 0: @@ -1072,6 +1081,9 @@ def visit_For(self, node): tle = importlib.import_module("triton.experimental.tle", package=__package__) if (IteratorClass is extension.parallel or IteratorClass is tle.dsa.parallel): for_op.set_attr("hivm.parallel_loop", self.builder.get_unit_attr()) + # hint manager + if bind_sub_block: + hint_trigger("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block) self.scf_stack.append(node) self.builder.set_insertion_point_to_start(for_op.get_body(0)) diff --git a/third_party/ascend/backend/spec/triton/runtime/jit.py b/third_party/ascend/backend/spec/triton/runtime/jit.py index 45178a40bb..da8ba230eb 100644 --- a/third_party/ascend/backend/spec/triton/runtime/jit.py +++ b/third_party/ascend/backend/spec/triton/runtime/jit.py @@ -756,10 +756,20 @@ def preload(self, specialization_data): # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. def parse(self): + # hint manager + # after removing flagtree backend specialization, hiding the implementation into hintmanager + from ..compiler.hint_manager import hint_trigger + line_flagtree_hints = hint_trigger("maps_line_numbers_to_comment_hints", self) + tree = ast.parse(self.src) assert isinstance(tree, ast.Module) assert len(tree.body) == 1 assert isinstance(tree.body[0], ast.FunctionDef) + + # hint manager + # Attach the line number to comment mapping to the function definition node + hint_trigger('attach_line_number_to_comment_mapping', tree, line_flagtree_hints) + return tree def __call__(self, *args, **kwargs): diff --git a/third_party/ascend/tutorials/hint/test_comment_hint.py b/third_party/ascend/tutorials/hint/test_comment_hint.py new file mode 100644 index 0000000000..ecbc9cad89 --- /dev/null +++ b/third_party/ascend/tutorials/hint/test_comment_hint.py @@ -0,0 +1,172 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Comment Hint Test +================= + +Tests the #@hint: comment annotation mechanism for the Ascend backend. + +This verifies that: +1. #@hint:dot_pad_only_k on tl.load lines generates AnnotationOp with dot_pad_only_k attr in TTIR +2. #@hint:bind_sub_block on for loops generates bind_sub_block attr on scf.for in TTIR +""" + +import triton +import triton.language as tl +from triton._C.libtriton import ir, ascend +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import AscendBackend, NPUOptions, min_dot_size +from triton.backends.compiler import GPUTarget + + +# --------------------------------------------------------------------------- +# Kernel with #@hint:dot_pad_only_k on tl.load +# --------------------------------------------------------------------------- +@triton.jit +def matmul_hint_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_K)): #@hint:bind_sub_block + k_mask = offs_k < K - k * BLOCK_K + a = tl.load(a_ptrs, mask=offs_m[:, None] < M and k_mask[None, :], other=0.0) #@hint:dot_pad_only_k + b = tl.load(b_ptrs, mask=k_mask[:, None] and offs_n[None, :] < N, other=0.0) #@hint:dot_pad_only_k + acc = tl.dot(a, b, acc) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + c = acc.to(tl.float16) + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# --------------------------------------------------------------------------- +# Helper: compile kernel to TTIR string using the full backend pipeline +# --------------------------------------------------------------------------- +def get_ttir_str(kernel_fn, signature, constants): + # Use the ascend backend's compile flow which properly invokes + # the spec ASTSource.make_ir -> ascend ast_to_ttir with hint support + from triton.compiler.compiler import ASTSource + + src = ASTSource(kernel_fn, signature, constants) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + ascend.load_dialects(context) + + options = NPUOptions() + target = GPUTarget("npu", options.arch, 64) + backend = AscendBackend(target) + backend.load_dialects(context) + codegen_fns = backend.get_codegen_implementation() + module_map = backend.get_module_map() + + module = src.make_ir(options, codegen_fns, module_map, context) + return str(module) + + +# --------------------------------------------------------------------------- +# Test 1: Verify IR contains hint annotations +# --------------------------------------------------------------------------- +def test_ir_hint_annotations(): + print("=" * 60) + print("Test 1: Verify IR hint annotations") + print("=" * 60) + + signature = { + "a_ptr": "*fp16", + "b_ptr": "*fp16", + "c_ptr": "*fp16", + "M": "i32", + "N": "i32", + "K": "i32", + "stride_am": "i32", + "stride_ak": "i32", + "stride_bk": "i32", + "stride_bn": "i32", + "stride_cm": "i32", + "stride_cn": "i32", + } + constants = {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64} + + ttir_str = get_ttir_str(matmul_hint_kernel, signature, constants) + + # Check for dot_pad_only_k annotation in IR + has_dot_pad = "dot_pad_only_k" in ttir_str + # Check for bind_sub_block attribute on for op in IR + has_bind_sub = "bind_sub_block" in ttir_str + + print(f" dot_pad_only_k found in IR: {has_dot_pad}") + print(f" bind_sub_block found in IR: {has_bind_sub}") + + if has_dot_pad: + print(" [PASS] dot_pad_only_k hint correctly attached to IR") + else: + print(" [WARN] dot_pad_only_k not found in IR - hint may not have been processed") + + if has_bind_sub: + print(" [PASS] bind_sub_block hint correctly attached to IR") + else: + print(" [WARN] bind_sub_block not found in IR - hint may not have been processed") + + # Print a snippet of the IR for debugging + print("\n--- TTIR snippet (first 2000 chars) ---") + print(ttir_str[:2000]) + print("--- end ---\n") + + assert has_dot_pad, "dot_pad_only_k annotation not found in generated TTIR" + assert has_bind_sub, "bind_sub_block attribute not found in generated TTIR" + print(" [PASS] All IR hint checks passed\n") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +if __name__ == "__main__": + test_ir_hint_annotations() + print("All comment hint tests passed!") \ No newline at end of file