Skip to content

Commit dedf841

Browse files
erwei-xilinxclaude
andcommitted
Split transform library into categorized files and add composable sequences
Replace the single transform_library.mlir with a transform_library/ directory containing 6 categorized files (canonicalization, elementwise, bufferization, air_mapping, vectorization, type_casting). Add 5 new composable sequences that eliminate remaining inline code in example scripts: - @vectorize_generics_at_16 / @vectorize_generics_at_32 - @cast_bf16_only_ops / @cast_cmpf_and_select_ops - @post_vectorize_reduce_cleanup All 16 elementwise example scripts now use only library calls with no inline code. Softmax scripts use @post_bufferize_cleanup and @post_vectorize_reduce_cleanup for their common tail phases. The driver.py inliner is simplified: library loading reads from the directory, and the injection logic now transitively resolves dependencies so library sequences that reference other library sequences work correctly. Tested: 14/15 aie2p examples pass (matvec fails due to pre-existing numerical bug unrelated to this change). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c9ffb4f commit dedf841

25 files changed

Lines changed: 454 additions & 544 deletions

amd_triton_npu/backend/driver.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def _inject_transform_library(user_script):
202202
203203
Two mechanisms:
204204
1. transform.include calls are expanded inline (parameter substitution,
205-
SSA renaming) to avoid enableExpensiveChecks segfaults in mlir-air.
205+
SSA renaming) to avoid segfaults in mlir-air's transform interpreter
206+
when resolving transform.include across region boundaries.
206207
2. foreach_match @name symbol references are resolved by injecting the
207208
referenced named_sequence definitions into the module (these cannot
208209
be inlined because foreach_match resolves symbols at runtime).
@@ -218,32 +219,35 @@ def _inject_transform_library(user_script):
218219
if not has_includes and not has_foreach_match:
219220
return user_script
220221

221-
lib_path = os.path.join(os.path.dirname(__file__), "transform_library.mlir")
222-
if not os.path.isfile(lib_path):
222+
# Load library content from transform_library/ directory
223+
lib_dir = os.path.join(os.path.dirname(__file__), "transform_library")
224+
if not os.path.isdir(lib_dir):
223225
return user_script
224-
225-
with open(lib_path, "r") as lib_f:
226-
lib_content = lib_f.read()
226+
parts = []
227+
for fname in sorted(os.listdir(lib_dir)):
228+
if fname.endswith(".mlir"):
229+
with open(os.path.join(lib_dir, fname), "r") as f:
230+
parts.append(f.read())
231+
lib_content = "\n".join(parts)
227232

228233
import re
229234

230-
# Parse library: full sequence text (for injection) and decomposed parts (for inlining)
231-
# Match sequences with {transform.readonly} param (standard) or {transform.consumed} param
235+
# Parse all named sequences: full text (for injection) and decomposed (for inlining)
232236
full_seq_pattern = re.compile(
233-
r"((?://[^\n]*\n)*" # optional leading comments
234-
r"transform\.named_sequence\s+@(\w+)\s*\([^)]*\)" # signature
235-
r"(?:\s*->\s*!transform\.any_op)?" # optional return
236-
r"\s*\{.*?\n\})", # body
237+
r"((?://[^\n]*\n)*"
238+
r"transform\.named_sequence\s+@(\w+)\s*\([^)]*\)"
239+
r"(?:\s*->\s*!transform\.any_op)?"
240+
r"\s*\{.*?\n\})",
237241
re.DOTALL,
238242
)
239-
full_sequences = {} # name -> full text (for injection)
243+
full_sequences = {}
240244
for m in full_seq_pattern.finditer(lib_content):
241245
full_sequences[m.group(2)] = m.group(1)
242246

