Skip to content

Commit 10f748f

Browse files
starrryzmotinwing
andauthored
[HINT] Add Triton v3.3.x hint manager (#348)
* apply hint_manager in 3.3 aipu * update func attrs * fix code format problems * update hintmanager * fix hint manager import error --------- Co-authored-by: motinwing <858391970@qq.com>
1 parent 48f177a commit 10f748f

4 files changed

Lines changed: 196 additions & 24 deletions

File tree

python/triton/compiler/code_generator.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# ideally we wouldn't need any runtime component
1717
from ..runtime import JITFunction
1818
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
19+
from .hint_manager import hint_trigger
1920

2021
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
2122

@@ -1243,10 +1244,7 @@ def visit_Call(self, node):
12431244
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
12441245

12451246
# 4. Get current line number and hints
1246-
line_num = node.lineno
1247-
function_def = self.jit_fn.parse()
1248-
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
1249-
flagtree_hints = line_flagtree_hints.get(line_num)
1247+
flagtree_hints = hint_trigger("get_node_hints", self, node)
12501248

12511249
# 5. Handle JIT function calls
12521250
if isinstance(fn, JITFunction):
@@ -1261,12 +1259,7 @@ def visit_Call(self, node):
12611259
extra_kwargs['_generator'] = self
12621260
try:
12631261
# Special handling for tl.load with hints
1264-
if fn.__name__ == "load" and flagtree_hints is not None:
1265-
print(f"[FLAGTREE] tl.load at line {line_num} has annotation {flagtree_hints}")
1266-
if 'flagtree_hints' not in kws:
1267-
kws['flagtree_hints'] = ""
1268-
if flagtree_hints not in kws['flagtree_hints']:
1269-
kws['flagtree_hints'] = flagtree_hints
1262+
hint_trigger("inject_kwargs_with_hints", fn, flagtree_hints, node.lineno, kws)
12701263

12711264
ret = fn(*args, **extra_kwargs, **kws)
12721265
# builtin functions return plain tuples for readability
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import sys
2+
import importlib
3+
4+
5+
class BaseHintHandler:
6+
# dynamicly find method
7+
def trigger(self, hook_name, *args, **kwargs):
8+
if hasattr(self, hook_name):
9+
method = getattr(self, hook_name)
10+
if callable(method):
11+
try:
12+
return method(*args, **kwargs)
13+
14+
except TypeError as e:
15+
import inspect
16+
17+
try:
18+
sig = inspect.signature(method)
19+
expected = str(sig)
20+
except Exception:
21+
expected = "(unknown)"
22+
23+
actual_args = f"{len(args)} positional"
24+
actual_kwargs = f"keys={list(kwargs.keys())}" if kwargs else "no keywords"
25+
26+
print(f"\n[Hint Trigger Mismatch] {self.__class__.__name__}.{hook_name}")
27+
print(f" > Expect : {expected}")
28+
print(f" > Actual : {actual_args}, {actual_kwargs}")
29+
print(f" > Reason : {e}\n")
30+
31+
raise e
32+
return None
33+
34+
35+
class HintManager:
36+
37+
def __init__(self, backend_name):
38+
self.backend_name = backend_name
39+
# load Handler with backend name
40+
self.handler = self._load_handler(backend_name)
41+
42+
def _load_handler(self, backend):
43+
if backend == 'npu':
44+
try:
45+
module = importlib.import_module("triton.backends.ascend.ascend_hint_handler")
46+
return module.AscendHintHandler()
47+
except ImportError as e:
48+
print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr)
49+
return BaseHintHandler()
50+
elif backend == 'aipu':
51+
try:
52+
module = importlib.import_module("triton.backends.aipu.aipu_hint_handler")
53+
return module.AipuHintHandler()
54+
except ImportError as e:
55+
print(f"[FlagTree] Warning: Failed to load aipu Hint Handler: {e}", file=sys.stderr)
56+
return BaseHintHandler()
57+
elif backend == 'cuda':
58+
try:
59+
module = importlib.import_module("triton.backends.nvidia.nvidia_hint_handler")
60+
return module.NvidiaHintHandler()
61+
except ImportError as e:
62+
print(f"[FlagTree] Warning: Failed to load Nvidia Hint Handler: {e}", file=sys.stderr)
63+
return BaseHintHandler()
64+
else:
65+
return BaseHintHandler()
66+
67+
68+
# supported backend with matched version
69+
SUPPORTED_BACKENDS = ["aipu", "npu", "cuda"]
70+
71+
# TODO : npu will have conflicts if more backend involved
72+
# mapping name
73+
BACKEND_ALIASES = {
74+
"ascend": "npu",
75+
"huawei": "npu",
76+
"nvidia": "cuda",
77+
}
78+
79+
80+
def normalize_backend_name(name: str) -> str:
81+
if not name:
82+
return ""
83+
name = name.lower()
84+
return BACKEND_ALIASES.get(name, name)
85+
86+
87+
def hint_get_flagtree_backend() -> str:
88+
detected_backend = ""
89+
90+
import torch
91+
92+
# Priority 1: Triton Driver
93+
try:
94+
from triton.runtime import driver
95+
if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'):
96+
device = driver.active.get_active_torch_device()
97+
if isinstance(device, torch.device):
98+
detected_backend = device.type
99+
# unimplemented support
100+
elif isinstance(device, str):
101+
detected_backend = device
102+
except ImportError:
103+
pass
104+
105+
# TODO : some backend may not support priority 1, so keep priority 2 is necessary
106+
# Priority 2: Torch Global State
107+
if not detected_backend:
108+
check_priority = ["aipu", "npu", "cuda"]
109+
110+
# 3. parse according to benefit
111+
for candidate in check_priority:
112+
module = getattr(torch, candidate, None)
113+
if module and hasattr(module, "is_available") and module.is_available():
114+
detected_backend = candidate
115+
break
116+
117+
# (Normalization and Validation)
118+
canonical_backend = normalize_backend_name(detected_backend)
119+
120+
if not canonical_backend or canonical_backend not in SUPPORTED_BACKENDS:
121+
return ""
122+
123+
return canonical_backend
124+
125+
126+
# lazy load after first call hint trigger
127+
_global_hint_manager = None
128+
129+
130+
def hint_trigger(hook_name, *args, **kwargs):
131+
global _global_hint_manager
132+
133+
if _global_hint_manager is None:
134+
_global_hint_manager = HintManager(hint_get_flagtree_backend())
135+
return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs)

