diff --git a/haiku/_src/batch_norm.py b/haiku/_src/batch_norm.py index 4a92fc8e5..460c914e2 100644 --- a/haiku/_src/batch_norm.py +++ b/haiku/_src/batch_norm.py @@ -170,14 +170,19 @@ def __call__( mean = jnp.mean(inputs, axis, keepdims=True) mean_of_squares = jnp.mean(jnp.square(inputs), axis, keepdims=True) if self.cross_replica_axis: - mean = jax.lax.pmean( - mean, - axis_name=self.cross_replica_axis, - axis_index_groups=self.cross_replica_axis_index_groups) - mean_of_squares = jax.lax.pmean( - mean_of_squares, - axis_name=self.cross_replica_axis, - axis_index_groups=self.cross_replica_axis_index_groups) + try: + mean = jax.lax.pmean( + mean, + axis_name=self.cross_replica_axis, + axis_index_groups=self.cross_replica_axis_index_groups) + mean_of_squares = jax.lax.pmean( + mean_of_squares, + axis_name=self.cross_replica_axis, + axis_index_groups=self.cross_replica_axis_index_groups) + except NameError: + # If the axis is not bound (e.g. during init or non-mapped execution), + # we skip the sync and fall back to local statistics. + pass var = mean_of_squares - jnp.square(mean) else: mean = self.mean_ema.average.astype(inputs.dtype)