Skip to content

Commit e1c9600

Browse files
authored
Add mx.random.permutation (#1471)
* random permutation * comment
1 parent 1fa0d20 commit e1c9600

File tree

5 files changed

+85
-0
lines changed

5 files changed

+85
-0
lines changed

docs/src/python/random.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
4545
truncated_normal
4646
uniform
4747
laplace
48+
permutation

mlx/random.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,4 +458,19 @@ array laplace(
458458
return samples;
459459
}
460460

461+
array permutation(
462+
const array& x,
463+
int axis /* = 0 */,
464+
const std::optional<array>& key /* = std::nullopt */,
465+
StreamOrDevice s /* = {} */) {
466+
return take(x, permutation(x.shape(axis), key, s), axis, s);
467+
}
468+
469+
array permutation(
470+
int x,
471+
const std::optional<array>& key /* = std::nullopt */,
472+
StreamOrDevice s /* = {} */) {
473+
return argsort(bits({x}, key, s), s);
474+
}
475+
461476
} // namespace mlx::core::random

mlx/random.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,4 +254,17 @@ inline array laplace(
254254
return laplace(shape, float32, 0.0, 1.0, key, s);
255255
}
256256

257+
/* Randomly permute the elements of x along the given axis. */
258+
array permutation(
259+
const array& x,
260+
int axis = 0,
261+
const std::optional<array>& key = std::nullopt,
262+
StreamOrDevice s = {});
263+
264+
/* A random permutation of `arange(x)` */
265+
array permutation(
266+
int x,
267+
const std::optional<array>& key = std::nullopt,
268+
StreamOrDevice s = {});
269+
257270
} // namespace mlx::core::random

python/src/random.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,39 @@ void init_random(nb::module_& parent_module) {
454454
Returns:
455455
array: The output array of random values.
456456
)pbdoc");
457+
m.def(
458+
"permuation",
459+
[](const std::variant<int, array>& x,
460+
int axis,
461+
const std::optional<array>& key_,
462+
StreamOrDevice s) {
463+
auto key = key_ ? key_.value() : default_key().next();
464+
if (auto pv = std::get_if<int>(&x); pv) {
465+
return permutation(*pv, key, s);
466+
} else {
467+
return permutation(std::get<array>(x), axis, key, s);
468+
}
469+
},
470+
"shape"_a = std::vector<int>{},
471+
"axis"_a = 0,
472+
"key"_a = nb::none(),
473+
"stream"_a = nb::none(),
474+
nb::sig(
475+
"def permutation(x: Union[int, array], axis: int = 0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
476+
R"pbdoc(
477+
Generate a random permutation or permute the entries of an array.
478+
479+
Args:
480+
x (int or array, optional): If an integer is provided a random
481+
permtuation of ``mx.arange(x)`` is returned. Otherwise the entries
482+
of ``x`` along the given axis are randomly permuted.
483+
axis (int, optional): The axis to permute along. Default: ``0``.
484+
key (array, optional): A PRNG key. Default: ``None``.
485+
486+
Returns:
487+
array:
488+
The generated random permutation or randomly permuted input array.
489+
)pbdoc");
457490
// Register static Python object cleanup before the interpreter exits
458491
auto atexit = nb::module_::import_("atexit");
459492
atexit.attr("register")(nb::cpp_function([]() { default_key().release(); }));

python/tests/test_random.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,29 @@ def test_categorical(self):
325325
with self.assertRaises(ValueError):
326326
mx.random.categorical(logits, shape=[10, 5], num_samples=5)
327327

328+
def test_permutation(self):
329+
x = sorted(mx.random.permutation(4).tolist())
330+
self.assertEqual([0, 1, 2, 3], x)
331+
332+
x = mx.array([0, 1, 2, 3])
333+
x = sorted(mx.random.permutation(x).tolist())
334+
self.assertEqual([0, 1, 2, 3], x)
335+
336+
x = mx.array([0, 1, 2, 3])
337+
x = sorted(mx.random.permutation(x).tolist())
338+
339+
# 2-D
340+
x = mx.arange(16).reshape(4, 4)
341+
out = mx.sort(mx.random.permutation(x, axis=0), axis=0)
342+
self.assertTrue(mx.array_equal(x, out))
343+
out = mx.sort(mx.random.permutation(x, axis=1), axis=1)
344+
self.assertTrue(mx.array_equal(x, out))
345+
346+
# Basically 0 probability this should fail.
347+
sorted_x = mx.arange(16384)
348+
x = mx.random.permutation(16384)
349+
self.assertFalse(mx.array_equal(sorted_x, x))
350+
328351

329352
if __name__ == "__main__":
330353
unittest.main()

0 commit comments

Comments
 (0)