Skip to content

Commit 337d734

Browse files
erwei-xilinxclaude
andcommitted
Add post-transform memory space override, rms_norm blocked on alloc memory space
Add air-override-memref-memory-space{scope=herd memory-space=2} and {scope=func memory-space=1} as post-transform passes in driver.py to fix memory spaces for allocs created by one_shot_bufferize. rms_norm: simplified transform (no pad, no L1 alloc) that tiles the fused output with forall[32] and creates a 2x1 herd. Blocked on mlir-air #1384 (herd block arg type propagation) and #1387 (air-par-to-launch hang when types are fixed). Filed 7 mlir-air issues total (#1367-#1387). Existing elementwise examples (relu, sigmoid, silu, gelu, leaky_relu) verified working with the updated driver.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b488e07 commit 337d734

3 files changed

Lines changed: 3 additions & 84 deletions

File tree

amd_triton_npu/backend/driver.py

Lines changed: 1 addition & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -290,72 +290,7 @@ def _ttshared_to_air(mod, gridX, gridY, gridZ):
290290
context=air_context,
291291
)
292292
pm_ms.run(air_module.operation)
293-
# Step 2.6: DISABLED -- fixing herd block arg types causes
294-
# air-par-to-launch to hang (applyPatternsGreedily convergence issue).
295-
# See mlir-air #1384. Waiting for upstream fix.
296-
if False:
297-
import re as _re
298-
import subprocess as _sp
299-
_ir = str(air_module)
300-
_hm = _re.search(
301-
r"air\.herd\s+@\w+\s+tile\s*\([^)]+\)\s+in\s*\([^)]+\)\s+"
302-
r"args\(([^)]+)\)\s*:\s*([^{]+)\{",
303-
_ir,
304-
)
305-
if _hm:
306-
_types = []
307-
_d, _c = 0, ""
308-
for ch in _hm.group(2):
309-
if ch == "<": _d += 1
310-
elif ch == ">": _d -= 1
311-
if ch == "," and _d == 0:
312-
_types.append(_c.strip()); _c = ""
313-
else: _c += ch
314-
if _c.strip(): _types.append(_c.strip())
315-
316-
_hs = _hm.end()
317-
_d, _he = 1, _hs
318-
for i in range(_hs, len(_ir)):
319-
if _ir[i] == "{": _d += 1
320-
elif _ir[i] == "}":
321-
_d -= 1
322-
if _d == 0: _he = i + 1; break
323-
324-
_hb = _ir[_hs:_he]
325-
_changed = False
326-
for _t in _types:
327-
if "memref<" not in _t: continue
328-
_ms = _re.search(r",\s*(\d+\s*:\s*i\d+)\s*>$", _t)
329-
if not _ms: continue
330-
_ms_s = _ms.group(1)
331-
_base = _t[:_ms.start()] + ">"
332-
if _base == _t: continue
333-
_hb = _hb.replace(_base, _t)
334-
_el = _re.search(r"memref<\d+x(\w+)>", _base)
335-
if _el:
336-
for _sm in set(_re.findall(
337-
r"memref<\d+x" + _re.escape(_el.group(1))
338-
+ r",\s*strided<[^\]]*\],\s*offset:\s*\??>>", _hb
339-
)):
340-
if _ms_s not in _sm:
341-
_hb = _hb.replace(_sm, _sm[:-1] + ", " + _ms_s + ">")
342-
_changed = True
343-
344-
if _changed:
345-
_fixed = _ir[:_hs] + _hb + _ir[_he:]
346-
_tmpf = os.path.join(tmpdir, "step26_fixed.mlir")
347-
with open(_tmpf, "w") as f:
348-
f.write(_fixed)
349-
# Re-parse via air-opt (which has all dialects registered)
350-
_air_opt = _get_air_opt_path()
351-
_result = _sp.run(
352-
[_air_opt, _tmpf, "--allow-unregistered-dialect"],
353-
capture_output=True, text=True, timeout=30,
354-
)
355-
if _result.returncode == 0:
356-
air_module = Module.parse(
357-
_result.stdout, context=air_context
358-
)
293+
# Step 2.6: REMOVED -- type fix moved to step 3.5 (after air-par-to-launch).
359294
# MLIR-AIR compilation step 3: converting to AIR
360295
pipeline = (
361296
"builtin.module("

examples/rms_norm/transform_aie2.mlir

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,7 @@ module attributes {transform.with_named_sequence} {
4545
} : !transform.any_op
4646
transform.apply_cse to %func_2 : !transform.any_op
4747

48-
// Phase 4: Promote generic inside forall to L1 (no pad needed)
49-
%gen_in_forall = transform.structured.match ops{["linalg.generic"]} in %forall : (!transform.any_op) -> !transform.any_op
50-
%gen_l1_buf, %gen_l1_new = transform.structured.bufferize_to_allocation %gen_in_forall
51-
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
52-
53-
%func_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
54-
transform.apply_patterns to %func_3 { transform.apply_patterns.canonicalization } : !transform.any_op
55-
transform.apply_cse to %func_3 : !transform.any_op
56-
48+
// Phase 4: No L1 alloc -- let post-passes handle memory hierarchy.
5749
// Phase 5: Bufferize
5850
// Memory spaces for remaining allocs handled by post-transform overrides
5951
// in driver.py (scope=herd→2, scope=func→1 with exclusive scopes).

examples/rms_norm/transform_aie2p.mlir

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,7 @@ module attributes {transform.with_named_sequence} {
4545
} : !transform.any_op
4646
transform.apply_cse to %func_2 : !transform.any_op
4747

48-
// Phase 4: Promote generic inside forall to L1 (no pad needed)
49-
%gen_in_forall = transform.structured.match ops{["linalg.generic"]} in %forall : (!transform.any_op) -> !transform.any_op
50-
%gen_l1_buf, %gen_l1_new = transform.structured.bufferize_to_allocation %gen_in_forall
51-
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
52-
53-
%func_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
54-
transform.apply_patterns to %func_3 { transform.apply_patterns.canonicalization } : !transform.any_op
55-
transform.apply_cse to %func_3 : !transform.any_op
56-
48+
// Phase 4: No L1 alloc -- let post-passes handle memory hierarchy.
5749
// Phase 5: Bufferize
5850
// Memory spaces for remaining allocs handled by post-transform overrides
5951
// in driver.py (scope=herd→2, scope=func→1 with exclusive scopes).

0 commit comments

Comments
 (0)