python/triton/runtime/jit.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from ..runtime.driver import driver
1313
from types import ModuleType
1414
from .._utils import find_paths_if, get_iterable_path
15-
import tokenize
16-
from io import StringIO
1715

1816
TRITON_MODULE = __name__[:-len(".runtime.jit")]
1917

@@ -705,26 +703,19 @@ def preload(self, specialization_data):
705703
# the user might want to monkey-patch self.src dynamically.
706704
# Our unit tests do this, for example.
707705
def parse(self):
706+
from ..compiler.hint_manager import hint_trigger
708707
# Maps line numbers to comment hints
709-
line_flagtree_hints = {}
710-
code_str = self.src
711-
g = tokenize.generate_tokens(StringIO(code_str).readline)
712-
for tok_type, tok_text, start, end, _ in g:
713-
if tok_type == tokenize.COMMENT:
714-
comment = tok_text.replace(" ", "").strip()
715-
if comment.startswith('#@hint:'):
716-
flagtree_hints = comment[len('#@hint:'):].strip()
717-
# Record the line number of the comment
718-
line_num = start[0]
719-
line_flagtree_hints[line_num] = flagtree_hints
708+
line_flagtree_hints = hint_trigger("maps_line_numbers_to_comment_hints", self)
709+
if line_flagtree_hints is None:
710+
line_flagtree_hints = {}
720711

721712
tree = ast.parse(self.src)
722713
assert isinstance(tree, ast.Module)
723714
assert len(tree.body) == 1
724715
assert isinstance(tree.body[0], ast.FunctionDef)
725716

726717
# Attach the line number to comment mapping to the function definition node
727-
tree.body[0].line_flagtree_hints = line_flagtree_hints
718+
hint_trigger("attach_line_number_to_comment_mapping", tree, line_flagtree_hints)
728719
return tree
729720

730721
def __call__(self, *args, **kwargs):
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# should store at third_party/aipu/backend/
2+
from triton.compiler.hint_manager import BaseHintHandler
3+
import triton.language as language
4+
import ast
5+
from triton.compiler.code_generator import _is_triton_value
6+
7+
8+
class AipuHintHandler(BaseHintHandler):
9+
# because aipu is diff from ascend in 2 aspects
10+
# 1. not backend_spec, modify triton src violently
11+
# 2. modify builder, semantic, core, and so on. pollute the src, which cant be involved in hint manager
12+
# for this, we just move changes in codegen & jit into hintmanager.
13+
14+
@staticmethod
15+
def get_node_hints(code_generator, node):
16+
line_num = node.lineno
17+
function_def = code_generator.jit_fn.parse()
18+
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
19+
flagtree_hints = line_flagtree_hints.get(line_num)
20+
return flagtree_hints
21+
22+
@staticmethod
23+
def inject_kwargs_with_hints(fn, flagtree_hints, line_num, kws):
24+
if fn.__name__ == "load" and flagtree_hints is not None:
25+
print(f"[FLAGTREE] tl.load at line {line_num} has annotation {flagtree_hints}")
26+
if 'flagtree_hints' not in kws:
27+
kws['flagtree_hints'] = ""
28+
if flagtree_hints not in kws['flagtree_hints']:
29+
kws['flagtree_hints'] = flagtree_hints
30+
31+
@staticmethod
32+
def maps_line_numbers_to_comment_hints(jit_fn):
33+
import tokenize
34+
from io import StringIO
35+
# Maps line numbers to comment hints
36+
line_flagtree_hints = {}
37+
code_str = jit_fn.src
38+
g = tokenize.generate_tokens(StringIO(code_str).readline)
39+
for tok_type, tok_text, start, end, _ in g:
40+
if tok_type == tokenize.COMMENT:
41+
comment = tok_text.replace(" ", "").strip()
42+
if comment.startswith('#@hint:'):
43+
flagtree_hints = comment[len('#@hint:'):].strip()
44+
# Record the line number of the comment
45+
line_num = start[0]
46+
line_flagtree_hints[line_num] = flagtree_hints
47+
48+
return line_flagtree_hints
49+
50+
@staticmethod
51+
def attach_line_number_to_comment_mapping(tree, line_flagtree_hints):
52+
if tree.body:
53+
tree.body[0].line_flagtree_hints = line_flagtree_hints

0 commit comments

Comments
 (0)