1010from jax .experimental import pallas as pl
1111from jax .experimental .pallas import tpu as pltpu
1212
13+ # Util.
14+
15+
16+ def align_to (x , a ):
17+ return pl .cdiv (x , a ) * a
18+
19+
1320# Define data classes.
1421
1522
@@ -182,25 +189,30 @@ def inner_kernel(
182189
183190 Args:
184191 tiled_lhs_ref: Contains value lhs[m_start:m_end, k_start:k_end]
185- tiled_rhs_ref: Contains value rhs[g_id, k_start:k_end, n_start:n_end].
186- where g_id is the group associated with lhs[m_start:m_end, :]
192+ tiled_rhs_ref: Contains value rhs[g_id, k_start:k_end, n_start:n_end]. where
193+ g_id is the group associated with lhs[m_start:m_end, :]
187194 tiled_out_ref: Contains value out[m_start:m_end, n_start:n_end]
188- partial_out_ref: Contains last size_lhs_sublane rows of the previous
189- output. If this is the first tile for grid[n_id, :, :], it will be
190- initialized to zeros.
195+ partial_out_ref: Contains last size_lhs_sublane rows of the previous output.
196+ Will be initialized to zero if this is first tile for grid[n_id, :, :].
191197 acc_ref: Reference to the accumulator.
192198 metadata_ref: Reference to the metadata.
193199 cfgs: GmmConfigs.
194200 """
195201
196202 def _matmul (is_first_k_step : bool , is_last_k_step : bool ):
197203 tiled_lhs = tiled_lhs_ref .reshape (- 1 , cfgs .tiles .tile_k )[...]
204+ tiled_rhs = tiled_rhs_ref .weight [...]
205+
206+ valid_k = cfgs .dims .size_k % cfgs .tiles .tile_k
207+ if is_last_k_step and valid_k != 0 :
208+ mask_rhs = lax .broadcasted_iota (jnp .int32 , tiled_rhs .shape , 0 ) < valid_k
209+ tiled_rhs = jnp .where (mask_rhs , tiled_rhs , 0 )
198210
199211 if cfgs .lhs_cfgs .quant_dtype is None :
200212 # Unquantized matmul path.
201213 acc = jnp .matmul (
202214 tiled_lhs ,
203- tiled_rhs_ref . weight [...] ,
215+ tiled_rhs ,
204216 preferred_element_type = jnp .float32 ,
205217 ).astype (acc_ref .dtype )
206218 else :
@@ -232,7 +244,7 @@ def _matmul(is_first_k_step: bool, is_last_k_step: bool):
232244 end_k = min (cfgs .tiles .tile_k , start_k + q_block_size )
233245
234246 block_lhs = tiled_lhs [:, start_k :end_k ]
235- block_rhs = tiled_rhs_ref . weight [start_k :end_k , start_n :end_n ]
247+ block_rhs = tiled_rhs [start_k :end_k , start_n :end_n ]
236248
237249 # Perform lhs quantization. Note that for every block_lhs,
238250 # same computation will be performed tiles_n//mxu_size times.
@@ -303,8 +315,8 @@ def _matmul(is_first_k_step: bool, is_last_k_step: bool):
303315 # fill size_lhs_sublane rows and will be revisited at the next step. By
304316 # storing the partial rows into the partial_out_ref, the next step can
305317 # read them and accumulate to them. Additionally, for group id of 2,
306- # since it completely fills the size_lhs_sublane rows, we need to
307- # initialize the partial_out_ref to zeros .
318+ # since it completely fills the size_lhs_sublane rows, we need to zero out
319+ # partial_out_ref to avoid numeric error for group 3 .
308320 last_row = m_end_local // cfgs .dims .size_lhs_sublane
309321 partial_out_ref [...] = jnp .where (
310322 m_end_local % cfgs .dims .size_lhs_sublane == 0 ,
@@ -370,8 +382,8 @@ def fill_metadata(
370382 Args:
371383 lhs_group_sizes_ref: The group sizes of lhs.
372384 group_offset_ref: Offset of the first group to process.
373- metadata_ref: Metadata that is used to determine the group id and m
374- offsets for each gmm tile.
385+ metadata_ref: Metadata that is used to determine the group id and m offsets
386+ for each gmm tile.
375387 cfgs: GmmConfigs.
376388
377389 Returns:
@@ -383,15 +395,15 @@ def fill_metadata(
383395 metadata_ref .gm_id_to_m_offset [0 ] = 0
384396
385397 @jax .named_scope ("inner_tm_loop" )
386- def inner_tm_loop (tm_id , curr_m_offset , * , end_m_offset , group_id , num_gm ):
398+ def inner_tm_loop (tm_id , curr_m_offset , * , end_m_offset , group_id ):
387399 local_offset = curr_m_offset % cfgs .dims .size_lhs_sublane
388400 tm_size = jnp .minimum (cfgs .tiles .tile_m - local_offset , end_m_offset - curr_m_offset )
389401
390- metadata_ref .gm_id_to_group_id [num_gm + tm_id ] = group_id
402+ metadata_ref .gm_id_to_group_id [tm_id ] = group_id
391403
392404 next_m_offset = curr_m_offset + tm_size
393- metadata_ref .gm_id_to_m_offset [num_gm + tm_id ] = curr_m_offset
394- metadata_ref .gm_id_to_m_offset [num_gm + tm_id + 1 ] = next_m_offset
405+ metadata_ref .gm_id_to_m_offset [tm_id ] = curr_m_offset
406+ metadata_ref .gm_id_to_m_offset [tm_id + 1 ] = next_m_offset
395407
396408 return next_m_offset
397409
@@ -426,16 +438,16 @@ def outer_group_loop(lhs_group_id, carry):
426438 # 2. If group comes before the group_offset, we should not process it.
427439 should_process = jnp .logical_and (group_size > 0 , group_id >= 0 )
428440 curr_num_gm = jnp .where (should_process , curr_num_gm , 0 )
441+ next_num_gm = num_gm + curr_num_gm
429442
430443 tm_loop_fn = functools .partial (
431444 inner_tm_loop ,
432445 end_m_offset = end_m_offset ,
433446 group_id = group_id ,
434- num_gm = num_gm ,
435447 )
436- lax .fori_loop (0 , curr_num_gm , tm_loop_fn , start_m_offset )
448+ lax .fori_loop (num_gm , next_num_gm , tm_loop_fn , start_m_offset )
437449
438- return num_gm + curr_num_gm , end_m_offset
450+ return next_num_gm , end_m_offset
439451
440452 num_gm , _ = lax .fori_loop (0 , max_num_group , outer_group_loop , (0 , 0 ))
441453 return num_gm
@@ -457,7 +469,7 @@ def zero_out_start(
457469 zero_ref [...] = jnp .zeros_like (zero_ref )
458470
459471 zero_dma = zero_ref .reshape (- 1 , dims .size_lhs_sublane , num_lanes )
460- out_dma = out_ref .reshape (- 1 , dims .size_lhs_sublane , dims . size_n )
472+ out_dma = out_ref .reshape (- 1 , dims .size_lhs_sublane , out_ref . shape [ - 1 ] )
461473 row_size = zero_dma .shape [0 ]
462474
463475 compute_start = metadata_ref .gm_id_to_m_offset [0 ]
@@ -480,9 +492,9 @@ def fill_zero(i, zero_size, *, start, end):
480492
481493 # Static loop. Will be unrolled during compile time.
482494 for n_start in range (0 , dims .size_n , num_lanes ):
483- n_end = min ( n_start + num_lanes , dims . size_n )
495+ n_end = n_start + num_lanes
484496 pltpu .make_async_copy (
485- src_ref = zero_dma .at [pl .ds (0 , dma_size ), :, : n_end - n_start ],
497+ src_ref = zero_dma .at [pl .ds (0 , dma_size )],
486498 dst_ref = out_dma .at [pl .ds (dma_start , dma_size ), :, n_start :n_end ],
487499 sem = semaphore_ref .at [0 ],
488500 ).start (priority = 1 )
@@ -509,7 +521,7 @@ def zero_out_end(
509521 * ,
510522 dims : Dimensions ,
511523):
512- out_dma = out_ref .reshape (- 1 , dims .size_lhs_sublane , dims . size_n )
524+ out_dma = out_ref .reshape (- 1 , dims .size_lhs_sublane , out_ref . shape [ - 1 ] )
513525 pltpu .make_async_copy (
514526 src_ref = out_dma .at [pl .ds (0 , zero_size )],
515527 dst_ref = out_dma .at [pl .ds (0 , zero_size )],
@@ -562,8 +574,8 @@ def kernel_main(
562574 cfgs: GmmConfigs.
563575 """
564576
565- num_k = cfgs .dims .size_k // cfgs .tiles .tile_k
566- num_n = cfgs .dims .size_n // cfgs .tiles .tile_n
577+ num_k = pl . cdiv ( cfgs .dims .size_k , cfgs .tiles .tile_k )
578+ num_n = pl . cdiv ( cfgs .dims .size_n , cfgs .tiles .tile_n )
567579
568580 # Fill metadata buffer and return number of group & m iterations.
569581 num_gm = fill_metadata (
@@ -595,8 +607,8 @@ def kernel_main(
595607
596608 # Bounded slice requires second last dim to be aligned to the sublane size.
597609 # rhs_ref uses static tiling thus reshape is not needed.
598- lhs_in = lhs_ref .reshape (- 1 , cfgs .dims .size_lhs_sublane , cfgs . dims . size_k )
599- out_in = out_ref .reshape (- 1 , cfgs .dims .size_lhs_sublane , cfgs . dims . size_n )
610+ lhs_in = lhs_ref .reshape (- 1 , cfgs .dims .size_lhs_sublane , lhs_ref . shape [ - 1 ] )
611+ out_in = out_ref .reshape (- 1 , cfgs .dims .size_lhs_sublane , out_ref . shape [ - 1 ] )
600612 scratches = [partial_out_ref , acc_ref , metadata_ref ]
601613 pipeline_fn (lhs_in , rhs_ref , out_in , scratches = scratches )
602614
@@ -615,13 +627,8 @@ def calculate_tiling(
615627 lhs_bits = jax .dtypes .itemsize_bits (lhs_dtype )
616628 rhs_bits = jax .dtypes .itemsize_bits (rhs_dtype )
617629
618- # Otherwise we run into a VMEM OOM
619- # TODO (jacobplatin/wenxindongwork): remove this hotfix once FP4 RHS bug is fixed
620- if rhs_bits == 4 :
621- rhs_bits = 8
622-
623630 # When using bf16 for lhs and rhs, 128 is the largest tile_m value that is
624- # safe to use for most scenarios. But if are using lower bitwidth, we need
631+ # safe to use for most scenarios. But if lower bitwidth is used , we need
625632 # to tweak tile_m to account for using faster hardware unit.
626633 # TODO(kyuyeunk): Account for different TPU hardware specs.
627634 bf16_bf16_tile_m = 128
@@ -633,60 +640,40 @@ def calculate_tiling(
633640 # Calculate vmem limit for a single rhs buffer when using triple buffers.
634641 num_rhs_buffers = 3
635642 rhs_vmem_target = vmem_limit_bytes // num_rhs_buffers
636- rhs_base_size_bytes = rhs_size_bytes = dims .size_k * dims .size_n * rhs_bits // 8
637-
638- num_lanes = pltpu .get_tpu_info ().num_lanes
643+ base_rhs_size_bytes = dims .size_k * dims .size_n * rhs_bits // 8
639644
640645 # To avoid stalling MXU, we add some buffer room where tile_n cannot go
641646 # smaller than 2x of mxu_column_size.
642647 tile_n_limit = pltpu .get_tpu_info ().mxu_column_size * 2
643648 tile_n_limit = min (tile_n_limit , dims .size_n )
644649
645650 # Initialize tile_k and tile_n to their maximum valid values.
646- num_k_tiles = 1
647- tile_k = dims .size_k
648- k_rem = 0
649-
650- # last_valid_n_tiles stores num_n_tiles that can evenly divide size_n but may
651- # or may not be sufficient to fit rhs into vmem target without changing
652- # tile_k.
653- last_valid_n_tiles = num_n_tiles = 1
654- tile_n = dims .size_n
655- n_rem = 0
651+ num_k_tiles = num_n_tiles = 1
652+ num_lanes = pltpu .get_tpu_info ().num_lanes
653+ tile_k = align_to (dims .size_k , num_lanes )
654+ tile_n = align_to (dims .size_n , num_lanes )
656655
657656 # Multiple k tiles will introduce accumulation overhead. Thus, we first try
658657 # to fit rhs into vmem by only adjusting tile_n.
659658
660659 # Decrease tile_n until rhs fits in vmem target.
661- while rhs_size_bytes > rhs_vmem_target or n_rem or tile_n % num_lanes :
662- if n_rem == 0 and tile_n % num_lanes == 0 :
663- last_valid_n_tiles = num_n_tiles
660+ while pl .cdiv (base_rhs_size_bytes , num_n_tiles ) >= rhs_vmem_target and tile_n >= tile_n_limit :
664661 num_n_tiles += 1
665- tile_n = dims .size_n // num_n_tiles
666- n_rem = dims .size_n % num_n_tiles
667- rhs_size_bytes = rhs_base_size_bytes // num_n_tiles
662+ tile_n = align_to (dims .size_n , num_n_tiles * num_lanes ) // num_n_tiles
668663
669664 # If decreasing tile_n is no longer possible, we decrease tile_k instead.
670- if tile_n % num_lanes or n_rem or tile_n < tile_n_limit :
671- num_n_tiles = last_valid_n_tiles
672- tile_n = dims .size_n // num_n_tiles
673- rhs_size_bytes = rhs_base_size_bytes // num_n_tiles
665+ if tile_n < tile_n_limit :
666+ num_n_tiles -= 1
667+ tile_n = align_to (dims .size_n , num_n_tiles * num_lanes ) // num_n_tiles
674668
675669 # Decrease tile_k until rhs fits in vmem target.
676- while rhs_size_bytes > rhs_vmem_target or k_rem or tile_k % num_lanes :
670+ base_rhs_size_bytes = pl .cdiv (base_rhs_size_bytes , num_n_tiles )
671+ while pl .cdiv (base_rhs_size_bytes , num_k_tiles ) >= rhs_vmem_target :
677672 num_k_tiles += 1
678- tile_k = dims .size_k // num_k_tiles
679- k_rem = dims .size_k % num_k_tiles
680- rhs_size_bytes = rhs_base_size_bytes // (num_k_tiles * num_n_tiles )
681-
682- is_tile_n_invalid = tile_n % num_lanes or n_rem or tile_n < tile_n_limit
683- is_tile_k_invalid = tile_k % num_lanes or k_rem
673+ tile_k = align_to (dims .size_k , num_k_tiles * num_lanes ) // num_k_tiles
684674
685- if is_tile_n_invalid or is_tile_k_invalid :
686- raise ValueError (
687- f"Could not find valid tile sizes for { dims = } and"
688- f" { rhs_vmem_target = } . Last tried tiles: { tile_m = } { tile_k = } { tile_n = } "
689- )
675+ if tile_n == 0 or tile_k == 0 :
676+ raise ValueError (f"Could not find valid tile sizes for { dims = } and { rhs_vmem_target = } ." )
690677
691678 return TileSizes (tile_m = tile_m , tile_k = tile_k , tile_n = tile_n )
692679
@@ -718,14 +705,8 @@ def validate_inputs(
718705
719706 assert group_offset .shape == (1 ,)
720707
721- # TODO(kyuyeunk): Add support for implicit padding along lane dimensions.
722- num_lanes = pltpu .get_tpu_info ().num_lanes
723- if size_k % num_lanes != 0 or size_n % num_lanes != 0 :
724- raise NotImplementedError ("Implicit padding along lane dimensions is not supported." )
725-
726- bitwidth = jax .dtypes .itemsize_bits (lhs .dtype )
727- packing = 32 // bitwidth
728- size_lhs_sublane = pltpu .get_tpu_info ().num_sublanes * packing
708+ size_lhs_sublane = pltpu .get_tpu_info ().get_sublane_tiling (lhs .dtype )
709+ size_lhs_sublane = min (size_lhs_sublane , size_m )
729710
730711 return Dimensions (
731712 size_m = size_m ,
@@ -737,20 +718,6 @@ def validate_inputs(
737718 )
738719
739720
740- def validate_tiles (tiles : TileSizes , dims : Dimensions ):
741- """Validates the tile sizes for the GMM kernel."""
742-
743- def _validate (x : int , tx : int , name : str ):
744- if x % tx != 0 :
745- raise ValueError (f"size_{ name } ={ x } is not divisible by tile_{ name } ={ tx } ." )
746- if tx > x :
747- raise ValueError (f"tile_{ name } ={ tx } is larger than size_{ name } ={ x } ." )
748-
749- _validate (dims .size_m , pltpu .get_tpu_info ().num_sublanes , "m" )
750- _validate (dims .size_k , tiles .tile_k , "k" )
751- _validate (dims .size_n , tiles .tile_n , "n" )
752-
753-
754721def get_cost_estimate (
755722 lhs : jax .Array ,
756723 rhs : WeightsRef ,
@@ -764,8 +731,9 @@ def get_cost_estimate(
764731
765732 lhs_bytes = dims .size_m * dims .size_k * lhs .dtype .itemsize
766733
767- # TODO(kyuyeunk): Handle 4-bit quantization case.
768- rhs_bytes = dims .size_group * dims .size_k * dims .size_n * rhs .weight .dtype .itemsize
734+ rhs_bytes = (
735+ dims .size_group * dims .size_k * dims .size_n * jax .dtypes .itemsize_bits (rhs .weight )
736+ ) // 8
769737 if rhs .scale is not None :
770738 rhs_bytes += dims .size_n * jnp .dtype (jnp .float32 ).itemsize
771739 if rhs .bias is not None :
@@ -860,8 +828,6 @@ def make_gmm_configs(
860828 lhs_q_dtype = lhs_q_dtype if lhs_q_dtype is not None else lhs .dtype
861829 tiles = tile_info (lhs_q_dtype , rhs .dtype , dims , vmem_limit_bytes )
862830
863- validate_tiles (tiles , dims )
864-
865831 return GmmConfigs (
866832 dims = dims ,
867833 tiles = tiles ,
@@ -916,7 +882,7 @@ def gmm_v2(
916882 """GMM kernel implemented with emit_pipeline.
917883
918884 Dynamically calculate offset lhs/out tiles to reduce redundant computations.
919- Additionally, adjusting dma size based on number of valid rows and utilize
885+ Additionally, it adjusts dma size based on number of valid rows and utilize
920886 triple buffering on weights to better utilize memory.
921887
922888 Args:
@@ -976,7 +942,7 @@ def gmm_v2(
976942 rhs_bias_spec = pl .BlockSpec (memory_space = pltpu .HBM )
977943
978944 # Initialize scratch shapes.
979- max_num_gm = dims .size_group + dims .size_m // tiles .tile_m - 1
945+ max_num_gm = dims .size_group + pl . cdiv ( dims .size_m , tiles .tile_m ) - 1
980946
981947 scratch_shapes = [
982948 # partial_out_ref
@@ -990,6 +956,7 @@ def gmm_v2(
990956 ),
991957 ]
992958
959+ num_lanes = pltpu .get_tpu_info ().num_lanes
993960 if cfgs .zero_init :
994961 # TODO(kyuyeunk): Create better heuristics for determining this value.
995962 target_zero_ref_bytes = 2 * 1024 * 1024
@@ -1004,7 +971,6 @@ def gmm_v2(
1004971 # (which is smallest allowed column size for DMA) and reuse the buffer by
1005972 # size_n//num_lanes times in a single tile, we can significantly increase
1006973 # tile_zero_m without triggering OOM.
1007- num_lanes = pltpu .get_tpu_info ().num_lanes
1008974 out_bytes = jnp .dtype (cfgs .out_dtype ).itemsize
1009975 tile_zero_m = target_zero_ref_bytes // num_lanes // out_bytes
1010976 tile_zero_m = min (tile_zero_m , dims .size_m )
@@ -1016,7 +982,8 @@ def gmm_v2(
1016982 else :
1017983 scratch_shapes += [None , None ]
1018984
1019- out_init = jax .ShapeDtypeStruct ((dims .size_m , dims .size_n ), cfgs .out_dtype )
985+ aligned_n = align_to (dims .size_n , num_lanes )
986+ out_init = jax .ShapeDtypeStruct ((dims .size_m , aligned_n ), cfgs .out_dtype )
1020987 rhs_weights = WeightsRef (weight = rhs , scale = rhs_scale , bias = rhs_bias )
1021988
1022989 return pl .pallas_call (
@@ -1042,22 +1009,9 @@ def gmm_v2(
10421009 name = get_scope_name (dims , tiles ),
10431010 cost_estimate = get_cost_estimate (lhs , rhs_weights , out_init .dtype , dims ),
10441011 metadata = get_metadata (cfgs ),
1045- )(group_sizes , group_offset , lhs , rhs_weights )
1012+ )(group_sizes , group_offset , lhs , rhs_weights )[:, : dims . size_n ]
10461013
10471014
1048- def is_supported_by_gmm_v2 (lhs : jax .Array , rhs : jax .Array , rhs_scale : jax .Array | None ) -> bool :
1049- """Return false if gmm_v2 does not support the inputs yet."""
1050-
1051- if rhs_scale is not None and rhs_scale .shape [1 ] != 1 :
1052- # gmm_v2 does not support subchannel quantization.
1053- return False
1054- # gmm_v2 does not support implicit padding along lane dimension.
1055- num_lanes = pltpu .get_tpu_info ().num_lanes
1056- if lhs .shape [- 1 ] % num_lanes != 0 or rhs .shape [- 1 ] % num_lanes != 0 :
1057- return False
1058- # gmm_v2 does not support when lhs is not multiple of sublane size.
1059- lhs_bytes = lhs .dtype .itemsize
1060- if lhs .shape [0 ] % (pltpu .get_tpu_info ().num_sublanes * lhs_bytes ) != 0 :
1061- return False
1062- # Handle weird edge cases where inputs are already quantized.
1063- return lhs .dtype in [jnp .bfloat16 , jnp .float32 ]
1015+ def is_supported_by_gmm_v2 (rhs_scale : jax .Array | None ) -> bool :
1016+ # gmm_v2 does not support subchannel quantization.
1017+ return rhs_scale is None or rhs_scale .shape [1 ] == 1
0 commit comments