Skip to content

Commit 7392aec

Browse files
committed
fix: route per-channel FP8 MoE to CompressedTensorsFp8MoEMethod
Per-channel (per_Token) FP8 quantization needs per-channel weight scale allocation [E, N, 1] which CompressedTensorsFp8MoEMethod provides. Fp8MoEMethod only allocates scalar-per-expert scales [E, 2]/[E]. - Add dispatch case for quant_dtype==fp8 + quant_type==per_Token to use CompressedTensorsFp8MoEMethod - Fix _load_per_channel_weight_scale to unsqueeze 1D checkpoint scales to match 2D [N, 1] buffer shape
1 parent 097b7a8 commit 7392aec

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

atom/model_ops/moe.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1984,6 +1984,13 @@ def __init__(
19841984
):
19851985
# Use CompressedTensorsFp8MoEMethod for compressed-tensors format
19861986
self.quant_method = CompressedTensorsFp8MoEMethod(quant_config, moe)
1987+
elif (
1988+
quant_config["quant_dtype"] == dtypes.fp8
1989+
and quant_config["quant_type"] == QuantType.per_Token
1990+
):
1991+
# Per-channel FP8 (e.g., Quark per_Token override for MTP layers)
1992+
# needs CompressedTensors-style weight scale handling
1993+
self.quant_method = CompressedTensorsFp8MoEMethod(quant_config, moe)
19871994
elif quant_config["quant_dtype"] == dtypes.fp8:
19881995
self.quant_method = Fp8MoEMethod(quant_config, moe)
19891996
elif quant_config["quant_dtype"] == dtypes.fp4x2:
@@ -2100,6 +2107,10 @@ def _load_per_channel_weight_scale(
21002107
tp_rank: int,
21012108
):
21022109
# for per channel weight quantization
2110+
# When scales are stored as [N,1] (CompressedTensors per-channel)
2111+
# but loaded from checkpoint as [N], reshape to match.
2112+
if loaded_weight.dim() == 1 and expert_data.dim() == 2:
2113+
loaded_weight = loaded_weight.unsqueeze(-1)
21032114
if shard_id == "w2":
21042115
expert_data.copy_(loaded_weight)
21052116
elif shard_id in ("w1", "w3"):

0 commit comments

Comments
 (0)