Skip to content

Conversation

@copybara-service
Copy link

Replace (Haiku internal) PRNGKey with jax.Array.

As noted in JAX docs [0] the correct type for PRNG keys in JAX is simply
jax.Array. In #857 there was some confusion that Haiku's internal type alias
might be causing typing issues for users code, so to make things simpler I
am removing the alias in favour of using jax.Array as recommended.

[0] https://docs.jax.dev/en/latest/jep/9263-typed-keys.html#type-annotations-for-prng-keys

@copybara-service copybara-service bot force-pushed the test_852726459 branch 2 times, most recently from eb532bb to 335e9de Compare January 6, 2026 13:46
As noted in JAX docs [0] the correct type for PRNG keys in JAX is simply
`jax.Array`. In #857 there was some confusion that Haiku's internal type alias
might be causing typing issues for users code, so to make things simpler I
am removing the alias in favour of using `jax.Array` as recommended.

[0] https://docs.jax.dev/en/latest/jep/9263-typed-keys.html#type-annotations-for-prng-keys

PiperOrigin-RevId: 852755180
@copybara-service copybara-service bot merged commit d45d6f5 into main Jan 6, 2026
@copybara-service copybara-service bot deleted the test_852726459 branch January 6, 2026 14:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant