Same as https://github.com/google-deepmind/jax_privacy/issues/86 but let's use Flax NNX instead of raw Jax.