Skip to content

Commit dce4bd7

Browse files
authored
Add ArrayDesc destructor to avoid possible stack overflow (#982)
1 parent ffff671 commit dce4bd7

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

mlx/array.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,36 @@ array::ArrayDesc::ArrayDesc(
192192
init();
193193
}
194194

195+
array::ArrayDesc::~ArrayDesc() {
196+
// When an array description is destroyed it will delete a bunch of arrays
197+
// that may also destory their corresponding descriptions and so on and so
198+
// forth.
199+
//
200+
// This calls recursively the destructor and can result in stack overflow, we
201+
// instead put them in a vector and destroy them one at a time resulting in a
202+
// max stack depth of 2.
203+
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
204+
205+
for (array& a : inputs) {
206+
if (a.array_desc_.use_count() == 1) {
207+
for_deletion.push_back(std::move(a.array_desc_));
208+
}
209+
}
210+
211+
while (!for_deletion.empty()) {
212+
// top is going to be deleted at the end of the block *after* the arrays
213+
// with inputs have been moved into the vector
214+
auto top = std::move(for_deletion.back());
215+
for_deletion.pop_back();
216+
217+
for (array& a : top->inputs) {
218+
if (a.array_desc_.use_count() == 1) {
219+
for_deletion.push_back(std::move(a.array_desc_));
220+
}
221+
}
222+
}
223+
}
224+
195225
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
196226
: arr(arr), idx(idx) {
197227
if (arr.ndim() == 0) {

mlx/array.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,8 @@ class array {
404404
std::shared_ptr<Primitive> primitive,
405405
std::vector<array> inputs);
406406

407+
~ArrayDesc();
408+
407409
private:
408410
// Initialize size, strides, and other metadata
409411
void init();

0 commit comments

Comments
 (0)