@@ -35,6 +35,28 @@ def _get_block_reshape_sharding(
3535 return NamedSharding (input_sharding .mesh , P (* blocked_spec ))
3636
3737
38+ def _get_safe_block_quant_input_sharding (
39+ tensor : jax .Array ,
40+ quantized_axes : list [int ],
41+ ) -> NamedSharding | None :
42+ """Drop sharding on axes that are about to be split into (num_blocks, block)."""
43+ input_sharding = getattr (tensor , "sharding" , None )
44+ if not isinstance (input_sharding , NamedSharding ):
45+ return None
46+
47+ adjusted_spec = list (input_sharding .spec )
48+ changed = False
49+ for axis_idx in quantized_axes :
50+ if axis_idx < len (adjusted_spec ) and adjusted_spec [axis_idx ] is not None :
51+ adjusted_spec [axis_idx ] = None
52+ changed = True
53+
54+ if not changed :
55+ return None
56+
57+ return NamedSharding (input_sharding .mesh , P (* adjusted_spec ))
58+
59+
3860def apply_linear_quantization (
3961 model_config : ModelConfig , model : nnx .Module , is_static_input : bool = False
4062) -> nnx .Module :
@@ -76,6 +98,11 @@ def apply_linear_quantization(
7698 # Accept both sglang-jax style and Qwix-style field names.
7799 weight_dtype_str = rule .get ("weight_dtype" , rule .get ("weight_qtype" ))
78100 activation_dtype_str = rule .get ("activation_dtype" , rule .get ("act_qtype" ))
101+ weight_block_size = (
102+ rule ["weight_block_size" ]
103+ if "weight_block_size" in rule
104+ else getattr (quant_config , "weight_block_size" , None )
105+ )
79106
80107 # Convert string dtypes to jnp dtypes
81108 weight_dtype = DTYPE_MAP .get (weight_dtype_str )
@@ -89,6 +116,7 @@ def apply_linear_quantization(
89116 "pattern" : pattern ,
90117 "weight_dtype" : weight_dtype ,
91118 "activation_dtype" : activation_dtype ,
119+ "weight_block_size" : weight_block_size ,
92120 }
93121 )
94122
@@ -140,7 +168,7 @@ def _replace_linear_recursive(obj, path: str = "", visited: set | None = None):
140168 weight_dtype = rule ["weight_dtype" ],
141169 activation_dtype = rule ["activation_dtype" ],
142170 is_static_input = is_static_input ,
143- weight_block_size = getattr ( quant_config , "weight_block_size" , None ) ,
171+ weight_block_size = rule [ "weight_block_size" ] ,
144172 )
145173 # Replace the attribute and free old weights
146174 setattr (obj , attr_name , quantized_linear )
@@ -323,6 +351,7 @@ def quantize_tensor(
323351 axis = [axis ]
324352
325353 orig_shape = tensor .shape
354+ original_input_sharding = getattr (tensor , "sharding" , None )
326355 mask = None
327356
328357 if block_size is not None :
@@ -356,14 +385,19 @@ def quantize_tensor(
356385
357386 orig_shape = tensor .shape
358387 # Convert all axis into positive values.
359- axis = sorted ([i % tensor .ndim for i in axis ])
388+ quantized_axes = sorted ([i % tensor .ndim for i in axis ])
389+ safe_input_sharding = _get_safe_block_quant_input_sharding (tensor , quantized_axes )
390+ if safe_input_sharding is not None :
391+ tensor = jax .sharding .reshard (tensor , safe_input_sharding )
392+ if mask is not None :
393+ mask = jax .sharding .reshard (mask , safe_input_sharding )
394+
360395 # Shift axis by 1 since its original position is now occupied by
361396 # num_blocks dim. Also, if n axes before an axis was also quantized,
362397 # shift its position by n.
363- axis = [1 + n + i for n , i in enumerate (axis )]
398+ axis = [1 + n + i for n , i in enumerate (quantized_axes )]
364399
365- input_sharding = getattr (tensor , "sharding" , None )
366- blocked_out_sharding = _get_block_reshape_sharding (tensor , axis )
400+ blocked_out_sharding = _get_block_reshape_sharding (tensor , quantized_axes )
367401
368402 # Flatten list of lists that contains (num_blocks, block).
369403 blocked_shape = list (itertools .chain (* blocked_shape ))
@@ -383,8 +417,8 @@ def quantize_tensor(
383417 # Guard all-zero blocks/tensors: scale==0 would produce 0/0 -> NaN.
384418 scale_safe = scale + (scale == 0 ).astype (scale .dtype )
385419 tensor_q = jnp .clip (tensor / scale_safe , dtype_min , dtype_max )
386- if block_size is not None and isinstance (input_sharding , NamedSharding ):
387- tensor_q = jax .lax .reshape (tensor_q , orig_shape , out_sharding = input_sharding )
420+ if block_size is not None and isinstance (original_input_sharding , NamedSharding ):
421+ tensor_q = jax .lax .reshape (tensor_q , orig_shape , out_sharding = original_input_sharding )
388422 else :
389423 tensor_q = tensor_q .reshape (orig_shape )
390424 tensor_q = tensor_q .astype (dtype )
0 commit comments