Skip to content

Commit 54f05e7

Browse files
authored
Fix gather vmap (#1563)
* fix gather * fix
1 parent 26be608 commit 54f05e7

File tree

2 files changed

+83
-29
lines changed

2 files changed

+83
-29
lines changed

mlx/primitives.cpp

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// Copyright © 2023-2024 Apple Inc.
2-
32
#include <algorithm>
43
#include <cassert>
54
#include <cmath>
@@ -1683,48 +1682,58 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
16831682
auto gather_axes = axes_;
16841683
auto slice_sizes = slice_sizes_;
16851684
auto src_vmapped = axes[0] >= 0;
1686-
auto indices_vmapped =
1687-
std::any_of(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });
1688-
auto out_ax =
1689-
*std::find_if(axes.begin(), axes.end(), [](int a) { return a >= 0; });
1685+
auto ind_vmap_ax_ptr =
1686+
std::find_if(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });
1687+
int out_ax = -1;
1688+
bool indices_vmapped = (ind_vmap_ax_ptr != axes.end());
1689+
if (indices_vmapped) {
1690+
out_ax = *ind_vmap_ax_ptr;
1691+
} else if (src_vmapped) {
1692+
out_ax = axes[0];
1693+
}
16901694

16911695
// Reorder all the index arrays so the vmap axis is in the same spot.
1692-
for (int i = 1; i < axes.size(); ++i) {
1693-
if (out_ax != axes[i] && axes[i] >= 0) {
1694-
indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());
1696+
if (indices_vmapped) {
1697+
for (int i = 1; i < axes.size(); ++i) {
1698+
if (out_ax != axes[i] && axes[i] >= 0) {
1699+
indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());
1700+
} else if (axes[i] < 0) {
1701+
indices[i - 1] = expand_dims(indices[i - 1], out_ax, stream());
1702+
}
16951703
}
16961704
}
16971705

1706+
int idx_dims = indices.empty() ? 0 : indices[0].ndim();
1707+
16981708
if (src_vmapped) {
1699-
int max_dims = 0;
1700-
for (auto& idx : indices) {
1701-
max_dims = std::max(static_cast<int>(idx.ndim()), max_dims);
1702-
}
1703-
auto new_ax_loc =
1704-
std::find_if(gather_axes.begin(), gather_axes.end(), [&out_ax](int a) {
1705-
return a >= out_ax;
1706-
});
1707-
for (; new_ax_loc < gather_axes.end(); new_ax_loc++) {
1708-
(*new_ax_loc)++;
1709+
for (auto& ax : gather_axes) {
1710+
if (ax >= axes[0]) {
1711+
ax++;
1712+
}
17091713
}
17101714
if (indices_vmapped) {
17111715
// Make a new index array for the vmapped dimension
1716+
auto vmap_inds = arange(0, src.shape(axes[0]), stream());
17121717
// Reshape it so it broadcasts with other index arrays
1718+
{
1719+
auto shape = std::vector<int>(idx_dims, 1);
1720+
shape[out_ax] = vmap_inds.size();
1721+
vmap_inds = reshape(vmap_inds, std::move(shape), stream());
1722+
}
17131723
// Update gather axes and slice sizes accordingly
1714-
auto shape = std::vector<int>(max_dims - out_ax, 1);
1715-
auto vmap_inds = arange(0, src.shape(out_ax), stream());
1716-
shape[0] = vmap_inds.shape(0);
1717-
vmap_inds = reshape(vmap_inds, shape, stream());
1718-
slice_sizes.insert(slice_sizes.begin() + out_ax, 1);
1719-
auto new_ax_idx = new_ax_loc - gather_axes.begin();
1720-
gather_axes.insert(new_ax_loc, out_ax);
1721-
indices.insert(indices.begin() + new_ax_idx, vmap_inds);
1724+
slice_sizes.insert(slice_sizes.begin() + axes[0], 1);
1725+
gather_axes.push_back(axes[0]);
1726+
indices.push_back(vmap_inds);
17221727
} else {
1723-
slice_sizes.insert(slice_sizes.begin() + axes[0], src.shape(axes[0]));
1724-
out_ax = max_dims + axes[0];
1728+
slice_sizes.insert(slice_sizes.begin() + out_ax, src.shape(out_ax));
1729+
out_ax += idx_dims;
17251730
}
17261731
}
1727-
return {{gather(src, indices, gather_axes, slice_sizes, stream())}, {out_ax}};
1732+
auto out = gather(src, indices, gather_axes, slice_sizes, stream());
1733+
if (src_vmapped && indices_vmapped) {
1734+
out = squeeze(out, idx_dims + axes[0], stream());
1735+
}
1736+
return {{out}, {out_ax}};
17281737
}
17291738

17301739
std::vector<array> Gather::vjp(

python/tests/test_vmap.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,51 @@ def test_vmap_inverse(self):
370370
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
371371
)
372372

373+
def test_vmap_gather(self):
374+
def gather(a, idx):
375+
return a[idx]
376+
377+
a = mx.array([[1, 2], [3, 4]])
378+
idx = mx.array(0)
379+
out = mx.vmap(gather, (0, None))(a, idx)
380+
self.assertTrue(mx.array_equal(out, mx.array([1, 3])))
381+
382+
out = mx.vmap(gather, (1, None))(a, idx)
383+
self.assertTrue(mx.array_equal(out, mx.array([1, 2])))
384+
385+
idx = mx.array([0, 1])
386+
out = mx.vmap(gather, (0, 0))(a, idx)
387+
self.assertTrue(mx.array_equal(out, mx.array([1, 4])))
388+
389+
a = mx.ones((2, 3, 4))
390+
idx = mx.zeros(4, mx.int32)
391+
out = mx.vmap(gather, (2, 0))(a, idx)
392+
self.assertEqual(out.shape, (4, 3))
393+
394+
f = mx.vmap(gather, (0, None))
395+
f = mx.vmap(gather, (0, 0))
396+
out = f(mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32))
397+
self.assertEqual(out.shape, (2, 4))
398+
399+
def gather(a, idxa, idxb):
400+
return a[idxa, idxb]
401+
402+
a = mx.ones((2, 3, 4))
403+
idxa = mx.zeros((2, 3), mx.int32)
404+
idxb = mx.zeros(3, mx.int32)
405+
out = mx.vmap(gather, (0, 0, None))(a, idxa, idxb)
406+
self.assertEqual(out.shape, (2, 3))
407+
408+
idxa = mx.zeros((3, 1, 2), mx.int32)
409+
idxb = mx.zeros((2, 3, 1, 2), mx.int32)
410+
out = mx.vmap(gather, (0, None, 0))(a, idxa, idxb)
411+
self.assertEqual(out.shape, (2, 3, 1, 2))
412+
413+
idxa = mx.zeros((3, 1, 2), mx.int32)
414+
idxb = mx.zeros((3, 1, 2, 2), mx.int32)
415+
out = mx.vmap(gather, (0, None, 3))(a, idxa, idxb)
416+
self.assertEqual(out.shape, (2, 3, 1, 2))
417+
373418
def test_vmap_scatter(self):
374419
def scatter(a):
375420
a[mx.array(0)] = mx.array(0.0)

0 commit comments

Comments
 (0)