Skip to content

Commit 718aea3

Browse files
authored
allow take to work with integer index (#1440)
1 parent 5b6f38d commit 718aea3

File tree

4 files changed

+74
-17
lines changed

4 files changed

+74
-17
lines changed

mlx/ops.cpp

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,20 @@ array squeeze(
504504
shape.push_back(a.shape(i));
505505
}
506506
}
507-
return reshape(a, shape, s);
507+
return reshape(a, std::move(shape), s);
508+
}
509+
510+
array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) {
511+
int ax = axis < 0 ? axis + a.ndim() : axis;
512+
if (ax < 0 || ax >= a.ndim()) {
513+
std::ostringstream msg;
514+
msg << "[squeeze] Invalid axis " << axis << " for array with " << a.ndim()
515+
<< " dimensions.";
516+
throw std::invalid_argument(msg.str());
517+
}
518+
auto shape = a.shape();
519+
shape.erase(shape.begin() + ax);
520+
return reshape(a, std::move(shape), s);
508521
}
509522

510523
array squeeze(const array& a, StreamOrDevice s /* = {} */) {
@@ -657,10 +670,15 @@ array slice(
657670

658671
array slice(
659672
const array& a,
660-
const std::vector<int>& start,
661-
const std::vector<int>& stop,
673+
std::vector<int> start,
674+
std::vector<int> stop,
662675
StreamOrDevice s /* = {} */) {
663-
return slice(a, start, stop, std::vector<int>(a.ndim(), 1), to_stream(s));
676+
return slice(
677+
a,
678+
std::move(start),
679+
std::move(stop),
680+
std::vector<int>(a.ndim(), 1),
681+
to_stream(s));
664682
}
665683

666684
/** Update a slice from the source array */
@@ -2715,13 +2733,43 @@ array take(
27152733
// Squeeze the axis we take over
27162734
std::vector<int> out_shape = out.shape();
27172735
out_shape.erase(out_shape.begin() + indices.ndim() + axis);
2718-
return reshape(out, out_shape, s);
2736+
return reshape(out, std::move(out_shape), s);
27192737
}
27202738

27212739
array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {
27222740
return take(reshape(a, {-1}, s), indices, 0, s);
27232741
}
27242742

2743+
array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {
2744+
// Check for valid axis
2745+
if (axis + static_cast<int>(a.ndim()) < 0 ||
2746+
axis >= static_cast<int>(a.ndim())) {
2747+
std::ostringstream msg;
2748+
msg << "[take] Received invalid axis " << axis << " for array with "
2749+
<< a.ndim() << " dimensions.";
2750+
throw std::invalid_argument(msg.str());
2751+
}
2752+
2753+
// Check for valid take
2754+
if (a.size() == 0) {
2755+
throw std::invalid_argument(
2756+
"[take] Cannot do a non-empty take from an array with zero elements.");
2757+
}
2758+
2759+
// Handle negative axis
2760+
axis = axis < 0 ? a.ndim() + axis : axis;
2761+
2762+
std::vector<int> starts(a.ndim(), 0);
2763+
std::vector<int> stops = a.shape();
2764+
starts[axis] = index;
2765+
stops[axis] = index + 1;
2766+
return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s);
2767+
}
2768+
2769+
array take(const array& a, int index, StreamOrDevice s /* = {} */) {
2770+
return take(reshape(a, {-1}, s), index, 0, s);
2771+
}
2772+
27252773
array take_along_axis(
27262774
const array& a,
27272775
const array& indices,
@@ -2764,7 +2812,7 @@ array take_along_axis(
27642812
// Squeeze out the slice shape
27652813
std::vector<int> out_shape(
27662814
out.shape().begin(), out.shape().begin() + a.ndim());
2767-
return reshape(out, out_shape, s);
2815+
return reshape(out, std::move(out_shape), s);
27682816
}
27692817

27702818
array put_along_axis(

mlx/ops.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,7 @@ array squeeze(
144144
StreamOrDevice s = {});
145145

146146
/** Remove singleton dimensions at the given axis. */
147-
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
148-
return squeeze(a, std::vector<int>{axis}, s);
149-
}
147+
array squeeze(const array& a, int axis, StreamOrDevice s = {});
150148

151149
/** Remove all singleton dimensions. */
152150
array squeeze(const array& a, StreamOrDevice s = {});
@@ -171,8 +169,8 @@ array slice(
171169
/** Slice an array with a stride of 1 in each dimension. */
172170
array slice(
173171
const array& a,
174-
const std::vector<int>& start,
175-
const std::vector<int>& stop,
172+
std::vector<int> start,
173+
std::vector<int> stop,
176174
StreamOrDevice s = {});
177175

178176
/** Update a slice from the source array */
@@ -936,9 +934,11 @@ array take(
936934
const array& indices,
937935
int axis,
938936
StreamOrDevice s = {});
937+
array take(const array& a, int index, int axis, StreamOrDevice s = {});
939938

940939
/** Take array entries at the given indices treating the array as flattened. */
941940
array take(const array& a, const array& indices, StreamOrDevice s = {});
941+
array take(const array& a, int index, StreamOrDevice s = {});
942942

943943
/** Take array entries given indices along the axis */
944944
array take_along_axis(

python/src/ops.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,13 +1398,15 @@ void init_ops(nb::module_& m) {
13981398
m.def(
13991399
"take",
14001400
[](const array& a,
1401-
const array& indices,
1401+
const std::variant<int, array>& indices,
14021402
const std::optional<int>& axis,
14031403
StreamOrDevice s) {
1404-
if (axis.has_value()) {
1405-
return take(a, indices, axis.value(), s);
1404+
if (auto pv = std::get_if<int>(&indices); pv) {
1405+
return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s);
14061406
} else {
1407-
return take(a, indices, s);
1407+
auto indices_ = std::get<array>(indices);
1408+
return axis ? take(a, indices_, axis.value(), s)
1409+
: take(a, indices_, s);
14081410
}
14091411
},
14101412
nb::arg(),
@@ -1413,7 +1415,7 @@ void init_ops(nb::module_& m) {
14131415
nb::kw_only(),
14141416
"stream"_a = nb::none(),
14151417
nb::sig(
1416-
"def take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
1418+
"def take(a: array, /, indices: Union[int, array], axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
14171419
R"pbdoc(
14181420
Take elements along an axis.
14191421
@@ -1425,7 +1427,7 @@ void init_ops(nb::module_& m) {
14251427
14261428
Args:
14271429
a (array): Input array.
1428-
indices (array): Input array with integral type.
1430+
indices (int or array): Integer index or input array with integral type.
14291431
axis (int, optional): Axis along which to perform the take. If unspecified
14301432
the array is treated as a flattened 1-D vector.
14311433

python/tests/test_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,13 @@ def test_take(self):
10591059
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
10601060
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
10611061

1062+
# Take with integer index
1063+
a = mx.arange(8).reshape(2, 4)
1064+
out = mx.take(a, 1, axis=0)
1065+
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6, 7])))
1066+
out = mx.take(a, 1, axis=1)
1067+
self.assertTrue(mx.array_equal(out, mx.array([1, 5])))
1068+
10621069
def test_take_along_axis(self):
10631070
a_np = np.arange(8).reshape(2, 2, 2)
10641071
a_mlx = mx.array(a_np)

0 commit comments

Comments
 (0)