Skip to content

Commit 5b0af4c

Browse files
authored
fix donation condition for compilation (#1237)
1 parent 8c2e15e commit 5b0af4c

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

mlx/backend/common/compiled.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,8 @@ void compiled_allocate_outputs(
205205
// - Donatable
206206
// - Correct size
207207
// - Not a constant
208-
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
209-
in.is_donatable() &&
208+
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
209+
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
210210
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
211211
if (move_buffers) {
212212
outputs[o].move_shared_buffer(

python/tests/test_compile.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,18 @@ def fn(x):
707707
x = mx.array([0, float("inf"), 1], dtype=mx.bfloat16)
708708
self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x)))
709709

710+
def test_max_into_equal(self):
711+
x = mx.random.uniform(shape=(1, 2, 2))
712+
mx.eval(x)
713+
714+
def fn():
715+
maxes = mx.max(x, axis=(1, 2), keepdims=True)
716+
return x == maxes
717+
718+
out = mx.compile(fn)()
719+
expected = fn()
720+
self.assertTrue(mx.array_equal(expected, out))
721+
710722

711723
if __name__ == "__main__":
712724
unittest.main()

0 commit comments

Comments
 (0)