243-
# Parse inlinable sequences (readonly param only, for transform.include expansion)
247+
# Parse inlinable sequences (readonly or consumed param, for transform.include)
244248
inline_seq_pattern = re.compile(
245249
r"transform\.named_sequence\s+@(\w+)\s*\(\s*"
246-
r"%(\w+)\s*:\s*!transform\.any_op\s*\{transform\.readonly\}\s*\)"
250+
r"%(\w+)\s*:\s*!transform\.any_op\s*\{transform\.(?:readonly|consumed)\}\s*\)"
247251
r"(\s*->\s*!transform\.any_op)?"
248252
r"\s*\{(.*?)\n\}",
249253
re.DOTALL,
@@ -256,10 +260,10 @@ def _inject_transform_library(user_script):
256260
body = match.group(4)
257261
sequences[name] = (param, body, has_result)
258262

259-
if not sequences:
263+
if not sequences and not full_sequences:
260264
return user_script
261265

262-
# Match: [%result = ] transform.include @name failures(...) (%actual) : (...) -> (...)
266+
# Inline transform.include calls to avoid mlir-air segfaults
263267
include_pattern = re.compile(
264268
r"(?:(%\w+)\s*=\s*)?"
265269
r"transform\.include\s+@(\w+)\s+"
@@ -269,11 +273,10 @@ def _inject_transform_library(user_script):
269273
r"(?:!transform\.any_op|\(\s*\))"
270274
)
271275

272-
max_depth = 20
273276
_counter = [0]
274277

275278
def _expand(text, depth=0):
276-
if depth > max_depth or "transform.include" not in text:
279+
if depth > 20 or "transform.include" not in text:
277280
return text
278281

279282
def _replace_include(m):
@@ -297,13 +300,9 @@ def _replace_include(m):
297300
if result_var and yielded_var:
298301
expanded = expanded.replace(yielded_var, result_var)
299302

300-
# Unique suffix to avoid SSA name collisions across expansions
301303
suffix = f"_lib{_counter[0]}"
302304
_counter[0] += 1
303305
local_vars = set(re.findall(r"%(\w+)", expanded))
304-
# Don't rename externally-scoped identifiers: the actual argument
305-
# and the caller's result variable (if any) are defined outside
306-
# the inlined body and must keep their original names.
307306
actual_name = actual_arg.lstrip("%")
308307
result_name = result_var.lstrip("%") if result_var else ""
309308
skip = {actual_name, result_name, "__", ""}
@@ -324,18 +323,22 @@ def _replace_include(m):
324323

325324
# Inject named sequences referenced by foreach_match (symbol references
326325
# that cannot be inlined — they must exist as definitions in the module).
327-
if has_foreach_match:
328-
# Extract foreach_match blocks (may span multiple lines)
329-
# Match: transform.foreach_match ... @matcher -> @action ... : ...
330-
foreach_refs = set()
331-
fm_pattern = re.compile(
332-
r"transform\.foreach_match.*?(?=transform\.|$)", re.DOTALL
333-
)
334-
for fm in fm_pattern.finditer(result):
335-
foreach_refs.update(re.findall(r"@(\w+)", fm.group(0)))
336-
# Remove __transform_main and any non-library refs
337-
foreach_refs.discard("__transform_main")
338-
needed = {n for n in foreach_refs if n in full_sequences}
326+
if has_foreach_match or "foreach_match" in result:
327+
all_refs = set(re.findall(r"@(\w+)", result))
328+
all_refs.discard("__transform_main")
329+
# Transitively resolve dependencies
330+
needed = set()
331+
worklist = [n for n in all_refs if n in full_sequences]
332+
while worklist:
333+
name = worklist.pop()
334+
if name in needed:
335+
continue
336+
needed.add(name)
337+
for dep in re.findall(r"@(\w+)", full_sequences[name]):
338+
if dep in full_sequences and dep not in needed:
339+
worklist.append(dep)
340+
# Inject definitions for all unresolved @name references
341+
# (matchers/actions referenced by foreach_match, plus their deps)
339342
if needed:
340343
module_marker = "module attributes {transform.with_named_sequence} {"
341344
idx = result.find(module_marker)

0 commit comments

Comments
 (0)