Skip to content

Commit e88e474

Browse files
authored
Reduce vmap + some fixes (#601)
1 parent 601c6d6 commit e88e474

File tree

5 files changed

+161
-33
lines changed

5 files changed

+161
-33
lines changed

mlx/ops.cpp

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ namespace {
1717

1818
std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
1919
const std::vector<int>& axes,
20-
const std::vector<int>& shape,
21-
bool keepdims) {
20+
const std::vector<int>& shape) {
2221
std::set<int> axes_set;
2322
auto ndim = shape.size();
2423
for (auto ax : axes) {
@@ -38,7 +37,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
3837
for (int i = 0; i < ndim; ++i) {
3938
if (axes_set.count(i) == 0) {
4039
out_shape.push_back(shape[i]);
41-
} else if (keepdims) {
40+
} else {
4241
out_shape.push_back(1);
4342
}
4443
}
@@ -1217,13 +1216,16 @@ array all(
12171216
if (axes.empty()) {
12181217
return astype(a, bool_, s);
12191218
}
1220-
auto [out_shape, sorted_axes] =
1221-
compute_reduce_shape(axes, a.shape(), keepdims);
1222-
return array(
1219+
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
1220+
auto out = array(
12231221
out_shape,
12241222
bool_,
12251223
std::make_unique<Reduce>(to_stream(s), Reduce::And, sorted_axes),
12261224
{a});
1225+
if (!keepdims) {
1226+
out = squeeze(out, sorted_axes, s);
1227+
}
1228+
return out;
12271229
}
12281230

12291231
array all(
@@ -1248,13 +1250,16 @@ array any(
12481250
if (axes.empty()) {
12491251
return astype(a, bool_, s);
12501252
}
1251-
auto [out_shape, sorted_axes] =
1252-
compute_reduce_shape(axes, a.shape(), keepdims);
1253-
return array(
1253+
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
1254+
auto out = array(
12541255
out_shape,
12551256
bool_,
12561257
std::make_unique<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
12571258
{a});
1259+
if (!keepdims) {
1260+
out = squeeze(out, sorted_axes, s);
1261+
}
1262+
return out;
12581263
}
12591264

12601265
array any(
@@ -1279,14 +1284,17 @@ array sum(
12791284
if (axes.empty()) {
12801285
return a;
12811286
}
1282-
auto [out_shape, sorted_axes] =
1283-
compute_reduce_shape(axes, a.shape(), keepdims);
1287+
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
12841288
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
1285-
return array(
1289+
auto out = array(
12861290
out_shape,
12871291
out_type,
12881292
std::make_unique<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
12891293
{a});
1294+
if (!keepdims) {
1295+
out = squeeze(out, sorted_axes, s);
1296+
}
1297+
return out;
12901298
}
12911299

12921300
array sum(
@@ -1374,13 +1382,16 @@ array prod(
13741382
if (axes.empty()) {
13751383
return a;
13761384
}
1377-
auto [out_shape, sorted_axes] =
1378-
compute_reduce_shape(axes, a.shape(), keepdims);
1379-
return array(
1385+
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
1386+
auto out = array(
13801387
out_shape,
13811388
a.dtype(),
13821389
std::make_unique<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
13831390
{a});
1391+
if (!keepdims) {
1392+
out = squeeze(out, sorted_axes, s);
1393+
}
1394+
return out;
13841395
}
13851396

13861397
array prod(
@@ -1408,13 +1419,16 @@ array max(
14081419
if (axes.empty()) {
14091420
return a;
14101421
}
1411-
auto [out_shape, sorted_axes] =
1412-
compute_reduce_shape(axes, a.shape(), keepdims);
1413-
return array(
1422+
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
1423+
auto out = array(
14141424
out_shape,
14151425
a.dtype(),
14161426
std::make_unique<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
14171427
{a});
1428+
if (!keepdims) {
1429+
out = squeeze(out, sorted_axes, s);
1430+
}
1431+
return out;
14181432
}
14191433

14201434
array max(
@@ -1442,13 +1456,16 @@ array min(
14421456
if (axes.empty()) {
14431457
return a;
14441458
}
1445-
auto [out_shape, sorted_axes] =
1446-
compute_reduce_shape(axes, a.shape(), keepdims);
1447-
return array(
1459+
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
1460+
auto out = array(
14481461
out_shape,
14491462
a.dtype(),
14501463
std::make_unique<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
14511464
{a});
1465+
if (!keepdims) {
1466+
out = squeeze(out, sorted_axes, s);
1467+
}
1468+
return out;
14521469
}
14531470

14541471
array min(
@@ -1477,14 +1494,17 @@ array argmin(
14771494
throw std::invalid_argument(
14781495
"[argmin] Cannot argmin reduce zero size array.");
14791496
}
1480-
auto [out_shape, sorted_axes] =
1481-
compute_reduce_shape({axis}, a.shape(), keepdims);
1482-
return array(
1497+
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
1498+
auto out = array(
14831499
out_shape,
14841500
uint32,
14851501
std::make_unique<ArgReduce>(
14861502
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
14871503
{a});
1504+
if (!keepdims) {
1505+
out = squeeze(out, sorted_axes, s);
1506+
}
1507+
return out;
14881508
}
14891509

14901510
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
@@ -1505,14 +1525,17 @@ array argmax(
15051525
throw std::invalid_argument(
15061526
"[argmax] Cannot argmax reduce zero size array.");
15071527
}
1508-
auto [out_shape, sorted_axes] =
1509-
compute_reduce_shape({axis}, a.shape(), keepdims);
1510-
return array(
1528+
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
1529+
auto out = array(
15111530
out_shape,
15121531
uint32,
15131532
std::make_unique<ArgReduce>(
15141533
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
15151534
{a});
1535+
if (!keepdims) {
1536+
out = squeeze(out, sorted_axes, s);
1537+
}
1538+
return out;
15161539
}
15171540

15181541
/** Returns a sorted copy of the flattened array. */

mlx/primitives.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// Copyright © 2023-2024 Apple Inc.
2-
32
#include <algorithm>
43
#include <cassert>
54
#include <cmath>
@@ -361,6 +360,20 @@ bool ArgReduce::is_equivalent(const Primitive& other) const {
361360
return reduce_type_ == r_other.reduce_type_ && axis_ == r_other.axis_;
362361
}
363362

363+
std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
364+
const std::vector<array>& inputs,
365+
const std::vector<int>& axes) {
366+
int reduce_ax = axis_ + (axis_ >= axes[0]);
367+
auto& in = inputs[0];
368+
std::vector<array> out;
369+
if (reduce_type_ == ArgReduce::ArgMin) {
370+
out.push_back(argmin(in, reduce_ax, true, stream()));
371+
} else {
372+
out.push_back(argmax(in, reduce_ax, true, stream()));
373+
}
374+
return {out, axes};
375+
}
376+
364377
std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
365378
const std::vector<array>& inputs,
366379
const std::vector<int>& axes) {
@@ -2153,7 +2166,36 @@ std::vector<array> Reduce::vjp(
21532166
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
21542167
const std::vector<array>& inputs,
21552168
const std::vector<int>& axes) {
2156-
throw std::runtime_error("Reduce::vmap not yet implemented.");
2169+
auto ax = axes[0];
2170+
auto reduce_axes = axes_;
2171+
for (auto& rax : reduce_axes) {
2172+
if (rax >= ax) {
2173+
rax++;
2174+
}
2175+
}
2176+
auto& in = inputs[0];
2177+
std::vector<array> out;
2178+
switch (reduce_type_) {
2179+
case Reduce::And:
2180+
out.push_back(all(in, reduce_axes, true, stream()));
2181+
break;
2182+
case Reduce::Or:
2183+
out.push_back(any(in, reduce_axes, true, stream()));
2184+
break;
2185+
case Reduce::Sum:
2186+
out.push_back(sum(in, reduce_axes, true, stream()));
2187+
break;
2188+
case Reduce::Prod:
2189+
out.push_back(prod(in, reduce_axes, true, stream()));
2190+
break;
2191+
case Reduce::Min:
2192+
out.push_back(min(in, reduce_axes, true, stream()));
2193+
break;
2194+
case Reduce::Max:
2195+
out.push_back(max(in, reduce_axes, true, stream()));
2196+
break;
2197+
}
2198+
return {out, axes};
21572199
}
21582200

21592201
bool Reduce::is_equivalent(const Primitive& other) const {

mlx/primitives.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ class ArgReduce : public UnaryPrimitive {
341341
void eval_cpu(const std::vector<array>& inputs, array& out) override;
342342
void eval_gpu(const std::vector<array>& inputs, array& out) override;
343343

344+
DEFINE_VMAP()
344345
DEFINE_PRINT(ArgReduce)
345346
bool is_equivalent(const Primitive& other) const override;
346347

mlx/transforms.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,9 +548,8 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
548548
"[vmap] The number of in axes must match the number of inputs.");
549549
}
550550

551-
// Run the function on placeholder inputs
552-
// to get the original graph
553-
std::vector<array> s_inputs;
551+
// Some error checking and get the vmap axis size
552+
size_t vmap_ax_size;
554553
for (int i = 0; i < inputs.size(); ++i) {
555554
if (in_axes[i] != -1) {
556555
if (inputs[i].ndim() == 0) {
@@ -563,7 +562,26 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
563562
<< inputs[i].ndim() << " dimensions.";
564563
throw std::invalid_argument(msg.str());
565564
}
565+
vmap_ax_size = inputs[i].shape(in_axes[i]);
566+
}
567+
}
568+
// Check that all vmapped axes have the same size
569+
for (int i = 0; i < inputs.size(); ++i) {
570+
if (in_axes[i] != -1) {
571+
if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) {
572+
std::ostringstream msg;
573+
msg << "[vmap] Inconsistent axis sizes: " << in_ax << " and "
574+
<< vmap_ax_size << ".";
575+
throw std::invalid_argument(msg.str());
576+
}
577+
}
578+
}
566579

580+
// Run the function on placeholder inputs
581+
// to get the original graph
582+
std::vector<array> s_inputs;
583+
for (int i = 0; i < inputs.size(); ++i) {
584+
if (in_axes[i] != -1) {
567585
std::vector<int> shape = inputs[i].shape();
568586
shape.erase(shape.begin() + in_axes[i]);
569587
array in(shape, inputs[i].dtype(), nullptr, {});

python/tests/test_vmap.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright © 2023 Apple Inc.
1+
# Copyright © 2023-2024 Apple Inc.
22

33
import unittest
44

@@ -220,6 +220,50 @@ def test_vmap_indexing(self):
220220
)
221221
self.assertTrue(mx.array_equal(out, expected))
222222

223+
def test_vmap_reduce(self):
224+
a = mx.ones((5, 5), mx.int32)
225+
out = mx.vmap(lambda x: x.sum())(a)
226+
self.assertTrue(mx.array_equal(out, mx.full((5,), 5)))
227+
228+
out = mx.vmap(lambda x: x.sum(keepdims=True))(a)
229+
self.assertTrue(mx.array_equal(out, mx.full((5, 1), 5)))
230+
231+
out = mx.vmap(lambda x: x.sum(axis=0))(a)
232+
self.assertTrue(mx.array_equal(out, mx.full((5,), 5)))
233+
234+
a = mx.ones((5, 3, 2), mx.int32)
235+
out = mx.vmap(lambda x: x.sum(axis=(0, 1)))(a)
236+
self.assertTrue(mx.array_equal(out, mx.full((5,), 6)))
237+
238+
a = mx.ones((5, 3, 2), mx.int32)
239+
out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(1,))(a)
240+
self.assertTrue(mx.array_equal(out, mx.full((3,), 10)))
241+
242+
a = mx.ones((5, 3, 2), mx.int32)
243+
out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(2,))(a)
244+
self.assertTrue(mx.array_equal(out, mx.full((2,), 15)))
245+
246+
def test_vmap_argreduce(self):
247+
a = mx.array([[1, 2, 3], [2, 3, 1]])
248+
out = mx.vmap(lambda x: mx.argmin(x))(a)
249+
expected = mx.array([0, 2])
250+
self.assertTrue(mx.array_equal(out, expected))
251+
252+
out = mx.vmap(lambda x: mx.argmax(x))(a)
253+
expected = mx.array([2, 1])
254+
self.assertTrue(mx.array_equal(out, expected))
255+
256+
def test_mismatch_input_sizes(self):
257+
a = mx.ones((10, 1))
258+
b = mx.ones((1, 1, 1, 5))
259+
260+
with self.assertRaises(ValueError):
261+
out = mx.vmap(lambda x, y: x + y)(a, b)
262+
263+
b = mx.ones((10, 5))
264+
with self.assertRaises(ValueError):
265+
out = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))(a, b)
266+
223267

224268
if __name__ == "__main__":
225269
unittest.main()

0 commit comments

Comments
 (0)