Skip to content

[Pallas TPU] jnp.argmin/argmax returns last index instead of first index on ties #34620

@catswe

Description

@catswe

Description

When the maximum value occurs more than once along a particular axis, the smallest index is returned.

Potentially related issue pytorch/xla#2415

https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argmin.html
https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argmax.html

argmax

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def kernel(x_ref, out_ref):
    out_ref[...] = jnp.argmax(x_ref[...], axis=-1)

x = jnp.ones((1, 8), dtype=jnp.float32)
pallas_result = pl.pallas_call(
    kernel,
    out_shape=jax.ShapeDtypeStruct((1,), jnp.int32),
)(x)
regular_result = jnp.argmax(x, axis=-1)

print(f"Input: {x}")
print(f"Pallas jnp.argmax: {pallas_result}")
print(f"Regular jnp.argmax: {regular_result}")
Input: [[1. 1. 1. 1. 1. 1. 1. 1.]]
Pallas jnp.argmax: [7]
Regular jnp.argmax: [0]

argmin

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def kernel(x_ref, out_ref):
    out_ref[...] = jnp.argmin(x_ref[...], axis=-1)

x = jnp.ones((1, 8), dtype=jnp.float32)
pallas_result = pl.pallas_call(
    kernel,
    out_shape=jax.ShapeDtypeStruct((1,), jnp.int32),
)(x)
regular_result = jnp.argmin(x, axis=-1)

print(f"Input: {x}")
print(f"Pallas jnp.argmin: {pallas_result}")
print(f"Regular jnp.argmin: {regular_result}")
Input: [[1. 1. 1. 1. 1. 1. 1. 1.]]
Pallas jnp.argmin: [7]
Regular jnp.argmin: [0]

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.9.0
jaxlib: 0.9.0
numpy:  2.3.5
python: 3.11.13 (main, Jun  4 2025, 08:57:29) [GCC 11.4.0]
device info: TPU v6 lite-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-7620f7df-w-0', release='6.8.0-1015-gcp', version='#17~22.04.1-Ubuntu SMP Tue Sep  3 16:11:52 UTC 2024', machine='x86_64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingpallasIssues pertaining to Pallas (GPU or TPU)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions