|
1 | 1 | // Copyright © 2023-2024 Apple Inc. |
2 | | - |
3 | 2 | #include <algorithm> |
4 | 3 | #include <cassert> |
5 | 4 | #include <cmath> |
@@ -1683,48 +1682,58 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap( |
1683 | 1682 | auto gather_axes = axes_; |
1684 | 1683 | auto slice_sizes = slice_sizes_; |
1685 | 1684 | auto src_vmapped = axes[0] >= 0; |
1686 | | - auto indices_vmapped = |
1687 | | - std::any_of(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; }); |
1688 | | - auto out_ax = |
1689 | | - *std::find_if(axes.begin(), axes.end(), [](int a) { return a >= 0; }); |
| 1685 | + auto ind_vmap_ax_ptr = |
| 1686 | + std::find_if(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; }); |
| 1687 | + int out_ax = -1; |
| 1688 | + bool indices_vmapped = (ind_vmap_ax_ptr != axes.end()); |
| 1689 | + if (indices_vmapped) { |
| 1690 | + out_ax = *ind_vmap_ax_ptr; |
| 1691 | + } else if (src_vmapped) { |
| 1692 | + out_ax = axes[0]; |
| 1693 | + } |
1690 | 1694 |
|
1691 | 1695 | // Reorder all the index arrays so the vmap axis is in the same spot. |
1692 | | - for (int i = 1; i < axes.size(); ++i) { |
1693 | | - if (out_ax != axes[i] && axes[i] >= 0) { |
1694 | | - indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream()); |
| 1696 | + if (indices_vmapped) { |
| 1697 | + for (int i = 1; i < axes.size(); ++i) { |
| 1698 | + if (out_ax != axes[i] && axes[i] >= 0) { |
| 1699 | + indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream()); |
| 1700 | + } else if (axes[i] < 0) { |
| 1701 | + indices[i - 1] = expand_dims(indices[i - 1], out_ax, stream()); |
| 1702 | + } |
1695 | 1703 | } |
1696 | 1704 | } |
1697 | 1705 |
|
| 1706 | + int idx_dims = indices.empty() ? 0 : indices[0].ndim(); |
| 1707 | + |
1698 | 1708 | if (src_vmapped) { |
1699 | | - int max_dims = 0; |
1700 | | - for (auto& idx : indices) { |
1701 | | - max_dims = std::max(static_cast<int>(idx.ndim()), max_dims); |
1702 | | - } |
1703 | | - auto new_ax_loc = |
1704 | | - std::find_if(gather_axes.begin(), gather_axes.end(), [&out_ax](int a) { |
1705 | | - return a >= out_ax; |
1706 | | - }); |
1707 | | - for (; new_ax_loc < gather_axes.end(); new_ax_loc++) { |
1708 | | - (*new_ax_loc)++; |
| 1709 | + for (auto& ax : gather_axes) { |
| 1710 | + if (ax >= axes[0]) { |
| 1711 | + ax++; |
| 1712 | + } |
1709 | 1713 | } |
1710 | 1714 | if (indices_vmapped) { |
1711 | 1715 | // Make a new index array for the vmapped dimension |
| 1716 | + auto vmap_inds = arange(0, src.shape(axes[0]), stream()); |
1712 | 1717 | // Reshape it so it broadcasts with other index arrays |
| 1718 | + { |
| 1719 | + auto shape = std::vector<int>(idx_dims, 1); |
| 1720 | + shape[out_ax] = vmap_inds.size(); |
| 1721 | + vmap_inds = reshape(vmap_inds, std::move(shape), stream()); |
| 1722 | + } |
1713 | 1723 | // Update gather axes and slice sizes accordingly |
1714 | | - auto shape = std::vector<int>(max_dims - out_ax, 1); |
1715 | | - auto vmap_inds = arange(0, src.shape(out_ax), stream()); |
1716 | | - shape[0] = vmap_inds.shape(0); |
1717 | | - vmap_inds = reshape(vmap_inds, shape, stream()); |
1718 | | - slice_sizes.insert(slice_sizes.begin() + out_ax, 1); |
1719 | | - auto new_ax_idx = new_ax_loc - gather_axes.begin(); |
1720 | | - gather_axes.insert(new_ax_loc, out_ax); |
1721 | | - indices.insert(indices.begin() + new_ax_idx, vmap_inds); |
| 1724 | + slice_sizes.insert(slice_sizes.begin() + axes[0], 1); |
| 1725 | + gather_axes.push_back(axes[0]); |
| 1726 | + indices.push_back(vmap_inds); |
1722 | 1727 | } else { |
1723 | | - slice_sizes.insert(slice_sizes.begin() + axes[0], src.shape(axes[0])); |
1724 | | - out_ax = max_dims + axes[0]; |
| 1728 | + slice_sizes.insert(slice_sizes.begin() + out_ax, src.shape(out_ax)); |
| 1729 | + out_ax += idx_dims; |
1725 | 1730 | } |
1726 | 1731 | } |
1727 | | - return {{gather(src, indices, gather_axes, slice_sizes, stream())}, {out_ax}}; |
| 1732 | + auto out = gather(src, indices, gather_axes, slice_sizes, stream()); |
| 1733 | + if (src_vmapped && indices_vmapped) { |
| 1734 | + out = squeeze(out, idx_dims + axes[0], stream()); |
| 1735 | + } |
| 1736 | + return {{out}, {out_ax}}; |
1728 | 1737 | } |
1729 | 1738 |
|
1730 | 1739 | std::vector<array> Gather::vjp( |
|
0 commit comments