@@ -202,7 +202,8 @@ def _inject_transform_library(user_script):
202202
203203 Two mechanisms:
204204 1. transform.include calls are expanded inline (parameter substitution,
205- SSA renaming) to avoid enableExpensiveChecks segfaults in mlir-air.
205+ SSA renaming) to avoid segfaults in mlir-air's transform interpreter
206+ when resolving transform.include across region boundaries.
206207 2. foreach_match @name symbol references are resolved by injecting the
207208 referenced named_sequence definitions into the module (these cannot
208209 be inlined because foreach_match resolves symbols at runtime).
@@ -218,32 +219,35 @@ def _inject_transform_library(user_script):
218219 if not has_includes and not has_foreach_match :
219220 return user_script
220221
221- lib_path = os .path .join (os .path .dirname (__file__ ), "transform_library.mlir" )
222- if not os .path .isfile (lib_path ):
222+ # Load library content from transform_library/ directory
223+ lib_dir = os .path .join (os .path .dirname (__file__ ), "transform_library" )
224+ if not os .path .isdir (lib_dir ):
223225 return user_script
224-
225- with open (lib_path , "r" ) as lib_f :
226- lib_content = lib_f .read ()
226+ parts = []
227+ for fname in sorted (os .listdir (lib_dir )):
228+ if fname .endswith (".mlir" ):
229+ with open (os .path .join (lib_dir , fname ), "r" ) as f :
230+ parts .append (f .read ())
231+ lib_content = "\n " .join (parts )
227232
228233 import re
229234
230- # Parse library: full sequence text (for injection) and decomposed parts (for inlining)
231- # Match sequences with {transform.readonly} param (standard) or {transform.consumed} param
235+ # Parse all named sequences: full text (for injection) and decomposed (for inlining)
232236 full_seq_pattern = re .compile (
233- r"((?://[^\n]*\n)*" # optional leading comments
234- r"transform\.named_sequence\s+@(\w+)\s*\([^)]*\)" # signature
235- r"(?:\s*->\s*!transform\.any_op)?" # optional return
236- r"\s*\{.*?\n\})" , # body
237+ r"((?://[^\n]*\n)*"
238+ r"transform\.named_sequence\s+@(\w+)\s*\([^)]*\)"
239+ r"(?:\s*->\s*!transform\.any_op)?"
240+ r"\s*\{.*?\n\})" ,
237241 re .DOTALL ,
238242 )
239- full_sequences = {} # name -> full text (for injection)
243+ full_sequences = {}
240244 for m in full_seq_pattern .finditer (lib_content ):
241245 full_sequences [m .group (2 )] = m .group (1 )
242246
243- # Parse inlinable sequences (readonly param only , for transform.include expansion )
247+ # Parse inlinable sequences (readonly or consumed param , for transform.include)
244248 inline_seq_pattern = re .compile (
245249 r"transform\.named_sequence\s+@(\w+)\s*\(\s*"
246- r"%(\w+)\s*:\s*!transform\.any_op\s*\{transform\.readonly\}\s*\)"
250+ r"%(\w+)\s*:\s*!transform\.any_op\s*\{transform\.(?: readonly|consumed) \}\s*\)"
247251 r"(\s*->\s*!transform\.any_op)?"
248252 r"\s*\{(.*?)\n\}" ,
249253 re .DOTALL ,
@@ -256,10 +260,10 @@ def _inject_transform_library(user_script):
256260 body = match .group (4 )
257261 sequences [name ] = (param , body , has_result )
258262
259- if not sequences :
263+ if not sequences and not full_sequences :
260264 return user_script
261265
262- # Match: [%result = ] transform.include @name failures(...) (%actual) : (...) -> (...)
266+ # Inline transform.include calls to avoid mlir-air segfaults
263267 include_pattern = re .compile (
264268 r"(?:(%\w+)\s*=\s*)?"
265269 r"transform\.include\s+@(\w+)\s+"
@@ -269,11 +273,10 @@ def _inject_transform_library(user_script):
269273 r"(?:!transform\.any_op|\(\s*\))"
270274 )
271275
272- max_depth = 20
273276 _counter = [0 ]
274277
275278 def _expand (text , depth = 0 ):
276- if depth > max_depth or "transform.include" not in text :
279+ if depth > 20 or "transform.include" not in text :
277280 return text
278281
279282 def _replace_include (m ):
@@ -297,13 +300,9 @@ def _replace_include(m):
297300 if result_var and yielded_var :
298301 expanded = expanded .replace (yielded_var , result_var )
299302
300- # Unique suffix to avoid SSA name collisions across expansions
301303 suffix = f"_lib{ _counter [0 ]} "
302304 _counter [0 ] += 1
303305 local_vars = set (re .findall (r"%(\w+)" , expanded ))
304- # Don't rename externally-scoped identifiers: the actual argument
305- # and the caller's result variable (if any) are defined outside
306- # the inlined body and must keep their original names.
307306 actual_name = actual_arg .lstrip ("%" )
308307 result_name = result_var .lstrip ("%" ) if result_var else ""
309308 skip = {actual_name , result_name , "__" , "" }
@@ -324,18 +323,22 @@ def _replace_include(m):
324323
325324 # Inject named sequences referenced by foreach_match (symbol references
326325 # that cannot be inlined — they must exist as definitions in the module).
327- if has_foreach_match :
328- # Extract foreach_match blocks (may span multiple lines)
329- # Match: transform.foreach_match ... @matcher -> @action ... : ...
330- foreach_refs = set ()
331- fm_pattern = re .compile (
332- r"transform\.foreach_match.*?(?=transform\.|$)" , re .DOTALL
333- )
334- for fm in fm_pattern .finditer (result ):
335- foreach_refs .update (re .findall (r"@(\w+)" , fm .group (0 )))
336- # Remove __transform_main and any non-library refs
337- foreach_refs .discard ("__transform_main" )
338- needed = {n for n in foreach_refs if n in full_sequences }
326+ if has_foreach_match or "foreach_match" in result :
327+ all_refs = set (re .findall (r"@(\w+)" , result ))
328+ all_refs .discard ("__transform_main" )
329+ # Transitively resolve dependencies
330+ needed = set ()
331+ worklist = [n for n in all_refs if n in full_sequences ]
332+ while worklist :
333+ name = worklist .pop ()
334+ if name in needed :
335+ continue
336+ needed .add (name )
337+ for dep in re .findall (r"@(\w+)" , full_sequences [name ]):
338+ if dep in full_sequences and dep not in needed :
339+ worklist .append (dep )
340+ # Inject definitions for all unresolved @name references
341+ # (matchers/actions referenced by foreach_match, plus their deps)
339342 if needed :
340343 module_marker = "module attributes {transform.with_named_sequence} {"
341344 idx = result .find (module_marker )
0 commit comments