We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d4f4ff3 commit 1a9f820Copy full SHA for 1a9f820
mlx/compile.cpp
@@ -727,7 +727,11 @@ void compile_fuse(
727
}
728
};
729
730
- if (arr.has_primitive()) {
+ // This will be the result of the fused operation so it needs
731
+ // a) to not be already computed ie have a primitive
732
+ // b) that primitive to not be a broadcast since it will unnecessarily
733
+ // cast to a contiguous array potentially blowing up memory
734
+ if (arr.has_primitive() && !is_broadcast(arr.primitive())) {
735
Stream s = arr.primitive().stream();
736
recurse(arr, 0, s, arr.shape());
737
0 commit comments