1616
1717import functools
1818import json
19+ import qwix .pallas as qpl
1920import re
2021from typing import Tuple , Sequence , Callable
2122from dataclasses import dataclass
@@ -629,13 +630,15 @@ def get_quant_mode(quant_mode_str: str = "train"):
629630def configure_quantization (config : Config , quant_mode_str : str = "train" ):
630631 """Configure quantization based on user config and quant mode."""
631632 if config .use_batch_split_schedule and config .quantization :
632- if not (config .use_qwix_quantization and config .quantization == "fp8_full" ):
633- raise ValueError ("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`" )
634- return QwixQuantization (
635- weight_calibration_method = config .weight_quantization_calibration_method ,
636- act_calibration_method = config .act_quantization_calibration_method ,
637- bwd_calibration_method = config .bwd_quantization_calibration_method ,
638- )
633+ # The older version of batch-split that fully uses qwix quantization.
634+ if config .quantization == "fp8_full" and not config .use_manual_quantization :
635+ return QwixQuantization (
636+ weight_calibration_method = config .weight_quantization_calibration_method ,
637+ act_calibration_method = config .act_quantization_calibration_method ,
638+ bwd_calibration_method = config .bwd_quantization_calibration_method ,
639+ )
640+ # The pure JAX version of batch-split that uses manual quantization.
641+ return None
639642
640643 if config .use_qwix_quantization :
641644 return None
@@ -764,8 +767,7 @@ def get_quantization_rule(config: Config):
764767 weight_qtype = jnp .int4 ,
765768 act_qtype = jnp .int4 ,
766769 bwd_qtype = jnp .int4 ,
767- bwd_weight_grad_tile_size = 1
768- / config .quantization_local_shard_count ,
770+ bwd_weight_grad_tile_size = 1 / config .quantization_local_shard_count ,
769771 op_names = ("dot_general" ,),
770772 )
771773 ]
@@ -776,8 +778,7 @@ def get_quantization_rule(config: Config):
776778 weight_qtype = jnp .int8 ,
777779 act_qtype = jnp .int8 ,
778780 bwd_qtype = jnp .int8 ,
779- bwd_weight_grad_tile_size = 1
780- / config .quantization_local_shard_count ,
781+ bwd_weight_grad_tile_size = 1 / config .quantization_local_shard_count ,
781782 op_names = ("dot_general" ,),
782783 )
783784 ]
@@ -788,8 +789,7 @@ def get_quantization_rule(config: Config):
788789 weight_qtype = jnp .float8_e4m3fn ,
789790 act_qtype = jnp .float8_e4m3fn ,
790791 bwd_qtype = jnp .float8_e4m3fn ,
791- bwd_weight_grad_tile_size = 1
792- / config .quantization_local_shard_count ,
792+ bwd_weight_grad_tile_size = 1 / config .quantization_local_shard_count ,
793793 op_names = ("dot_general" ,),
794794 )
795795 ]
@@ -802,8 +802,7 @@ def get_quantization_rule(config: Config):
802802 weight_qtype = jnp .float8_e4m3fn ,
803803 act_qtype = jnp .float8_e4m3fn ,
804804 bwd_qtype = jnp .float8_e4m3fn ,
805- bwd_weight_grad_tile_size = 1
806- / config .quantization_local_shard_count ,
805+ bwd_weight_grad_tile_size = 1 / config .quantization_local_shard_count ,
807806 op_names = ("dot_general" ,),
808807 )
809808 ]
@@ -814,8 +813,7 @@ def get_quantization_rule(config: Config):
814813 weight_qtype = jnp .float8_e4m3fn ,
815814 act_qtype = jnp .float8_e4m3fn ,
816815 bwd_qtype = jnp .float8_e4m3fn ,
817- bwd_weight_grad_tile_size = 1
818- / config .quantization_local_shard_count ,
816+ bwd_weight_grad_tile_size = 1 / config .quantization_local_shard_count ,
819817 op_names = ("dot_general" ,),
820818 )
821819 ]
@@ -851,6 +849,72 @@ def maybe_quantize_model(model, config):
851849 return model
852850
853851
852+ def _cast_reduced_from (arr , reduced_arr ):
853+ aval = jax .typeof (reduced_arr )
854+ # In shard map
855+ if aval .sharding .mesh .axis_types [0 ] == jax .sharding .AxisType .Manual :
856+ for axis in aval .mat .reduced :
857+ arr = jax .lax .pcast (arr , axis , to = "reduced" )
858+ return arr
859+ # Outside shard map
860+ return jax .reshard (arr , aval .sharding )
861+
862+
863+ def _make_scale_tensor (scale , arr ):
864+ scale_tensor = jnp .full_like (arr , scale , dtype = jnp .bfloat16 )
865+ return _cast_reduced_from (scale_tensor , arr )
866+
867+
868+ def _get_max_min (target_dtype ):
869+ if target_dtype in (jnp .int4 , jnp .int8 ):
870+ return jnp .iinfo (target_dtype ).max , jnp .iinfo (target_dtype ).min
871+ else :
872+ return jnp .finfo (target_dtype ).max .astype (jnp .bfloat16 ), jnp .finfo (target_dtype ).min .astype (jnp .bfloat16 )
873+
874+
875+ def manual_quantize (tensor , calibration_method ):
876+ """Manually quantizes a tensor based on a fixed calibration method.
877+
878+ Args:
879+ tensor: The tensor to quantize.
880+ calibration_method: A string specifying the calibration method. Expected
881+ format is "fixed,{scale},{max_val}".
882+
883+ Returns:
884+ A qwix.QArray containing the quantized value and the scale.
885+
886+ Raises:
887+ ValueError: If calibration_method is None or has an unexpected format.
888+ """
889+ calib_method = calibration_method
890+ if calib_method is None :
891+ raise ValueError ("calibration_method cannot be None for manual quantization" )
892+ if not calib_method .startswith ("fixed" ):
893+ raise ValueError ("Only static weight/activation quantization is supported, but got" f" { calib_method } " )
894+
895+ parts = calib_method .split ("," )
896+ if len (parts ) != 3 :
897+ raise ValueError (f"Unexpected format for weight calibration method: { calib_method } " )
898+
899+ fwd_dtype = jnp .float8_e4m3fn
900+ dtype_max , dtype_min = _get_max_min (fwd_dtype )
901+ max_val = float (parts [2 ])
902+ scale = max_val / dtype_max
903+ scale = jnp .where (scale == 0 , 1.0 , scale )
904+ # scale must be converted to a tensor because grad has reduced axes.
905+ scale_tensor = _make_scale_tensor (scale , tensor )
906+ min_bound = _make_scale_tensor (dtype_min , tensor )
907+ max_bound = _make_scale_tensor (dtype_max , tensor )
908+ q_tensor = jnp .clip (tensor / scale_tensor , min_bound , max_bound ).astype (fwd_dtype )
909+
910+ # get scale for QArray
911+ scale_shape = [1 ] * tensor .ndim
912+ # It must stay fully replicated for the backward pass and Pallas.
913+ scale_tensor_qpl = jnp .full (scale_shape , scale , dtype = tensor .dtype )
914+ # wrap in QArray
915+ return qpl .QArray (qvalue = q_tensor , scale = scale_tensor_qpl )
916+
917+
854918class TransformerEngineQuantization (Quantization ):
855919 """Class for TransformerEngine quantization recipes."""
856920
0 commit comments