Skip to content

Commit 2b55bd5

Browse files
Merge pull request jax-ml#24657 from dymil:patch-1
PiperOrigin-RevId: 694536516
2 parents ce3826d + 9763044 commit 2b55bd5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10185,18 +10185,18 @@ def argmin(a: ArrayLike, axis: int | None = None, out: None = None,
1018510185
keepdims: bool | None = None) -> Array:
1018610186
"""Return the index of the minimum value of an array.
1018710187
10188-
JAX implementation of :func:`numpy.argmax`.
10188+
JAX implementation of :func:`numpy.argmin`.
1018910189
1019010190
Args:
1019110191
a: input array
10192-
axis: optional integer specifying the axis along which to find the maximum
10192+
axis: optional integer specifying the axis along which to find the minimum
1019310193
value. If ``axis`` is not specified, ``a`` will be flattened.
1019410194
out: unused by JAX
1019510195
keepdims: if True, then return an array with the same number of dimensions
1019610196
as ``a``.
1019710197
1019810198
Returns:
10199-
an array containing the index of the maximum value along the specified axis.
10199+
an array containing the index of the minimum value along the specified axis.
1020010200
1020110201
See also:
1020210202
- :func:`jax.numpy.argmax`: return the index of the maximum value.

0 commit comments

Comments
 (0)