Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 164 additions & 1 deletion amd_triton_npu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,165 @@ def detect_npu_version():
raise RuntimeError("Unsupported or unrecognized NPU device found.")


def _inject_transform_library(user_script):
"""
Process library references in a user transform script.

Two mechanisms:
1. transform.include calls are expanded inline (parameter substitution,
SSA renaming) to avoid enableExpensiveChecks segfaults in mlir-air.
2. foreach_match @name symbol references are resolved by injecting the
referenced named_sequence definitions into the module (these cannot
be inlined because foreach_match resolves symbols at runtime).

Args:
user_script: The user's transform script as a string

Returns:
str: The processed script
"""
has_includes = "transform.include" in user_script
has_foreach_match = "foreach_match" in user_script
if not has_includes and not has_foreach_match:
return user_script

lib_path = os.path.join(os.path.dirname(__file__), "transform_library.mlir")
if not os.path.isfile(lib_path):
return user_script

with open(lib_path, "r") as lib_f:
lib_content = lib_f.read()

import re

# Parse library: full sequence text (for injection) and decomposed parts (for inlining)
# Match sequences with {transform.readonly} param (standard) or {transform.consumed} param
full_seq_pattern = re.compile(
r"((?://[^\n]*\n)*" # optional leading comments
r"transform\.named_sequence\s+@(\w+)\s*\([^)]*\)" # signature
r"(?:\s*->\s*!transform\.any_op)?" # optional return
r"\s*\{.*?\n\})", # body
re.DOTALL,
)
full_sequences = {} # name -> full text (for injection)
for m in full_seq_pattern.finditer(lib_content):
full_sequences[m.group(2)] = m.group(1)

# Parse inlinable sequences (readonly param only, for transform.include expansion)
inline_seq_pattern = re.compile(
r"transform\.named_sequence\s+@(\w+)\s*\(\s*"
r"%(\w+)\s*:\s*!transform\.any_op\s*\{transform\.readonly\}\s*\)"
r"(\s*->\s*!transform\.any_op)?"
r"\s*\{(.*?)\n\}",
re.DOTALL,
)
sequences = {}
for match in inline_seq_pattern.finditer(lib_content):
name = match.group(1)
param = match.group(2)
has_result = match.group(3) is not None
body = match.group(4)
sequences[name] = (param, body, has_result)

if not sequences:
return user_script

# Match: [%result = ] transform.include @name failures(...) (%actual) : (...) -> (...)
include_pattern = re.compile(
r"(?:(%\w+)\s*=\s*)?"
r"transform\.include\s+@(\w+)\s+"
r"failures\(\w+\)\s*"
r"\((%\w+)\)\s*"
r":\s*\(!transform\.any_op\)\s*->\s*"
r"(?:!transform\.any_op|\(\s*\))"
)

max_depth = 20
_counter = [0]

def _expand(text, depth=0):
if depth > max_depth or "transform.include" not in text:
return text

def _replace_include(m):
result_var = m.group(1)
seq_name = m.group(2)
actual_arg = m.group(3)

if seq_name not in sequences:
return m.group(0)

param, body, has_result = sequences[seq_name]
expanded = body.replace(f"%{param}", actual_arg)

yield_match = re.search(
r"transform\.yield(?:\s+(%\w+)\s*:\s*!transform\.any_op)?",
expanded,
)
if yield_match:
yielded_var = yield_match.group(1)
expanded = expanded[: yield_match.start()].rstrip()
if result_var and yielded_var:
expanded = expanded.replace(yielded_var, result_var)

# Unique suffix to avoid SSA name collisions across expansions
suffix = f"_lib{_counter[0]}"
_counter[0] += 1
local_vars = set(re.findall(r"%(\w+)", expanded))
# Don't rename externally-scoped identifiers: the actual argument
# and the caller's result variable (if any) are defined outside
# the inlined body and must keep their original names.
actual_name = actual_arg.lstrip("%")
result_name = result_var.lstrip("%") if result_var else ""
skip = {actual_name, result_name, "__", ""}
for var in local_vars:
if var not in skip and not var.startswith("_lib"):
expanded = re.sub(
rf"(?<!\w)%{re.escape(var)}(?!\w)",
f"%{var}{suffix}",
expanded,
)
Comment thread
erwei-xilinx marked this conversation as resolved.

return expanded

text = include_pattern.sub(_replace_include, text)
return _expand(text, depth + 1)

result = _expand(user_script) if has_includes else user_script

# Inject named sequences referenced by foreach_match (symbol references
# that cannot be inlined — they must exist as definitions in the module).
if has_foreach_match:
# Extract foreach_match blocks (may span multiple lines)
# Match: transform.foreach_match ... @matcher -> @action ... : ...
foreach_refs = set()
fm_pattern = re.compile(
r"transform\.foreach_match.*?(?=transform\.|$)", re.DOTALL
)
for fm in fm_pattern.finditer(result):
foreach_refs.update(re.findall(r"@(\w+)", fm.group(0)))
# Remove __transform_main and any non-library refs
foreach_refs.discard("__transform_main")
needed = {n for n in foreach_refs if n in full_sequences}
if needed:
module_marker = "module attributes {transform.with_named_sequence} {"
idx = result.find(module_marker)
if idx != -1:
insert_pos = idx + len(module_marker)
injection = "\n\n".join(
full_sequences[n] for n in full_sequences if n in needed
)
result = (
result[:insert_pos]
+ "\n\n"
+ injection
+ "\n\n"
+ result[insert_pos:]
)

return result


def _get_transform_ir_string():
"""
Get the transform IR string for tiling operations.
Expand All @@ -204,6 +363,9 @@ def _get_transform_ir_string():
read the transform IR from that file. Otherwise, use the default
hardcoded transform IR string.

If the script uses `transform.include`, the shared transform library
(transform_library.mlir) is automatically injected.

Returns:
str: The transform IR string to use for tiling
"""
Expand All @@ -218,7 +380,8 @@ def _get_transform_ir_string():
)
with open(custom_script_path, "r") as f:
print(f"Using custom tiling script from: {custom_script_path}")
return f.read()
user_script = f.read()
return _inject_transform_library(user_script)

# Default hardcoded transform IR string
matmul_tiling_size_l1_m = 32
Expand Down
Loading
Loading