Skip to content

Commit 5992420

Browse files
committed
FP8 Megablox for batch split
1 parent 2d739e9 commit 5992420

7 files changed

Lines changed: 220 additions & 61 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ save_quantized_params_path: ""
143143
# accepted values are "inference"
144144
model_call_mode: ""
145145
use_qwix_quantization: false # whether to use qwix for quantization. if set to true, the model will be quantized using qwix.
146+
use_manual_quantization: false # a flag if to use manual quantization for batch split. Only used if use_batch_split_schedule is True.
146147
# quantization calibration method used for weights and activations. supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#l70-l80
147148
weight_quantization_calibration_method: "absmax"
148149
act_quantization_calibration_method: "absmax"

src/maxtext/configs/types.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,10 @@ class Quantization(BaseModel):
422422
kv_quant_dtype: Literal["int8", "int4"] = Field("int8", description="Data type for KV cache quantization.")
423423
quantization_local_shard_count: int = Field(-1, description="Shards the range finding operation for quantization.")
424424
use_qwix_quantization: bool = Field(False, description="Whether to use qwix for quantization.")
425+
use_manual_quantization: bool = Field(
426+
False,
427+
description="Whether to use manual quantization for batch split. Only used if use_batch_split_schedule is True.",
428+
)
425429
weight_quantization_calibration_method: str = Field(
426430
"absmax",
427431
description="Quantization calibration method used for weights.",
@@ -2727,8 +2731,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
27272731
f"Decoder '{self.decoder_block.value}' is not supported with 'explicit' sharding. "
27282732
f"Supported options are: {list(supported_decoders)}."
27292733
)
2730-
if self.quantization:
2731-
raise ValueError("Quantization is not supported with 'explicit' sharding.")
27322734
if self.context_sharding not in ("context", "expert"):
27332735
raise ValueError(f"Assigned context_sharding f{self.context_sharding} is not supported.")
27342736
if (
@@ -2835,10 +2837,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
28352837
self.use_grpo = False
28362838

28372839
if self.use_batch_split_schedule:
2838-
if self.quantization and not (self.use_qwix_quantization and self.quantization == "fp8_full"):
2839-
raise ValueError(
2840-
"Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`"
2841-
)
2840+
if self.quantization and not self.quantization == "fp8_full":
2841+
raise ValueError("Batch split quantization only supports `quantization=fp8_full`")
28422842

28432843
if self.opt_type == "muon" and self.decoder_block not in [
28442844
DecoderBlockType.DEEPSEEK,

src/maxtext/kernels/megablox/ops.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax
2323
import jax.numpy as jnp
2424
from maxtext.kernels.megablox import backend
25+
from maxtext.layers import quantizations
2526
import qwix
2627
import qwix.pallas as qpl
2728
import tokamax
@@ -61,6 +62,7 @@ def gmm(
6162
weight_gather_axes: List[Tuple[str, int]] | None = None,
6263
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
6364
qwix_rule: qwix.QtRule | None = None,
65+
use_manual_quantization: bool = False,
6466
):
6567
"""Grouped matrix multiplication operation."""
6668
quantization_rule = None
@@ -80,7 +82,7 @@ def gmm(
8082
)
8183

8284
gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
83-
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11))
85+
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11, 12))
8486
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
8587
return gmm_fwd_bwd(
8688
lhs,
@@ -95,6 +97,7 @@ def gmm(
9597
quantization_rule,
9698
use_tokamax_backend,
9799
weight_gather_axes,
100+
use_manual_quantization,
98101
)
99102

100103

@@ -121,6 +124,7 @@ def _gmm_fwd(
121124
quantization_rule: qwix.QtRule | None = None,
122125
use_tokamax_backend: bool = False,
123126
weight_gather_axes: List[Tuple[str, int]] | None = None,
127+
use_manual_quantization: bool = False,
124128
) -> tuple[
125129
jnp.ndarray,
126130
tuple[
@@ -140,15 +144,18 @@ def _gmm_fwd(
140144
calibration_method=quantization_rule.act_calibration_method,
141145
)
142146
if quantization_rule.weight_qtype and not isinstance(rhs, qpl.QArray):
143-
rhs = qpl.quantize(
144-
rhs,
145-
quantization_rule.weight_qtype,
146-
# If only considering the fwd pass, we could also enable channelwise
147-
# axes for the group axis, i.e., [0, 1 or 2]. However, this makes the
148-
# bwd pass unable to reuse the scale easily.
149-
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else ([1] if transpose_rhs else [2]),
150-
calibration_method=quantization_rule.weight_calibration_method,
151-
)
147+
if not use_manual_quantization:
148+
rhs = qpl.quantize(
149+
rhs,
150+
quantization_rule.weight_qtype,
151+
# If only considering the fwd pass, we could also enable channelwise
152+
# axes for the group axis, i.e., [0, 1 or 2]. However, this makes the
153+
# bwd pass unable to reuse the scale easily.
154+
channelwise_axes=([] if quantization_rule.disable_channelwise_axes else ([1] if transpose_rhs else [2])),
155+
calibration_method=quantization_rule.weight_calibration_method,
156+
)
157+
else:
158+
rhs = quantizations.manual_quantize(rhs, quantization_rule.weight_calibration_method)
152159
# QAG is only supported for following conditions
153160
if use_tokamax_backend:
154161
if quantization_rule and quantization_rule.bwd_qtype:
@@ -195,6 +202,7 @@ def _gmm_bwd(
195202
quantization_rule: qwix.QtRule | None,
196203
use_tokamax_backend: bool,
197204
weight_gather_axes: List[Tuple[str, int]] | None,
205+
use_manual_quantization: bool,
198206
residual: tuple[
199207
jnp.ndarray | qpl.QArray,
200208
jnp.ndarray | qpl.QArray,
@@ -204,7 +212,7 @@ def _gmm_bwd(
204212
grad: jnp.ndarray,
205213
) -> tuple[jnp.ndarray, jnp.ndarray, None, None, jnp.ndarray]:
206214
"""Backward function for throughput GMM VJP."""
207-
del preferred_element_type
215+
del preferred_element_type, use_manual_quantization
208216
lhs, rhs, group_sizes, group_offset = residual
209217
num_actual_groups = rhs.shape[0]
210218

src/maxtext/layers/decoders.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,8 @@ def __call__(
922922
# as detected by immutable params, use deepseek_batchsplit custom
923923
# scan with initialized parameters.
924924
if cfg.use_batch_split_schedule and not self.is_mutable_collection("params"):
925-
if cfg.use_qwix_quantization:
925+
# old version of batch-split that fully uses qwix quantization.
926+
if cfg.use_qwix_quantization and not cfg.use_manual_quantization:
926927
y = deepseek_batchsplit_fp8.scan_batch_split_layers(
927928
y,
928929
self.variables["params"]["moe_layers"],
@@ -935,7 +936,9 @@ def __call__(
935936
policy=policy,
936937
)
937938
else:
938-
# bf16 code path
939+
# bf16 and fp8 code path for pure-JAX batch-split.
940+
# fp8 code path supports both manual quantization and qwix
941+
# quantization.
939942
y = deepseek_batchsplit.scan_batch_split_layers(
940943
y,
941944
self.variables["params"]["moe_layers"],

src/maxtext/layers/quantizations.py

Lines changed: 81 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import functools
1818
import json
19+
import qwix.pallas as qpl
1920
import re
2021
from typing import Tuple, Sequence, Callable
2122
from dataclasses import dataclass
@@ -629,13 +630,15 @@ def get_quant_mode(quant_mode_str: str = "train"):
629630
def configure_quantization(config: Config, quant_mode_str: str = "train"):
630631
"""Configure quantization based on user config and quant mode."""
631632
if config.use_batch_split_schedule and config.quantization:
632-
if not (config.use_qwix_quantization and config.quantization == "fp8_full"):
633-
raise ValueError("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`")
634-
return QwixQuantization(
635-
weight_calibration_method=config.weight_quantization_calibration_method,
636-
act_calibration_method=config.act_quantization_calibration_method,
637-
bwd_calibration_method=config.bwd_quantization_calibration_method,
638-
)
633+
# The older version of batch-split that fully uses qwix quantization.
634+
if config.quantization == "fp8_full" and not config.use_manual_quantization:
635+
return QwixQuantization(
636+
weight_calibration_method=config.weight_quantization_calibration_method,
637+
act_calibration_method=config.act_quantization_calibration_method,
638+
bwd_calibration_method=config.bwd_quantization_calibration_method,
639+
)
640+
# The pure JAX version of batch-split that uses manual quantization.
641+
return None
639642

640643
if config.use_qwix_quantization:
641644
return None
@@ -764,8 +767,7 @@ def get_quantization_rule(config: Config):
764767
weight_qtype=jnp.int4,
765768
act_qtype=jnp.int4,
766769
bwd_qtype=jnp.int4,
767-
bwd_weight_grad_tile_size=1
768-
/ config.quantization_local_shard_count,
770+
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
769771
op_names=("dot_general",),
770772
)
771773
]
@@ -776,8 +778,7 @@ def get_quantization_rule(config: Config):
776778
weight_qtype=jnp.int8,
777779
act_qtype=jnp.int8,
778780
bwd_qtype=jnp.int8,
779-
bwd_weight_grad_tile_size=1
780-
/ config.quantization_local_shard_count,
781+
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
781782
op_names=("dot_general",),
782783
)
783784
]
@@ -788,8 +789,7 @@ def get_quantization_rule(config: Config):
788789
weight_qtype=jnp.float8_e4m3fn,
789790
act_qtype=jnp.float8_e4m3fn,
790791
bwd_qtype=jnp.float8_e4m3fn,
791-
bwd_weight_grad_tile_size=1
792-
/ config.quantization_local_shard_count,
792+
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
793793
op_names=("dot_general",),
794794
)
795795
]
@@ -802,8 +802,7 @@ def get_quantization_rule(config: Config):
802802
weight_qtype=jnp.float8_e4m3fn,
803803
act_qtype=jnp.float8_e4m3fn,
804804
bwd_qtype=jnp.float8_e4m3fn,
805-
bwd_weight_grad_tile_size=1
806-
/ config.quantization_local_shard_count,
805+
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
807806
op_names=("dot_general",),
808807
)
809808
]
@@ -814,8 +813,7 @@ def get_quantization_rule(config: Config):
814813
weight_qtype=jnp.float8_e4m3fn,
815814
act_qtype=jnp.float8_e4m3fn,
816815
bwd_qtype=jnp.float8_e4m3fn,
817-
bwd_weight_grad_tile_size=1
818-
/ config.quantization_local_shard_count,
816+
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
819817
op_names=("dot_general",),
820818
)
821819
]
@@ -851,6 +849,72 @@ def maybe_quantize_model(model, config):
851849
return model
852850

