Skip to content

Commit f30e633

Browse files
authored
Minor updates to address a few issues (#537)
* docs on arg indices return type * arange with nan * undo isort
1 parent 4fe2fa2 commit f30e633

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

mlx/ops.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,14 @@ array arange(
7979
msg << bool_ << " not supported for arange.";
8080
throw std::invalid_argument(msg.str());
8181
}
82-
int size = std::max(static_cast<int>(std::ceil((stop - start) / step)), 0);
82+
if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) {
83+
throw std::invalid_argument("[arange] Cannot compute length.");
84+
}
85+
double real_size = std::ceil((stop - start) / step);
86+
if (std::isnan(real_size)) {
87+
throw std::invalid_argument("[arange] Cannot compute length.");
88+
}
89+
int size = std::max(static_cast<int>(real_size), 0);
8390
return array(
8491
{size},
8592
dtype,

python/src/ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2254,7 +2254,7 @@ void init_ops(py::module_& m) {
22542254
singleton dimensions, defaults to `False`.
22552255
22562256
Returns:
2257-
array: The output array with the indices of the minimum values.
2257+
array: The ``uint32`` array with the indices of the minimum values.
22582258
)pbdoc");
22592259
m.def(
22602260
"argmax",
@@ -2287,7 +2287,7 @@ void init_ops(py::module_& m) {
22872287
singleton dimensions, defaults to `False`.
22882288
22892289
Returns:
2290-
array: The output array with the indices of the maximum values.
2290+
array: The ``uint32`` array with the indices of the maximum values.
22912291
)pbdoc");
22922292
m.def(
22932293
"sort",
@@ -2343,7 +2343,7 @@ void init_ops(py::module_& m) {
23432343
If unspecified, it defaults to -1 (sorting over the last axis).
23442344
23452345
Returns:
2346-
array: The indices that sort the input array.
2346+
array: The ``uint32`` array containing indices that sort the input.
23472347
)pbdoc");
23482348
m.def(
23492349
"partition",
@@ -2416,7 +2416,7 @@ void init_ops(py::module_& m) {
24162416
If unspecified, it defaults to ``-1``.
24172417
24182418
Returns:
2419-
array: The indices that partition the input array.
2419+
array: The `uint32`` array containing indices that partition the input.
24202420
)pbdoc");
24212421
m.def(
24222422
"topk",

python/tests/test_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,17 @@ def test_split(self):
980980
self.assertEqual(z.tolist(), [5, 6, 7])
981981

982982
def test_arange_overload_dispatch(self):
983+
with self.assertRaises(ValueError):
984+
a = mx.arange(float("nan"), 1, 5)
985+
with self.assertRaises(ValueError):
986+
a = mx.arange(0, float("nan"), 5)
987+
with self.assertRaises(ValueError):
988+
a = mx.arange(0, 2, float("nan"))
989+
with self.assertRaises(ValueError):
990+
a = mx.arange(0, float("inf"), float("inf"))
991+
with self.assertRaises(ValueError):
992+
a = mx.arange(float("inf"), 1, float("inf"))
993+
983994
a = mx.arange(5)
984995
expected = [0, 1, 2, 3, 4]
985996
self.assertListEqual(a.tolist(), expected)

0 commit comments

Comments
 (0)