Skip to content

Commit 252dc29

Browse files
committed
IMAGE hack for tinygrad#16335 (from tinygrad#16343)
Scheduling an IMAGE=1 FLOAT16=1 model fails with 'END src[0] should be KERNEL, not Ops.STORE' (tinygrad#16335). Carry chenyu's closed fix (tinygrad#16343, 'really a bug in pm_apply_rangeify') until a proper fix lands upstream.
1 parent a359644 commit 252dc29

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

tinygrad/schedule/indexing.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP) -> UOp:
5353
# if a range has a 1 src, it's the same as UOp.const(dtypes.weakint, 0)
5454
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.weakint, 0)
5555

56+
def tag_reduce_scope(ctx:IndexingContext, x:UOp, ret:UOp|None) -> UOp|None:
57+
if ret is None or x not in ctx.range_map: return ret
58+
if scope:=tuple(r.arg[0] for r in ctx.range_map[x][0] if r.op is Ops.RANGE and r.arg[-1] is AxisType.REDUCE):
59+
ret = ret.replace(tag=scope)
60+
ctx.range_map[ret] = ctx.range_map[x]
61+
return ret
62+
5663
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
5764
if x.op in {Ops.STAGE, Ops.INDEX}: return None
5865
new_srcs = []
@@ -78,22 +85,22 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
7885
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
7986
new_srcs.append(new_src)
8087
# NOTE: do we need this?
81-
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
88+
return tag_reduce_scope(ctx, x, x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None)
8289

8390
def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp):
8491
if x not in ctx.range_map: return None
8592
valid: UOp = UOp.const(dtypes.bool, True).uprod([r.get_valid() for r in ctx.range_map[x][0]])
8693
ret = valid.where(x.src[0], UOp.const(x.dtype, 0))
8794
ctx.range_map[ret] = ctx.range_map[x]
88-
return ret
95+
return tag_reduce_scope(ctx, x, ret)
8996

9097
def convert_reduce_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
9198
if len(x.arg[1]) == 0: return None
9299
# input ranges
93100
new_ranges = [r for i,r in enumerate(ctx.range_map[x][0]) if i in x.arg[1]]
94101
ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=(x.arg[0], ()))
95102
ctx.range_map[ret] = ctx.range_map[x]
96-
return ret
103+
return tag_reduce_scope(ctx, x, ret)
97104

98105
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
99106
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]

0 commit comments

Comments
 (0)