Skip to content

ak.argmin/argmax give incorrect results in the jax backend #3463

Open
@ikrommyd

Description

@ikrommyd
In [3]: import awkward as ak
   ...: ak.jax.register_and_check()
   ...:
   ...: x = ak.argmin(ak.Array([[1, 2], [], [6,4,5], [7]], backend="cpu"), axis=1)
   ...: y = ak.argmin(ak.Array([[1, 2], [], [6,4,5], [7]], backend="jax"), axis=1)

In [4]: x
Out[4]: <Array [0, None, 1, 0] type='4 * ?int64'>

In [5]: y
Out[5]: <Array [0, None, 3, 5] type='4 * ?int64'>

My understand is that this happens because positional corrections aren't applied like the cpu backend does:

apply_positional_corrections(result_array, parents, starts, shifts)
even though the output of the reduce kernel is correct.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugThe problem described is something that must be fixed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions