55
66import jax
77import jax .numpy as jnp
8- from jax ._src import dtypes
98
109from .tuned_block_sizes import TunedValue
1110
1211
12+ def _dtype_bits (dtype : jnp .dtype ) -> int :
13+ return jnp .dtype (dtype ).itemsize * 8
14+
15+
1316def unfold_args (
1417 conditions : tuple [jax .Array | bool , ...],
1518 fn_conditions : tuple [bool , ...],
@@ -191,7 +194,8 @@ def quantize_array(
191194
192195 # TODO(kyuyeunk): Investigate performance gain from non xlu transpose.
193196 scale = jnp .transpose (x_abs_max / dtype_max )
194- scale_inv = jnp .nan_to_num (1 / scale , dtype_max )
197+ scale = jnp .where (scale == 0 , 1.0 , scale )
198+ scale_inv = jnp .nan_to_num (1 / scale , nan = dtype_max , posinf = dtype_max , neginf = - dtype_max )
195199 return (x * scale_inv ).astype (quant_dtype ), scale .astype (jnp .float32 )
196200
197201
@@ -215,13 +219,11 @@ def get_vmem_limit(
215219 """Calculate VMEM limit for the kernel."""
216220
217221 # Calculate in/out VMEM size.
218- x_size = (batch_block_size * in_block_size * dtypes .itemsize_bits (x_dtype ))
219- x_abs_max_size = batch_block_size * dtypes .itemsize_bits (scale_dtype )
220- w_q_size = (out_block_size * in_block_size *
221- dtypes .itemsize_bits (w_q_dtype ))
222- w_scale_size = out_block_size * dtypes .itemsize_bits (scale_dtype )
223- out_size = (batch_block_size * out_block_size *
224- dtypes .itemsize_bits (out_dtype ))
222+ x_size = batch_block_size * in_block_size * _dtype_bits (x_dtype )
223+ x_abs_max_size = batch_block_size * _dtype_bits (scale_dtype )
224+ w_q_size = out_block_size * in_block_size * _dtype_bits (w_q_dtype )
225+ w_scale_size = out_block_size * _dtype_bits (scale_dtype )
226+ out_size = batch_block_size * out_block_size * _dtype_bits (out_dtype )
225227
226228 vmem_in_out = x_size + x_abs_max_size + w_q_size + w_scale_size + out_size
227229 vmem_in_out *= 2 # Account for compute and vreg spills.
@@ -235,11 +237,9 @@ def get_vmem_limit(
235237 vmem_in_out += out_size if (n_batch > 1 or n_out > 1 ) else 0
236238
237239 # Calculate scratch VMEM size.
238- acc_size = (batch_block_size * out_block_size *
239- dtypes .itemsize_bits (acc_dtype ))
240- x_q_size = (batch_block_size * in_block_size *
241- dtypes .itemsize_bits (x_q_dtype ))
242- x_scale_size = batch_block_size * dtypes .itemsize_bits (scale_dtype )
240+ acc_size = batch_block_size * out_block_size * _dtype_bits (acc_dtype )
241+ x_q_size = batch_block_size * in_block_size * _dtype_bits (x_q_dtype )
242+ x_scale_size = batch_block_size * _dtype_bits (scale_dtype )
243243
244244 vmem_scratch = acc_size if save_acc else 0
245245 vmem_scratch += x_q_size + x_scale_size if save_x_q else 0
@@ -277,10 +277,14 @@ def validate_inputs(
277277 # Verify input shapes.
278278 if x .shape [1 ] != w_q .shape [1 ]:
279279 raise ValueError (f'{ x .shape [1 ]= } must be equal to { w_q .shape [1 ]= } ' )
280- if w_q .shape [0 ] != w_scale .shape [1 ] and (w_scale .ndim == 3 and w_q .shape [0 ]
281- != w_scale .shape [2 ]):
282- raise ValueError (
283- f"{ w_q .shape [0 ]= } must be equal to { w_scale .shape [1 ]= } " )
280+ if w_scale .ndim == 2 :
281+ if w_q .shape [0 ] != w_scale .shape [1 ]:
282+ raise ValueError (f"{ w_q .shape [0 ]= } must be equal to { w_scale .shape [1 ]= } " )
283+ elif w_scale .ndim == 3 :
284+ if w_q .shape [0 ] != w_scale .shape [2 ]:
285+ raise ValueError (f"{ w_q .shape [0 ]= } must be equal to { w_scale .shape [2 ]= } " )
286+ else :
287+ raise ValueError (f"Unsupported { w_scale .ndim = } for quantized weight scale." )
284288 if x_abs_max is not None and x_abs_max .shape != (1 , x .shape [0 ]):
285289 raise ValueError (
286290 f"{ x_abs_max .shape = } must be equal to (1, { x .shape [0 ]= } )" )
@@ -317,5 +321,5 @@ def quantize_block(data, axis, target_dtype):
317321 if jnp .issubdtype (target_dtype , jnp .floating ):
318322 data_q = (data / scale ).clip (dtype_min , dtype_max ).astype (target_dtype )
319323 else :
320- data_q = jnp .round (data / scale ).astype (target_dtype )
324+ data_q = jnp .clip ( jnp . round (data / scale ), dtype_min , dtype_max ).astype (target_dtype )
321325 return data_q , scale
0 commit comments