-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Open
Labels
bugSomething isn't workingSomething isn't workingpallasIssues pertaining to Pallas (GPU or TPU)Issues pertaining to Pallas (GPU or TPU)
Description
Description
When the maximum value occurs more than once along a particular axis, the smallest index is returned.
Potentially related issue pytorch/xla#2415
https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argmin.html
https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argmax.html
argmax
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
def kernel(x_ref, out_ref):
out_ref[...] = jnp.argmax(x_ref[...], axis=-1)
x = jnp.ones((1, 8), dtype=jnp.float32)
pallas_result = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((1,), jnp.int32),
)(x)
regular_result = jnp.argmax(x, axis=-1)
print(f"Input: {x}")
print(f"Pallas jnp.argmax: {pallas_result}")
print(f"Regular jnp.argmax: {regular_result}")Input: [[1. 1. 1. 1. 1. 1. 1. 1.]]
Pallas jnp.argmax: [7]
Regular jnp.argmax: [0]
argmin
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
def kernel(x_ref, out_ref):
out_ref[...] = jnp.argmin(x_ref[...], axis=-1)
x = jnp.ones((1, 8), dtype=jnp.float32)
pallas_result = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((1,), jnp.int32),
)(x)
regular_result = jnp.argmin(x, axis=-1)
print(f"Input: {x}")
print(f"Pallas jnp.argmin: {pallas_result}")
print(f"Regular jnp.argmin: {regular_result}")Input: [[1. 1. 1. 1. 1. 1. 1. 1.]]
Pallas jnp.argmin: [7]
Regular jnp.argmin: [0]
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.9.0
jaxlib: 0.9.0
numpy: 2.3.5
python: 3.11.13 (main, Jun 4 2025, 08:57:29) [GCC 11.4.0]
device info: TPU v6 lite-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-7620f7df-w-0', release='6.8.0-1015-gcp', version='#17~22.04.1-Ubuntu SMP Tue Sep 3 16:11:52 UTC 2024', machine='x86_64')
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingpallasIssues pertaining to Pallas (GPU or TPU)Issues pertaining to Pallas (GPU or TPU)