-
Notifications
You must be signed in to change notification settings - Fork 7
Description
Gather operation is basically broken, for example x[:, 0] works correctly, but x[:, 1] is broken (although this exact bug is fixable with a few lines of extra code).
Also, most arguments passed to the gather operation are ignored, which results in assuming the simplest most general version, always.
So most of the slicing in JAX which gets turned into lax.gather does not work correctly.
Other examples:
This is working: x[:, 0, :]
this one x[:, :, 0] this does not work, and actually gets compiled into the same as x[:, 0, :]
We need jax2onnx for our production release soon(ish), so if my time allows I will make the correct implementation into gather.py myself. I am reporting this issue so you know that I am starting to work on this. If somewhere there is already a good implementation please warn me so I don't work needlessly.