fix(batch_norm): allow string cross_replica_axis for vmap support #859
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes
NameErrorwhen usinghk.BatchNormwith a namedjax.vmapaxis.Problem
hk.BatchNormpreviously assumedcross_replica_axiswas always available (like inpmap). When used withvmap, this caused aNameError: unbound axis nameduring initialization because the named axis doesn't exist in the init context.Fix
Added a
try-except NameErrorblock around thejax.lax.pmeancalls.init), it now skips the sync and uses local stats.vmap, it correctly syncs statistics across the batch dimension.Verification
initno longer crashes, andvmapstatistics match expected values (std ~1.0).haiku/_src/batch_norm_test.pyCloses #812
Closes #822