Skip to content

Commit fe3167d

Browse files
authored
smaller CPU binary (#1203)
* smaller CPU binary * fix no cpu build
1 parent 31e134b commit fe3167d

File tree

7 files changed

+168
-187
lines changed

7 files changed

+168
-187
lines changed

mlx/backend/common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ target_sources(
4646
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
4747
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
4848
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
49+
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
4950
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
5051
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
5152
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp

mlx/backend/common/binary.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,20 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
196196
}
197197
}
198198

199+
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
200+
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
201+
auto& in1 = inputs[0];
202+
auto& in2 = inputs[1];
203+
binary(in1, in2, out, detail::LogicalAnd());
204+
}
205+
206+
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
207+
assert(inputs.size() == 2); // LogicalOr requires two input arrays
208+
auto& in1 = inputs[0];
209+
auto& in2 = inputs[1];
210+
binary(in1, in2, out, detail::LogicalOr());
211+
}
212+
199213
void Maximum::eval(const std::vector<array>& inputs, array& out) {
200214
assert(inputs.size() == 2);
201215
auto& a = inputs[0];

mlx/backend/common/primitives.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#include "mlx/allocator.h"
1010
#include "mlx/backend/common/arange.h"
11-
#include "mlx/backend/common/binary.h"
1211
#include "mlx/backend/common/copy.h"
1312
#include "mlx/backend/common/ops.h"
1413
#include "mlx/backend/common/slicing.h"
@@ -314,20 +313,6 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
314313
unary(in, out, detail::LogicalNot());
315314
}
316315

