@@ -19,12 +19,14 @@ class _LoweringContext:
1919 params: Input parameter names (all live in HBM).
2020 buffers: Variable name to buffer location string.
2121 aliases: Maps accumulation output names to canonical PSUM variable.
22+ alias_offsets: Maps alias names to their start offsets per axis.
2223 staging_counter: Monotonic counter for staging variable names.
2324 """
2425
2526 params : tuple [str , ...]
2627 buffers : dict [str , str ] = field (default_factory = dict )
2728 aliases : dict [str , str ] = field (default_factory = dict )
29+ alias_offsets : dict [str , tuple [int , ...]] = field (default_factory = dict )
2830 staging_counter : int = 0
2931
3032 def resolve (self , name : str ) -> str :
@@ -40,6 +42,26 @@ def resolve(self, name: str) -> str:
4042 name = self .aliases [name ]
4143 return name
4244
45+ def _resolve_offsets (self , name : str ) -> tuple [int , ...]:
46+ """Accumulate start offsets along the alias chain.
47+
48+ Args:
49+ name: Variable name, possibly an accumulation alias.
50+
51+ Returns:
52+ Tuple of accumulated start offsets per axis.
53+ """
54+ offsets : list [int ] = []
55+ while name in self .aliases :
56+ entry_offsets = self .alias_offsets .get (name , ())
57+ if not offsets :
58+ offsets = list (entry_offsets )
59+ else :
60+ for i , o in enumerate (entry_offsets ):
61+ offsets [i ] += o
62+ name = self .aliases [name ]
63+ return tuple (offsets )
64+
4365 def buffer_of (self , name : str ) -> str :
4466 """Look up the buffer location of a variable, resolving aliases.
4567
@@ -57,8 +79,9 @@ def buffer_of(self, name: str) -> str:
5779 def subscript (self , ref : TensorRef ) -> str :
5880 """Render a TensorRef as ``name[s:e, s:e]``, resolving aliases.
5981
60- Unconditionally renders slices from the IR. The IR is the
61- source of truth — no shape comparison or optimization.
82+ When the name resolves through an alias chain, accumulates
83+ start offsets and composes them with the ref slices so the
84+ subscript points at the correct region of the canonical buffer.
6285
6386 Args:
6487 ref: Tensor reference.
@@ -67,25 +90,61 @@ def subscript(self, ref: TensorRef) -> str:
6790 Subscripted string or plain resolved name.
6891 """
6992 resolved = self .resolve (ref .name )
93+ offsets = self ._resolve_offsets (ref .name )
7094 result = resolved
7195 if ref .slices :
72- parts = ", " . join ( f" { s } : { e } " for s , e in ref .slices )
96+ parts = _compose_slices ( ref .slices , offsets )
7397 result = f"{ resolved } [{ parts } ]"
7498 return result
7599
76100
101+ def _compose_slices (slices : tuple [tuple [int , int ], ...], offsets : tuple [int , ...]) -> str :
102+ """Compose ref slices with alias offsets into a subscript string.
103+
104+ Args:
105+ slices: Per-axis (start, stop) bounds from the TensorRef.
106+ offsets: Per-axis start offsets from the alias chain.
107+
108+ Returns:
109+ Comma-separated ``s:e`` subscript string.
110+ """
111+ parts : list [str ] = []
112+ for i , (s , e ) in enumerate (slices ):
113+ offset = offsets [i ] if i < len (offsets ) else 0
114+ parts .append (f"{ s + offset } :{ e + offset } " )
115+ return ", " .join (parts )
116+
117+
77118def get_kwarg (stmt : GymStatement , key : str ) -> object :
78119 """Extract a keyword argument value from a statement.
79120
121+ Asserts that kwargs contain no duplicate keys, since duplicates
122+ indicate an IR construction bug upstream.
123+
80124 Args:
81125 stmt: GymStatement to search.
82126 key: Keyword argument name.
83127
84128 Returns:
85129 The value if found, None otherwise.
86130 """
131+ _assert_no_duplicate_kwargs (stmt )
87132 result = None
88133 for k , v in stmt .kwargs :
89134 if k == key :
90135 result = v
136+ break
91137 return result
138+
139+
140+ def _assert_no_duplicate_kwargs (stmt : GymStatement ) -> None :
141+ """Assert that a statement has no duplicate keyword argument names.
142+
143+ Args:
144+ stmt: GymStatement to check.
145+
146+ Raises:
147+ AssertionError: If duplicate kwarg keys are found.
148+ """
149+ keys = [k for k , _ in stmt .kwargs ]
150+ assert len (keys ) == len (set (keys )), f"Duplicate kwargs in { stmt .op } stmt '{ stmt .output .name } ': { keys } "
0 commit comments