@@ -53,6 +53,13 @@ def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP) -> UOp:
5353 # if a range has a 1 src, it's the same as UOp.const(dtypes.weakint, 0)
5454 return UOp .range (s , next (self .range_idx ), axistype ) if resolve (s != 1 ) else UOp .const (dtypes .weakint , 0 )
5555
56+ def tag_reduce_scope (ctx :IndexingContext , x :UOp , ret :UOp | None ) -> UOp | None :
57+ if ret is None or x not in ctx .range_map : return ret
58+ if scope := tuple (r .arg [0 ] for r in ctx .range_map [x ][0 ] if r .op is Ops .RANGE and r .arg [- 1 ] is AxisType .REDUCE ):
59+ ret = ret .replace (tag = scope )
60+ ctx .range_map [ret ] = ctx .range_map [x ]
61+ return ret
62+
5663def create_bufferize_and_index_based_on_ranges (ctx :IndexingContext , x :UOp ):
5764 if x .op in {Ops .STAGE , Ops .INDEX }: return None
5865 new_srcs = []
@@ -78,22 +85,22 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
7885 if x in ctx .range_map : new_src = new_src .index (* [r for i ,r in enumerate (ctx .range_map [x ][0 ]) if i in realized_ranges ])
7986 new_srcs .append (new_src )
8087 # NOTE: do we need this?
81- return x .replace (src = tns ) if x .src != (tns := tuple (new_srcs )) else None
88+ return tag_reduce_scope ( ctx , x , x .replace (src = tns ) if x .src != (tns := tuple (new_srcs )) else None )
8289
8390def convert_pad_to_where_to_keep_behavior_local (ctx :IndexingContext , x :UOp ):
8491 if x not in ctx .range_map : return None
8592 valid : UOp = UOp .const (dtypes .bool , True ).uprod ([r .get_valid () for r in ctx .range_map [x ][0 ]])
8693 ret = valid .where (x .src [0 ], UOp .const (x .dtype , 0 ))
8794 ctx .range_map [ret ] = ctx .range_map [x ]
88- return ret
95+ return tag_reduce_scope ( ctx , x , ret )
8996
9097def convert_reduce_to_reduce_with_ranges (ctx :IndexingContext , x :UOp ):
9198 if len (x .arg [1 ]) == 0 : return None
9299 # input ranges
93100 new_ranges = [r for i ,r in enumerate (ctx .range_map [x ][0 ]) if i in x .arg [1 ]]
94101 ret = UOp (Ops .REDUCE , x .dtype , src = (x .src [0 ],)+ tuple (new_ranges ), arg = (x .arg [0 ], ()))
95102 ctx .range_map [ret ] = ctx .range_map [x ]
96- return ret
103+ return tag_reduce_scope ( ctx , x , ret )
97104
98105def remove_movement_op_after_rangeify (ctx :IndexingContext , x :UOp ):
99106 if x in ctx .range_map or x .src [0 ].op is Ops .INDEX : return x .src [0 ]
0 commit comments