317-
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
318-
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
319-
auto& in1 = inputs[0];
320-
auto& in2 = inputs[1];
321-
binary(in1, in2, out, detail::LogicalAnd());
322-
}
323-
324-
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
325-
assert(inputs.size() == 2); // LogicalOr requires two input arrays
326-
auto& in1 = inputs[0];
327-
auto& in2 = inputs[1];
328-
binary(in1, in2, out, detail::LogicalOr());
329-
}
330-
331316
void Negative::eval(const std::vector<array>& inputs, array& out) {
332317
assert(inputs.size() == 1);
333318
auto& in = inputs[0];

mlx/backend/common/reduce.cpp

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -104,48 +104,14 @@ void reduce_dispatch_out(
104104
}
105105
case Reduce::Sum: {
106106
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
107-
switch (out.dtype()) {
108-
case bool_:
109-
reduction_op<InT, bool>(in, out, axes, false, op);
110-
break;
111-
case uint8:
112-
reduction_op<InT, uint8_t>(in, out, axes, 0, op);
113-
break;
114-
case uint16:
115-
reduction_op<InT, uint16_t>(in, out, axes, 0, op);
116-
break;
117-
case uint32:
118-
reduction_op<InT, uint32_t>(in, out, axes, 0, op);
119-
break;
120-
case uint64:
121-
reduction_op<InT, uint64_t>(in, out, axes, 0, op);
122-
break;
123-
case int8:
124-
reduction_op<InT, int8_t>(in, out, axes, 0, op);
125-
break;
126-
case int16:
127-
reduction_op<InT, int16_t>(in, out, axes, 0, op);
128-
break;
129-
case int32:
130-
reduction_op<InT, int32_t>(in, out, axes, 0, op);
131-
break;
132-
case int64:
133-
reduction_op<InT, int64_t>(in, out, axes, 0, op);
134-
break;
135-
case float16:
136-
reduction_op<InT, float16_t>(in, out, axes, 0.0f, op);
137-
break;
138-
case float32:
139-
reduction_op<InT, float>(in, out, axes, 0.0f, op);
140-
break;
141-
case bfloat16:
142-
reduction_op<InT, bfloat16_t>(in, out, axes, 0.0f, op);
143-
break;
144-
case complex64:
145-
reduction_op<InT, complex64_t>(in, out, axes, complex64_t{0.0f}, op);
146-
break;
107+
if (out.dtype() == int32) {
108+
// special case since the input type can be bool
109+
reduction_op<InT, int32_t>(in, out, axes, 0, op);
110+
} else {
111+
reduction_op<InT, InT>(in, out, axes, 0, op);
147112
}
148-
} break;
113+
break;
114+
}
149115
case Reduce::Prod: {
150116
auto op = [](auto y, auto x) { (*y) *= x; };
151117
reduction_op<InT, InT>(in, out, axes, 1, op);
@@ -168,6 +134,29 @@ void reduce_dispatch_out(
168134

169135
} // namespace
170136

137+
void nd_loop(
138+
std::function<void(int)> callback,
139+
const std::vector<int>& shape,
140+
const std::vector<size_t>& strides) {
141+
std::function<void(int, int)> loop_inner;
142+
loop_inner = [&](int dim, int offset) {
143+
if (dim < shape.size() - 1) {
144+
int size = shape[dim];
145+
size_t stride = strides[dim];
146+
for (int i = 0; i < size; i++) {
147+
loop_inner(dim + 1, offset + i * stride);
148+
}
149+
} else {
150+
int size = shape[dim];
151+
size_t stride = strides[dim];
152+
for (int i = 0; i < size; i++) {
153+
callback(offset + i * stride);
154+
}
155+
}
156+
};
157+
loop_inner(0, 0);
158+
}
159+
171160
void Reduce::eval(const std::vector<array>& inputs, array& out) {
172161
assert(inputs.size() == 1);
173162
auto& in = inputs[0];

mlx/backend/common/reduce.h

Lines changed: 4 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -49,47 +49,18 @@ struct ReductionPlan {
4949
ReductionPlan(ReductionOpType type_) : type(type_) {}
5050
};
5151

52-
namespace {
52+
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
5353

5454
// Helper for the ndimensional strided loop
5555
// Should this be in utils?
56-
inline void nd_loop(
56+
void nd_loop(
5757
std::function<void(int)> callback,
5858
const std::vector<int>& shape,
59-
const std::vector<size_t>& strides) {
60-
std::function<void(int, int)> loop_inner;
61-
loop_inner = [&](int dim, int offset) {
62-
if (dim < shape.size() - 1) {
63-
int size = shape[dim];
64-
size_t stride = strides[dim];
65-
for (int i = 0; i < size; i++) {
66-
loop_inner(dim + 1, offset + i * stride);
67-
}
68-
} else {
69-
int size = shape[dim];
70-
size_t stride = strides[dim];
71-
for (int i = 0; i < size; i++) {
72-
callback(offset + i * stride);
73-
}
74-
}
75-
};
76-
loop_inner(0, 0);
77-
}
59+
const std::vector<size_t>& strides);
7860

7961
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
8062
const array& x,
81-
const std::vector<int>& axes) {
82-
std::vector<int> shape = x.shape();
83-
std::vector<size_t> strides = x.strides();
84-
85-
for (int i = axes.size() - 1; i >= 0; i--) {
86-
int a = axes[i];
87-
shape.erase(shape.begin() + a);
88-
strides.erase(strides.begin() + a);
89-
}
90-
91-
return std::make_pair(shape, strides);
92-
}
63+
const std::vector<int>& axes);
9364

9465
template <typename T, typename U, typename Op>
9566
struct DefaultStridedReduce {
@@ -123,102 +94,6 @@ struct DefaultContiguousReduce {
12394
}
12495
};
12596

126-
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
127-
// The data is all there and we are reducing over everything
128-
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
129-
x.flags().contiguous) {
130-
return ContiguousAllReduce;
131-
}
132-
133-
// Row contiguous input so the output is row contiguous
134-
if (x.flags().row_contiguous) {
135-
// Merge consecutive axes
136-
std::vector<int> shape = {x.shape(axes[0])};
137-
std::vector<size_t> strides = {x.strides()[axes[0]]};
138-
for (int i = 1; i < axes.size(); i++) {
139-
if (axes[i] - 1 == axes[i - 1]) {
140-
shape.back() *= x.shape(axes[i]);
141-
strides.back() = x.strides()[axes[i]];
142-
} else {
143-
shape.push_back(x.shape(axes[i]));
144-
strides.push_back(x.strides()[axes[i]]);
145-
}
146-
}
147-
148-
if (strides.back() == 1) {
149-
return ReductionPlan(ContiguousReduce, shape, strides);
150-
} else if (strides.back() > 1) {
151-
return ReductionPlan(ContiguousStridedReduce, shape, strides);
152-
}
153-
}
154-
155-
// Let's check if we can optimize our access patterns
156-
//
157-
// 1. We have a reduction axis with stride 1. Simply call
158-
// GeneralContiguousReduce and be done with it.
159-
// 2. We have transpositions and we are not reducing over the axis with
160-
// stride 1. However, we are reducing over an axis where everything is
161-
// contiguous in memory to the right of that axis. We can call strided
162-
// reduce and be done with it.
163-
// 2. We have weird transpositions and expands. Copy the strides to the
164-
// output, then call strided reduce.
165-
166-
// Sort reduction axes by stride in order to merge them and figure out if we
167-
// have a contiguous reduction.
168-
std::vector<std::pair<int, size_t>> reductions;
169-
for (auto a : axes) {
170-
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
171-
}
172-
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
173-
return a.second > b.second;
174-
});
175-
// Extract the two smallest and try to merge them in case the contiguous
176-
// reduction can be bigger than just the last axis.
177-
for (int i = reductions.size() - 1; i >= 1; i--) {
178-
auto a = reductions[i];
179-
auto b = reductions[i - 1];
180-
181-
// b.stride = a.shape * a.stride then a and b are contiguous
182-
if (b.second == a.first * a.second) {
183-
reductions.erase(reductions.begin() + i);
184-
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
185-
}
186-
}
187-
188-
std::vector<int> shape;
189-
std::vector<size_t> strides;
190-
for (auto r : reductions) {
191-
shape.push_back(r.first);
192-
strides.push_back(r.second);
193-
}
194-
195-
// We can call the contiguous reduction op for every weird way the input is
196-
// structured in the rest of the axes.
197-
if (strides.back() == 1) {
198-
return ReductionPlan(GeneralContiguousReduce, shape, strides);
199-
}
200-
201-
// Delegate to the general strided reduction op if the axes after
202-
// strides.back() are contiguous.
203-
if (strides.back() > 1) {
204-
int size = 1;
205-
for (int i = x.ndim() - 1; i >= 0; i--) {
206-
if (axes.back() == i) {
207-
continue;
208-
}
209-
if (x.strides()[i] != size) {
210-
break;
211-
}
212-
size *= x.shape(i);
213-
}
214-
if (size >= strides.back()) {
215-
return ReductionPlan(GeneralStridedReduce, shape, strides);
216-
}
217-
}
218-
219-
return ReductionPlan(GeneralReduce, shape, strides);
220-
}
221-
22297
template <typename T, typename U, typename OpS, typename OpC, typename Op>
22398
void reduction_op(
22499
const array& x,
@@ -361,6 +236,4 @@ void reduction_op(
361236
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
362237
}
363238

364-
} // namespace
365-
366239
} // namespace mlx::core

0 commit comments

Comments
 (0)