|
5 | 5 | from triton.compiler.code_generator import _is_triton_value |
6 | 6 |
|
7 | 7 | class AipuHintHandler(BaseHintHandler): |
8 | | - |
| 8 | + # because aipu is diff from ascend in 2 aspects |
| 9 | + # 1. not backend_spec, modify triton src violently |
| 10 | + # 2. modify builder, semantic, core, and so on. pollute the src, which cant be involved in hint manager |
| 11 | + # for this, we just move changes in codegen & jit into hintmanager. |
9 | 12 |
|
10 | 13 | @staticmethod |
11 | | - def func1(code_generator, node, names, values): |
| 14 | + def get_node_hints(code_generator, node): |
12 | 15 | line_num = node.lineno |
13 | | - function_def = self.jit_fn.parse() |
| 16 | + function_def = code_generator.jit_fn.parse() |
14 | 17 | line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) |
15 | 18 | flagtree_hints = line_flagtree_hints.get(line_num) |
| 19 | + return flagtree_hints |
16 | 20 |
|
17 | | - def func2(): |
| 21 | + @staticmethod |
| 22 | + def inject_kwargs_with_hints(fn, flagtree_hints, line_num, kws): |
18 | 23 | if fn.__name__ == "load" and flagtree_hints is not None: |
19 | 24 | print(f"[FLAGTREE] tl.load at line {line_num} has annotation {flagtree_hints}") |
20 | 25 | if 'flagtree_hints' not in kws: |
21 | 26 | kws['flagtree_hints'] = "" |
22 | 27 | if flagtree_hints not in kws['flagtree_hints']: |
23 | 28 | kws['flagtree_hints'] = flagtree_hints |
24 | 29 |
|
25 | | - @staticmethod |
26 | | - def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): |
27 | | - if not (hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn')): |
28 | | - return |
29 | | - |
30 | | - if not flagtree_hints: |
31 | | - return |
32 | | - |
33 | | - # 3. AIPU 特有的 Hint 处理逻辑 |
34 | | - # [请对照 PR 修改此处字符串] 假设 AIPU 有一个 hint 叫 'global_memory_cache' 或者类似的 |
35 | | - target_hint_key = 'aipu_specific_hint_key' # <--- 请替换为 PR 里的真实字符串 |
36 | | - |
37 | | - if target_hint_key in flagtree_hints: |
38 | | - # 检查是否作用于 tl.load / tl.store |
39 | | - if (isinstance(node.value, ast.Call) and |
40 | | - isinstance(node.value.func, ast.Attribute) and |
41 | | - isinstance(node.value.func.value, ast.Name) and |
42 | | - node.value.func.value.id == 'tl'): |
43 | | - |
44 | | - # 例如:如果是 load 操作 |
45 | | - if node.value.func.attr == 'load': |
46 | | - for name, value in zip(names, values): |
47 | | - if _is_triton_value(value): |
48 | | - # 创建 AIPU 特有的 Annotation |
49 | | - # print(f"[FLAGTREE][AIPU] Applying {target_hint_key} to {name}") |
50 | | - hint_val = code_generator.builder.get_unit_attr() |
51 | | - # 'aipu.hint_name' 需要与后端 LLVM IR 处理逻辑对应 |
52 | | - code_generator.builder.create_annotation(value.handle, target_hint_key, hint_val) |
53 | | - |
54 | | - @staticmethod |
55 | | - def check_override_bind_sub_block(code_generator, node, bind_sub_block): |
56 | | - """ |
57 | | - 对应 CodeGenerator.visit_For 中决定是否开启 bind_sub_block 的逻辑 |
58 | | - """ |
59 | | - if not (hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn')): |
60 | | - return bind_sub_block |
61 | | - |
62 | | - line_num = node.lineno |
63 | | - function_def = code_generator.jit_fn.parse() |
64 | | - line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) |
65 | | - flagtree_hints = line_flagtree_hints.get(line_num) |
66 | | - |
67 | | - # 检查 AIPU 是否也支持通过 hint 强制开启/关闭 sub_block 绑定 |
68 | | - if flagtree_hints and 'bind_sub_block' in flagtree_hints: |
69 | | - # 如果 AIPU 后端支持此特性,则返回 True |
70 | | - return True |
71 | | - |
72 | | - return bind_sub_block |
73 | | - |
74 | 30 | @staticmethod |
75 | 31 | def maps_line_numbers_to_comment_hints(jit_fn): |
| 32 | + import tokenize |
| 33 | + from io import StringIO |
76 | 34 | # Maps line numbers to comment hints |
77 | 35 | line_flagtree_hints = {} |
78 | | - code_str = self.src |
| 36 | + code_str = jit_fn.src |
79 | 37 | g = tokenize.generate_tokens(StringIO(code_str).readline) |
80 | 38 | for tok_type, tok_text, start, end, _ in g: |
81 | 39 | if tok_type == tokenize.COMMENT: |
|
0 commit comments