Skip to content

Commit 2059fe0

Browse files
committed
update func attrs
1 parent 2343cd9 commit 2059fe0

3 files changed

Lines changed: 17 additions & 57 deletions

File tree

python/triton/compiler/code_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,7 @@ def visit_Call(self, node):
12431243
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
12441244

12451245
# 4. Get current line number and hints
1246-
hint_trigger("func1", self, node, names, values)
1246+
flagtree_hints = hint_trigger("get_node_hints", self, node)
12471247

12481248
# 5. Handle JIT function calls
12491249
if isinstance(fn, JITFunction):
@@ -1258,7 +1258,7 @@ def visit_Call(self, node):
12581258
extra_kwargs['_generator'] = self
12591259
try:
12601260
# Special handling for tl.load with hints
1261-
hint_trigger("func2", self, node, names, values)
1261+
hint_trigger("inject_kwargs_with_hints", fn, flagtree_hints, node.lineno, kws)
12621262

12631263
ret = fn(*args, **extra_kwargs, **kws)
12641264
# builtin functions return plain tuples for readability

python/triton/runtime/jit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,9 @@ def preload(self, specialization_data):
706706
# Our unit tests do this, for example.
707707
def parse(self):
708708
# Maps line numbers to comment hints
709-
hint_trigger("maps_line_numbers_to_comment_hints", self)
709+
line_flagtree_hints = hint_trigger("maps_line_numbers_to_comment_hints", self)
710+
if line_flagtree_hints is None:
711+
line_flagtree_hints = {}
710712

711713
tree = ast.parse(self.src)
712714
assert isinstance(tree, ast.Module)

third_party/aipu/backend/aipu_hint_handler.py

Lines changed: 12 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,77 +5,35 @@
55
from triton.compiler.code_generator import _is_triton_value
66

77
class AipuHintHandler(BaseHintHandler):
8-
8+
# because aipu is diff from ascend in 2 aspects
9+
# 1. not backend_spec, modify triton src violently
10+
# 2. modify builder, semantic, core, and so on. pollute the src, which cant be involved in hint manager
11+
# for this, we just move changes in codegen & jit into hintmanager.
912

1013
@staticmethod
11-
def func1(code_generator, node, names, values):
14+
def get_node_hints(code_generator, node):
1215
line_num = node.lineno
13-
function_def = self.jit_fn.parse()
16+
function_def = code_generator.jit_fn.parse()
1417
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
1518
flagtree_hints = line_flagtree_hints.get(line_num)
19+
return flagtree_hints
1620

17-
def func2():
21+
@staticmethod
22+
def inject_kwargs_with_hints(fn, flagtree_hints, line_num, kws):
1823
if fn.__name__ == "load" and flagtree_hints is not None:
1924
print(f"[FLAGTREE] tl.load at line {line_num} has annotation {flagtree_hints}")
2025
if 'flagtree_hints' not in kws:
2126
kws['flagtree_hints'] = ""
2227
if flagtree_hints not in kws['flagtree_hints']:
2328
kws['flagtree_hints'] = flagtree_hints
2429

25-
@staticmethod
26-
def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values):
27-
if not (hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn')):
28-
return
29-
30-
if not flagtree_hints:
31-
return
32-
33-
# 3. AIPU 特有的 Hint 处理逻辑
34-
# [请对照 PR 修改此处字符串] 假设 AIPU 有一个 hint 叫 'global_memory_cache' 或者类似的
35-
target_hint_key = 'aipu_specific_hint_key' # <--- 请替换为 PR 里的真实字符串
36-
37-
if target_hint_key in flagtree_hints:
38-
# 检查是否作用于 tl.load / tl.store
39-
if (isinstance(node.value, ast.Call) and
40-
isinstance(node.value.func, ast.Attribute) and
41-
isinstance(node.value.func.value, ast.Name) and
42-
node.value.func.value.id == 'tl'):
43-
44-
# 例如:如果是 load 操作
45-
if node.value.func.attr == 'load':
46-
for name, value in zip(names, values):
47-
if _is_triton_value(value):
48-
# 创建 AIPU 特有的 Annotation
49-
# print(f"[FLAGTREE][AIPU] Applying {target_hint_key} to {name}")
50-
hint_val = code_generator.builder.get_unit_attr()
51-
# 'aipu.hint_name' 需要与后端 LLVM IR 处理逻辑对应
52-
code_generator.builder.create_annotation(value.handle, target_hint_key, hint_val)
53-
54-
@staticmethod
55-
def check_override_bind_sub_block(code_generator, node, bind_sub_block):
56-
"""
57-
对应 CodeGenerator.visit_For 中决定是否开启 bind_sub_block 的逻辑
58-
"""
59-
if not (hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn')):
60-
return bind_sub_block
61-
62-
line_num = node.lineno
63-
function_def = code_generator.jit_fn.parse()
64-
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
65-
flagtree_hints = line_flagtree_hints.get(line_num)
66-
67-
# 检查 AIPU 是否也支持通过 hint 强制开启/关闭 sub_block 绑定
68-
if flagtree_hints and 'bind_sub_block' in flagtree_hints:
69-
# 如果 AIPU 后端支持此特性,则返回 True
70-
return True
71-
72-
return bind_sub_block
73-
7430
@staticmethod
7531
def maps_line_numbers_to_comment_hints(jit_fn):
32+
import tokenize
33+
from io import StringIO
7634
# Maps line numbers to comment hints
7735
line_flagtree_hints = {}
78-
code_str = self.src
36+
code_str = jit_fn.src
7937
g = tokenize.generate_tokens(StringIO(code_str).readline)
8038
for tok_type, tok_text, start, end, _ in g:
8139
if tok_type == tokenize.COMMENT:

0 commit comments

Comments
 (0)