diff --git a/tpu_inference/layers/jax/quantization/fp8.py b/tpu_inference/layers/jax/quantization/fp8.py index 246331c871..0990d864dc 100644 --- a/tpu_inference/layers/jax/quantization/fp8.py +++ b/tpu_inference/layers/jax/quantization/fp8.py @@ -549,12 +549,15 @@ def process_weights_after_loading(self, layer: JaxMoE) -> bool: shard_put(weights.w13_weight, shardings=layer.edf_sharding)) layer.kernel_down_proj_EFD = nnx.Param( shard_put(weights.w2_weight, shardings=layer.efd_sharding)) + # gmm expects shape [num_groups, num_blocks, 1, n] - # TODO(gpolovets1): Make sure it works for gmm_v2 as well. - edf_scale_sharding = (layer.edf_sharding[0], ) + (None, ) * ( - weights.w13_weight_scale.ndim - 2) + (layer.edf_sharding[-1], ) - efd_scale_sharding = (layer.efd_sharding[0], ) + (None, ) * ( - weights.w2_weight_scale.ndim - 2) + (layer.efd_sharding[-1], ) + edf_scale_sharding = (layer.edf_sharding[0], None, None, + layer.edf_sharding[2]) + w2_scale_tp_axis = layer.efd_sharding[ + 1] if weights.w2_weight_scale.shape[1] > 1 else None + efd_scale_sharding = (layer.efd_sharding[0], w2_scale_tp_axis, + None, None) + setattr( layer, f"kernel_gating_upproj_EDF_{self.weight_scale_name}", nnx.Param(