Skip to content

Commit 1a9f820

Browse files
authored
Compiled should not end in broadcast (#2622)
1 parent d4f4ff3 commit 1a9f820

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

mlx/compile.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,11 @@ void compile_fuse(
727727
}
728728
};
729729

730-
if (arr.has_primitive()) {
730+
// This will be the result of the fused operation so it needs
731+
// a) to not be already computed ie have a primitive
732+
// b) that primitive to not be a broadcast since it will unnecessarily
733+
// cast to a contiguous array potentially blowing up memory
734+
if (arr.has_primitive() && !is_broadcast(arr.primitive())) {
731735
Stream s = arr.primitive().stream();
732736
recurse(arr, 0, s, arr.shape());
733737
}

0 commit comments

Comments
 (0)