Skip to content

Commit 57c6aa7

Browse files
authored
fix multi output leak (#1548)
1 parent cde5b4a commit 57c6aa7

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

mlx/array.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ array::ArrayDesc::~ArrayDesc() {
271271
for (array& a : ad.inputs) {
272272
if (a.array_desc_) {
273273
input_map.insert({a.id(), a});
274+
for (auto& s : a.siblings()) {
275+
input_map.insert({s.id(), s});
276+
}
274277
}
275278
}
276279
ad.inputs.clear();

python/tests/test_array.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,6 +1771,19 @@ def fun():
17711771
peak_2 = mx.metal.get_peak_memory()
17721772
self.assertEqual(peak_1, peak_2)
17731773

1774+
def fun():
1775+
a = mx.array([1.0, 2.0, 3.0, 4.0])
1776+
b, _ = mx.divmod(a, a)
1777+
return mx.log(b)
1778+
1779+
fun()
1780+
mx.synchronize()
1781+
peak_1 = mx.metal.get_peak_memory()
1782+
fun()
1783+
mx.synchronize()
1784+
peak_2 = mx.metal.get_peak_memory()
1785+
self.assertEqual(peak_1, peak_2)
1786+
17741787
def test_add_numpy(self):
17751788
x = mx.array(1)
17761789
y = np.array(2, dtype=np.int32)

0 commit comments

Comments
 (0)