@@ -49,47 +49,18 @@ struct ReductionPlan {
4949 ReductionPlan (ReductionOpType type_) : type(type_) {}
5050};
5151
52- namespace {
52+ ReductionPlan get_reduction_plan ( const array& x, const std::vector< int > axes);
5353
5454// Helper for the ndimensional strided loop
5555// Should this be in utils?
56- inline void nd_loop (
56+ void nd_loop (
5757 std::function<void (int )> callback,
5858 const std::vector<int>& shape,
59- const std::vector<size_t>& strides) {
60- std::function<void (int , int )> loop_inner;
61- loop_inner = [&](int dim, int offset) {
62- if (dim < shape.size () - 1 ) {
63- int size = shape[dim];
64- size_t stride = strides[dim];
65- for (int i = 0 ; i < size; i++) {
66- loop_inner (dim + 1 , offset + i * stride);
67- }
68- } else {
69- int size = shape[dim];
70- size_t stride = strides[dim];
71- for (int i = 0 ; i < size; i++) {
72- callback (offset + i * stride);
73- }
74- }
75- };
76- loop_inner (0 , 0 );
77- }
59+ const std::vector<size_t>& strides);
7860
7961std::pair<std::vector<int >, std::vector<size_t >> shapes_without_reduction_axes (
8062 const array& x,
81- const std::vector<int >& axes) {
82- std::vector<int > shape = x.shape ();
83- std::vector<size_t > strides = x.strides ();
84-
85- for (int i = axes.size () - 1 ; i >= 0 ; i--) {
86- int a = axes[i];
87- shape.erase (shape.begin () + a);
88- strides.erase (strides.begin () + a);
89- }
90-
91- return std::make_pair (shape, strides);
92- }
63+ const std::vector<int >& axes);
9364
9465template <typename T, typename U, typename Op>
9566struct DefaultStridedReduce {
@@ -123,102 +94,6 @@ struct DefaultContiguousReduce {
12394 }
12495};
12596
126- ReductionPlan get_reduction_plan (const array& x, const std::vector<int > axes) {
127- // The data is all there and we are reducing over everything
128- if (x.size () == x.data_size () && axes.size () == x.ndim () &&
129- x.flags ().contiguous ) {
130- return ContiguousAllReduce;
131- }
132-
133- // Row contiguous input so the output is row contiguous
134- if (x.flags ().row_contiguous ) {
135- // Merge consecutive axes
136- std::vector<int > shape = {x.shape (axes[0 ])};
137- std::vector<size_t > strides = {x.strides ()[axes[0 ]]};
138- for (int i = 1 ; i < axes.size (); i++) {
139- if (axes[i] - 1 == axes[i - 1 ]) {
140- shape.back () *= x.shape (axes[i]);
141- strides.back () = x.strides ()[axes[i]];
142- } else {
143- shape.push_back (x.shape (axes[i]));
144- strides.push_back (x.strides ()[axes[i]]);
145- }
146- }
147-
148- if (strides.back () == 1 ) {
149- return ReductionPlan (ContiguousReduce, shape, strides);
150- } else if (strides.back () > 1 ) {
151- return ReductionPlan (ContiguousStridedReduce, shape, strides);
152- }
153- }
154-
155- // Let's check if we can optimize our access patterns
156- //
157- // 1. We have a reduction axis with stride 1. Simply call
158- // GeneralContiguousReduce and be done with it.
159- // 2. We have transpositions and we are not reducing over the axis with
160- // stride 1. However, we are reducing over an axis where everything is
161- // contiguous in memory to the right of that axis. We can call strided
162- // reduce and be done with it.
163- // 2. We have weird transpositions and expands. Copy the strides to the
164- // output, then call strided reduce.
165-
166- // Sort reduction axes by stride in order to merge them and figure out if we
167- // have a contiguous reduction.
168- std::vector<std::pair<int , size_t >> reductions;
169- for (auto a : axes) {
170- reductions.push_back (std::make_pair (x.shape (a), x.strides ()[a]));
171- }
172- std::sort (reductions.begin (), reductions.end (), [](auto a, auto b) {
173- return a.second > b.second ;
174- });
175- // Extract the two smallest and try to merge them in case the contiguous
176- // reduction can be bigger than just the last axis.
177- for (int i = reductions.size () - 1 ; i >= 1 ; i--) {
178- auto a = reductions[i];
179- auto b = reductions[i - 1 ];
180-
181- // b.stride = a.shape * a.stride then a and b are contiguous
182- if (b.second == a.first * a.second ) {
183- reductions.erase (reductions.begin () + i);
184- reductions[i - 1 ] = std::make_pair (a.first * b.first , a.second );
185- }
186- }
187-
188- std::vector<int > shape;
189- std::vector<size_t > strides;
190- for (auto r : reductions) {
191- shape.push_back (r.first );
192- strides.push_back (r.second );
193- }
194-
195- // We can call the contiguous reduction op for every weird way the input is
196- // structured in the rest of the axes.
197- if (strides.back () == 1 ) {
198- return ReductionPlan (GeneralContiguousReduce, shape, strides);
199- }
200-
201- // Delegate to the general strided reduction op if the axes after
202- // strides.back() are contiguous.
203- if (strides.back () > 1 ) {
204- int size = 1 ;
205- for (int i = x.ndim () - 1 ; i >= 0 ; i--) {
206- if (axes.back () == i) {
207- continue ;
208- }
209- if (x.strides ()[i] != size) {
210- break ;
211- }
212- size *= x.shape (i);
213- }
214- if (size >= strides.back ()) {
215- return ReductionPlan (GeneralStridedReduce, shape, strides);
216- }
217- }
218-
219- return ReductionPlan (GeneralReduce, shape, strides);
220- }
221-
22297template <typename T, typename U, typename OpS, typename OpC, typename Op>
22398void reduction_op (
22499 const array& x,
@@ -361,6 +236,4 @@ void reduction_op(
361236 reduction_op<T, U>(x, out, axes, init, ops, opc, op);
362237}
363238
364- } // namespace
365-
366239} // namespace mlx::core
0 commit comments