Skip to content

Commit 8081df7

Browse files
authored
Fix boolean all reduce bug (#1355)
1 parent 64bec4f commit 8081df7

File tree

4 files changed

+14
-3
lines changed

4 files changed

+14
-3
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
2424
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
2525

2626
if(NOT MLX_VERSION)
27-
set(MLX_VERSION 0.17.0)
27+
set(MLX_VERSION 0.17.1)
2828
endif()
2929

3030
# --------------------- Processor tests -------------------------

mlx/backend/metal/reduce.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,11 @@ void all_reduce_dispatch(
308308
compute_encoder.dispatchThreads(grid_dims, group_dims);
309309

310310
// 2nd pass
311-
compute_encoder->setComputePipelineState(kernel);
311+
std::ostringstream kname_2nd_pass;
312+
kname_2nd_pass << "all_reduce_" << op_name << type_to_name(intermediate);
313+
auto kernel_2nd_pass =
314+
get_reduce_kernel(d, kname_2nd_pass.str(), op_name, intermediate, out);
315+
compute_encoder->setComputePipelineState(kernel_2nd_pass);
312316
size_t intermediate_size = n_rows;
313317
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
314318
group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);

python/tests/test_reduce.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,13 @@ def test_edge_case(self):
124124
z = np.array(x).sum((0, 2, 3))
125125
self.assertTrue(np.all(z == y))
126126

127+
def test_sum_bool(self):
128+
x = np.random.uniform(0, 1, size=(10, 10, 10)) > 0.5
129+
y = mx.array(x)
130+
npsum = x.sum().item()
131+
mxsum = y.sum().item()
132+
self.assertEqual(npsum, mxsum)
133+
127134

128135
if __name__ == "__main__":
129136
unittest.main(failfast=True)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def run(self) -> None:
163163

164164
setup(
165165
name="mlx",
166-
version=get_version("0.17.0"),
166+
version=get_version("0.17.1"),
167167
author="MLX Contributors",
168168
author_email="[email protected]",
169169
description="A framework for machine learning on Apple silicon.",

0 commit comments

Comments
 (0)