Skip to content

Gather operation does not handle most cases #95

@rakadam

Description

@rakadam

Gather operation is basically broken, for example x[:, 0] works correctly, but x[:, 1] is broken (although this exact bug is fixable with a few lines of extra code).
Also, most arguments passed to the gather operation are ignored, which results in assuming the simplest most general version, always.
So most of the slicing in JAX which gets turned into lax.gather does not work correctly.

Other examples:
This is working: x[:, 0, :]
this one x[:, :, 0] this does not work, and actually gets compiled into the same as x[:, 0, :]

We need jax2onnx for our production release soon(ish), so if my time allows I will make the correct implementation into gather.py myself. I am reporting this issue so you know that I am starting to work on this. If somewhere there is already a good implementation please warn me so I don't work needlessly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions