Skip to content

Commit 105d236

Browse files
authored
Add vmap for SVD and inverse (#849)
1 parent 53e6a93 commit 105d236

File tree

7 files changed

+116
-5
lines changed

7 files changed

+116
-5
lines changed

mlx/backend/common/inverse.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "mlx/allocator.h"
44
#include "mlx/backend/common/copy.h"
5+
#include "mlx/linalg.h"
56
#include "mlx/primitives.h"
67

78
#ifdef ACCELERATE_NEW_LAPACK
@@ -92,4 +93,12 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
9293
inverse_impl(inputs[0], output);
9394
}
9495

96+
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
97+
const std::vector<array>& inputs,
98+
const std::vector<int>& axes) {
99+
auto ax = axes[0] >= 0 ? 0 : -1;
100+
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
101+
return {{linalg::inv(a, stream())}, {ax}};
102+
}
103+
95104
} // namespace mlx::core

mlx/backend/common/svd.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "mlx/allocator.h"
44
#include "mlx/backend/common/copy.h"
55
#include "mlx/backend/common/lapack_helper.h"
6+
#include "mlx/linalg.h"
67
#include "mlx/primitives.h"
78

89
namespace mlx::core {
@@ -144,4 +145,12 @@ void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
144145
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
145146
}
146147

148+
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
149+
const std::vector<array>& inputs,
150+
const std::vector<int>& axes) {
151+
auto ax = axes[0] >= 0 ? 0 : -1;
152+
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
153+
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
154+
}
155+
147156
} // namespace mlx::core

mlx/primitives.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ std::pair<std::vector<array>, std::vector<int>> Equal::vmap(
11271127
const std::vector<array>& inputs,
11281128
const std::vector<int>& axes) {
11291129
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
1130-
return {{equal(a, b, stream())}, axes};
1130+
return {{equal(a, b, stream())}, {to_ax}};
11311131
}
11321132

11331133
std::vector<array> Equal::vjp(
@@ -1468,7 +1468,7 @@ std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
14681468
const std::vector<array>& inputs,
14691469
const std::vector<int>& axes) {
14701470
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
1471-
return {{greater(a, b, stream())}, axes};
1471+
return {{greater(a, b, stream())}, {to_ax}};
14721472
}
14731473

14741474
std::vector<array> Greater::vjp(
@@ -1495,7 +1495,7 @@ std::pair<std::vector<array>, std::vector<int>> GreaterEqual::vmap(
14951495
const std::vector<array>& inputs,
14961496
const std::vector<int>& axes) {
14971497
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
1498-
return {{greater_equal(a, b, stream())}, axes};
1498+
return {{greater_equal(a, b, stream())}, {to_ax}};
14991499
}
15001500

15011501
std::vector<array> GreaterEqual::vjp(
@@ -1522,7 +1522,7 @@ std::pair<std::vector<array>, std::vector<int>> Less::vmap(
15221522
const std::vector<array>& inputs,
15231523
const std::vector<int>& axes) {
15241524
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
1525-
return {{less(a, b, stream())}, axes};
1525+
return {{less(a, b, stream())}, {to_ax}};
15261526
}
15271527

15281528
std::vector<array> Less::vjp(
@@ -1549,7 +1549,7 @@ std::pair<std::vector<array>, std::vector<int>> LessEqual::vmap(
15491549
const std::vector<array>& inputs,
15501550
const std::vector<int>& axes) {
15511551
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
1552-
return {{less_equal(a, b, stream())}, axes};
1552+
return {{less_equal(a, b, stream())}, {to_ax}};
15531553
}
15541554

15551555
std::vector<array> LessEqual::vjp(

mlx/primitives.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,6 +1929,7 @@ class SVD : public Primitive {
19291929
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
19301930
override;
19311931

1932+
DEFINE_VMAP()
19321933
DEFINE_PRINT(SVD)
19331934

19341935
private:
@@ -1943,6 +1944,7 @@ class Inverse : public UnaryPrimitive {
19431944
void eval_cpu(const std::vector<array>& inputs, array& output) override;
19441945
void eval_gpu(const std::vector<array>& inputs, array& output) override;
19451946

1947+
DEFINE_VMAP()
19461948
DEFINE_PRINT(Inverse)
19471949

19481950
private:

mlx/transforms.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,7 @@ std::vector<array> vmap_replace(
655655
}
656656

657657
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
658+
658659
// For each primitive's outputs add its id, the vout id and the vax
659660
auto outputs = a.outputs();
660661
for (int i = 0; i < v_outputs.size(); ++i) {

python/tests/test_vmap.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,64 @@ def test_vmap_matmul(self):
314314
expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0))
315315
self.assertTrue(mx.allclose(out, expected))
316316

317+
def test_vmap_svd(self):
318+
a = mx.random.uniform(shape=(3, 4, 2))
319+
320+
cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu)
321+
322+
# Vmap over the first axis (this is already supported natively by the primitive).
323+
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a)
324+
self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))
325+
self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))
326+
self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))
327+
328+
for i in range(a.shape[0]):
329+
M = a[i]
330+
U, S, Vt = Us[i], Ss[i], Vts[i]
331+
self.assertTrue(
332+
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
333+
)
334+
335+
# Vmap over the second axis.
336+
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a)
337+
self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))
338+
self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))
339+
self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))
340+
341+
for i in range(a.shape[1]):
342+
M = a[:, i, :]
343+
U, S, Vt = Us[i], Ss[i], Vts[i]
344+
self.assertTrue(
345+
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
346+
)
347+
348+
def test_vmap_inverse(self):
349+
a = mx.random.uniform(shape=(3, 4, 4))
350+
351+
cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)
352+
353+
# Vmap over the first axis (this is already supported natively by the primitive).
354+
invs = mx.vmap(cpu_inv, in_axes=(0,))(a)
355+
356+
for i in range(a.shape[0]):
357+
self.assertTrue(
358+
mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=0, atol=1e-5)
359+
)
360+
361+
a = mx.random.uniform(shape=(4, 3, 4))
362+
363+
# Without vmapping, each input matrix is not square.
364+
with self.assertRaises(ValueError):
365+
mx.eval(cpu_inv(a))
366+
367+
# Vmap over the second axis.
368+
invs = mx.vmap(cpu_inv, in_axes=(1,))(a)
369+
370+
for i in range(a.shape[1]):
371+
self.assertTrue(
372+
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
373+
)
374+
317375

318376
if __name__ == "__main__":
319377
unittest.main()

tests/vmap_tests.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,35 @@ TEST_CASE("test vmap gather") {
413413
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
414414
}
415415
}
416+
417+
TEST_CASE("test vmap SVD") {
418+
auto fun = [](std::vector<array> inputs) {
419+
return linalg::svd(inputs.at(0), Device::cpu);
420+
};
421+
422+
auto a = astype(reshape(arange(24), {3, 4, 2}), float32);
423+
424+
// vmap over the second axis.
425+
{
426+
auto out = vmap(fun, /* in_axes = */ {1})({a});
427+
const auto& U = out.at(0);
428+
const auto& S = out.at(1);
429+
const auto& Vt = out.at(2);
430+
431+
CHECK_EQ(U.shape(), std::vector<int>{a.shape(1), a.shape(0), a.shape(0)});
432+
CHECK_EQ(S.shape(), std::vector<int>{a.shape(1), a.shape(2)});
433+
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(1), a.shape(2), a.shape(2)});
434+
}
435+
436+
// vmap over the third axis.
437+
{
438+
auto out = vmap(fun, /* in_axes = */ {2})({a});
439+
const auto& U = out.at(0);
440+
const auto& S = out.at(1);
441+
const auto& Vt = out.at(2);
442+
443+
CHECK_EQ(U.shape(), std::vector<int>{a.shape(2), a.shape(0), a.shape(0)});
444+
CHECK_EQ(S.shape(), std::vector<int>{a.shape(2), a.shape(0)});
445+
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(2), a.shape(1), a.shape(1)});
446+
}
447+
}

0 commit comments

Comments
 (0)