@@ -325,6 +325,29 @@ def test_categorical(self):
325325 with self .assertRaises (ValueError ):
326326 mx .random .categorical (logits , shape = [10 , 5 ], num_samples = 5 )
327327
328+ def test_permutation (self ):
329+ x = sorted (mx .random .permutation (4 ).tolist ())
330+ self .assertEqual ([0 , 1 , 2 , 3 ], x )
331+
332+ x = mx .array ([0 , 1 , 2 , 3 ])
333+ x = sorted (mx .random .permutation (x ).tolist ())
334+ self .assertEqual ([0 , 1 , 2 , 3 ], x )
335+
336+ x = mx .array ([0 , 1 , 2 , 3 ])
337+ x = sorted (mx .random .permutation (x ).tolist ())
338+
339+ # 2-D
340+ x = mx .arange (16 ).reshape (4 , 4 )
341+ out = mx .sort (mx .random .permutation (x , axis = 0 ), axis = 0 )
342+ self .assertTrue (mx .array_equal (x , out ))
343+ out = mx .sort (mx .random .permutation (x , axis = 1 ), axis = 1 )
344+ self .assertTrue (mx .array_equal (x , out ))
345+
346+ # Basically 0 probability this should fail.
347+ sorted_x = mx .arange (16384 )
348+ x = mx .random .permutation (16384 )
349+ self .assertFalse (mx .array_equal (sorted_x , x ))
350+
328351
329352if __name__ == "__main__" :
330353 unittest .main ()
0 commit comments