From 0cac2a48710ea6a258a80ba731050536708869f3 Mon Sep 17 00:00:00 2001 From: George Polovets Date: Fri, 27 Feb 2026 20:15:42 +0000 Subject: [PATCH] Updated sharding scales to align with GMM_TP settings. Signed-off-by: George Polovets --- tpu_inference/layers/jax/quantization/fp8.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tpu_inference/layers/jax/quantization/fp8.py b/tpu_inference/layers/jax/quantization/fp8.py index 246331c871..0990d864dc 100644 --- a/tpu_inference/layers/jax/quantization/fp8.py +++ b/tpu_inference/layers/jax/quantization/fp8.py @@ -549,12 +549,15 @@ def process_weights_after_loading(self, layer: JaxMoE) -> bool: shard_put(weights.w13_weight, shardings=layer.edf_sharding)) layer.kernel_down_proj_EFD = nnx.Param( shard_put(weights.w2_weight, shardings=layer.efd_sharding)) + # gmm expects shape [num_groups, num_blocks, 1, n] - # TODO(gpolovets1): Make sure it works for gmm_v2 as well. - edf_scale_sharding = (layer.edf_sharding[0], ) + (None, ) * ( - weights.w13_weight_scale.ndim - 2) + (layer.edf_sharding[-1], ) - efd_scale_sharding = (layer.efd_sharding[0], ) + (None, ) * ( - weights.w2_weight_scale.ndim - 2) + (layer.efd_sharding[-1], ) + edf_scale_sharding = (layer.edf_sharding[0], None, None, + layer.edf_sharding[2]) + w2_scale_tp_axis = layer.efd_sharding[ + 1] if weights.w2_weight_scale.shape[1] > 1 else None + efd_scale_sharding = (layer.efd_sharding[0], w2_scale_tp_axis, + None, None) + setattr( layer, f"kernel_gating_upproj_EDF_{self.weight_scale_name}", nnx.Param(