Hi, I found a bug in the pixelwise function of ColorJitter regarding saturation adjustment.
In line 275, the saturation variable is updated via a function call, but the return value is not assigned back to the variable. Since JAX arrays are immutable, this line currently has no effect.
# Current code
F.adjust_brightness(saturation, amount, invert=invert)
# Likely intended code
saturation = F.adjust_brightness(saturation, amount, invert=invert)
I verified this experimentally: when I set brightness/contrast/hue to 0 and probability p to 1.0, the output images are identical to the input images, meaning the saturation jitter is ignored.
Given that ColorJitter is used for data augmentation in the Pi-series VLA models (OpenPi), this issue might be silently affecting training performance.