853851

852+
def _cast_reduced_from(arr, reduced_arr):
853+
aval = jax.typeof(reduced_arr)
854+
# In shard map
855+
if aval.sharding.mesh.axis_types[0] == jax.sharding.AxisType.Manual:
856+
for axis in aval.mat.reduced:
857+
arr = jax.lax.pcast(arr, axis, to="reduced")
858+
return arr
859+
# Outside shard map
860+
return jax.reshard(arr, aval.sharding)
861+
862+
863+
def _make_scale_tensor(scale, arr):
864+
scale_tensor = jnp.full_like(arr, scale, dtype=jnp.bfloat16)
865+
return _cast_reduced_from(scale_tensor, arr)
866+
867+
868+
def _get_max_min(target_dtype):
869+
if target_dtype in (jnp.int4, jnp.int8):
870+
return jnp.iinfo(target_dtype).max, jnp.iinfo(target_dtype).min
871+
else:
872+
return jnp.finfo(target_dtype).max.astype(jnp.bfloat16), jnp.finfo(target_dtype).min.astype(jnp.bfloat16)
873+
874+
875+
def manual_quantize(tensor, calibration_method):
876+
"""Manually quantizes a tensor based on a fixed calibration method.
877+
878+
Args:
879+
tensor: The tensor to quantize.
880+
calibration_method: A string specifying the calibration method. Expected
881+
format is "fixed,{scale},{max_val}".
882+
883+
Returns:
884+
A qwix.QArray containing the quantized value and the scale.
885+
886+
Raises:
887+
ValueError: If calibration_method is None or has an unexpected format.
888+
"""
889+
calib_method = calibration_method
890+
if calib_method is None:
891+
raise ValueError("calibration_method cannot be None for manual quantization")
892+
if not calib_method.startswith("fixed"):
893+
raise ValueError("Only static weight/activation quantization is supported, but got" f" {calib_method}")
894+
895+
parts = calib_method.split(",")
896+
if len(parts) != 3:
897+
raise ValueError(f"Unexpected format for weight calibration method: {calib_method}")
898+
899+
fwd_dtype = jnp.float8_e4m3fn
900+
dtype_max, dtype_min = _get_max_min(fwd_dtype)
901+
max_val = float(parts[2])
902+
scale = max_val / dtype_max
903+
scale = jnp.where(scale == 0, 1.0, scale)
904+
# scale must be converted to a tensor because grad has reduced axes.
905+
scale_tensor = _make_scale_tensor(scale, tensor)
906+
min_bound = _make_scale_tensor(dtype_min, tensor)
907+
max_bound = _make_scale_tensor(dtype_max, tensor)
908+
q_tensor = jnp.clip(tensor / scale_tensor, min_bound, max_bound).astype(fwd_dtype)
909+
910+
# get scale for QArray
911+
scale_shape = [1] * tensor.ndim
912+
# It must stay fully replicated for the backward pass and Pallas.
913+
scale_tensor_qpl = jnp.full(scale_shape, scale, dtype=tensor.dtype)
914+
# wrap in QArray
915+
return qpl.QArray(qvalue=q_tensor, scale=scale_tensor_qpl)
916+
917+
854918
class TransformerEngineQuantization(Quantization):
855919
"""Class for TransformerEngine quantization recipes."""
856920

src/maxtext/models/deepseek.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,8 @@ def __call__(
451451
# That is also why we can split/merge activations here as well as
452452
# in `Decoder`, since they will never be executed together.
453453
if self.config.use_batch_split_schedule:
454-
if self.config.use_qwix_quantization:
454+
# The older version of batch-split that fully uses qwix quantization.
455+
if self.config.use_qwix_quantization and not self.config.use_manual_quantization:
455456
activation_pspec = jax.sharding.PartitionSpec(
456457
("data", "fsdp", "fsdp_transpose", "expert", "context"),
457458
None,
@@ -490,7 +491,9 @@ def __call__(
490491
)(outputs)
491492
return outputs, None
492493

493-
# bf16 code path
494+
# bf16 and fp8 code path for pure-JAX batch-split.
495+
# fp8 code path supports both manual quantization and qwix
496+
# quantization.
494497
input_sharding = jax.typeof(inputs).sharding
495498
activation_pspec = jax.sharding.PartitionSpec(
496499
("data", "fsdp", "expert"),

0 commit comments

Comments
 (0)