Skip to content

Commit 62be331

Browse files
committed
format
1 parent 1e71b04 commit 62be331

3 files changed

Lines changed: 71 additions & 119 deletions

File tree

python/sgl_jax/srt/kernels/gmm/megablox_gmm_kernel/gmm_v2.py

Lines changed: 68 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
from jax.experimental import pallas as pl
1111
from jax.experimental.pallas import tpu as pltpu
1212

13+
# Util.
14+
15+
16+
def align_to(x, a):
17+
return pl.cdiv(x, a) * a
18+
19+
1320
# Define data classes.
1421

1522

@@ -182,25 +189,30 @@ def inner_kernel(
182189
183190
Args:
184191
tiled_lhs_ref: Contains value lhs[m_start:m_end, k_start:k_end]
185-
tiled_rhs_ref: Contains value rhs[g_id, k_start:k_end, n_start:n_end].
186-
where g_id is the group associated with lhs[m_start:m_end, :]
192+
tiled_rhs_ref: Contains value rhs[g_id, k_start:k_end, n_start:n_end]. where
193+
g_id is the group associated with lhs[m_start:m_end, :]
187194
tiled_out_ref: Contains value out[m_start:m_end, n_start:n_end]
188-
partial_out_ref: Contains last size_lhs_sublane rows of the previous
189-
output. If this is the first tile for grid[n_id, :, :], it will be
190-
initialized to zeros.
195+
partial_out_ref: Contains last size_lhs_sublane rows of the previous output.
196+
Will be initialized to zero if this is first tile for grid[n_id, :, :].
191197
acc_ref: Reference to the accumulator.
192198
metadata_ref: Reference to the metadata.
193199
cfgs: GmmConfigs.
194200
"""
195201

196202
def _matmul(is_first_k_step: bool, is_last_k_step: bool):
197203
tiled_lhs = tiled_lhs_ref.reshape(-1, cfgs.tiles.tile_k)[...]
204+
tiled_rhs = tiled_rhs_ref.weight[...]
205+
206+
valid_k = cfgs.dims.size_k % cfgs.tiles.tile_k
207+
if is_last_k_step and valid_k != 0:
208+
mask_rhs = lax.broadcasted_iota(jnp.int32, tiled_rhs.shape, 0) < valid_k
209+
tiled_rhs = jnp.where(mask_rhs, tiled_rhs, 0)
198210

199211
if cfgs.lhs_cfgs.quant_dtype is None:
200212
# Unquantized matmul path.
201213
acc = jnp.matmul(
202214
tiled_lhs,
203-
tiled_rhs_ref.weight[...],
215+
tiled_rhs,
204216
preferred_element_type=jnp.float32,
205217
).astype(acc_ref.dtype)
206218
else:
@@ -232,7 +244,7 @@ def _matmul(is_first_k_step: bool, is_last_k_step: bool):
232244
end_k = min(cfgs.tiles.tile_k, start_k + q_block_size)
233245

234246
block_lhs = tiled_lhs[:, start_k:end_k]
235-
block_rhs = tiled_rhs_ref.weight[start_k:end_k, start_n:end_n]
247+
block_rhs = tiled_rhs[start_k:end_k, start_n:end_n]
236248

237249
# Perform lhs quantization. Note that for every block_lhs,
238250
# same computation will be performed tiles_n//mxu_size times.
@@ -303,8 +315,8 @@ def _matmul(is_first_k_step: bool, is_last_k_step: bool):
303315
# fill size_lhs_sublane rows and will be revisited at the next step. By
304316
# storing the partial rows into the partial_out_ref, the next step can
305317
# read them and accumulate to them. Additionally, for group id of 2,
306-
# since it completely fills the size_lhs_sublane rows, we need to
307-
# initialize the partial_out_ref to zeros.
318+
# since it completely fills the size_lhs_sublane rows, we need to zero out
319+
# partial_out_ref to avoid numeric error for group 3.
308320
last_row = m_end_local // cfgs.dims.size_lhs_sublane
309321
partial_out_ref[...] = jnp.where(
310322
m_end_local % cfgs.dims.size_lhs_sublane == 0,
@@ -370,8 +382,8 @@ def fill_metadata(
370382
Args:
371383
lhs_group_sizes_ref: The group sizes of lhs.
372384
group_offset_ref: Offset of the first group to process.
373-
metadata_ref: Metadata that is used to determine the group id and m
374-
offsets for each gmm tile.
385+
metadata_ref: Metadata that is used to determine the group id and m offsets
386+
for each gmm tile.
375387
cfgs: GmmConfigs.
376388
377389
Returns:
@@ -383,15 +395,15 @@ def fill_metadata(
383395
metadata_ref.gm_id_to_m_offset[0] = 0
384396

385397
@jax.named_scope("inner_tm_loop")
386-
def inner_tm_loop(tm_id, curr_m_offset, *, end_m_offset, group_id, num_gm):
398+
def inner_tm_loop(tm_id, curr_m_offset, *, end_m_offset, group_id):
387399
local_offset = curr_m_offset % cfgs.dims.size_lhs_sublane
388400
tm_size = jnp.minimum(cfgs.tiles.tile_m - local_offset, end_m_offset - curr_m_offset)
389401

390-
metadata_ref.gm_id_to_group_id[num_gm + tm_id] = group_id
402+
metadata_ref.gm_id_to_group_id[tm_id] = group_id
391403

392404
next_m_offset = curr_m_offset + tm_size
393-
metadata_ref.gm_id_to_m_offset[num_gm + tm_id] = curr_m_offset
394-
metadata_ref.gm_id_to_m_offset[num_gm + tm_id + 1] = next_m_offset
405+
metadata_ref.gm_id_to_m_offset[tm_id] = curr_m_offset
406+
metadata_ref.gm_id_to_m_offset[tm_id + 1] = next_m_offset
395407

396408
return next_m_offset
397409

@@ -426,16 +438,16 @@ def outer_group_loop(lhs_group_id, carry):
426438
# 2. If group comes before the group_offset, we should not process it.
427439
should_process = jnp.logical_and(group_size > 0, group_id >= 0)
428440
curr_num_gm = jnp.where(should_process, curr_num_gm, 0)
441+
next_num_gm = num_gm + curr_num_gm
429442

430443
tm_loop_fn = functools.partial(
431444
inner_tm_loop,
432445
end_m_offset=end_m_offset,
433446
group_id=group_id,
434-
num_gm=num_gm,
435447
)
436-
lax.fori_loop(0, curr_num_gm, tm_loop_fn, start_m_offset)
448+
lax.fori_loop(num_gm, next_num_gm, tm_loop_fn, start_m_offset)
437449

438-
return num_gm + curr_num_gm, end_m_offset
450+
return next_num_gm, end_m_offset
439451

440452
num_gm, _ = lax.fori_loop(0, max_num_group, outer_group_loop, (0, 0))
441453
return num_gm
@@ -457,7 +469,7 @@ def zero_out_start(
457469
zero_ref[...] = jnp.zeros_like(zero_ref)
458470

459471
zero_dma = zero_ref.reshape(-1, dims.size_lhs_sublane, num_lanes)
460-
out_dma = out_ref.reshape(-1, dims.size_lhs_sublane, dims.size_n)
472+
out_dma = out_ref.reshape(-1, dims.size_lhs_sublane, out_ref.shape[-1])
461473
row_size = zero_dma.shape[0]
462474

463475
compute_start = metadata_ref.gm_id_to_m_offset[0]
@@ -480,9 +492,9 @@ def fill_zero(i, zero_size, *, start, end):
480492

481493
# Static loop. Will be unrolled during compile time.
482494
for n_start in range(0, dims.size_n, num_lanes):
483-
n_end = min(n_start + num_lanes, dims.size_n)
495+
n_end = n_start + num_lanes
484496
pltpu.make_async_copy(
485-
src_ref=zero_dma.at[pl.ds(0, dma_size), :, : n_end - n_start],
497+
src_ref=zero_dma.at[pl.ds(0, dma_size)],
486498
dst_ref=out_dma.at[pl.ds(dma_start, dma_size), :, n_start:n_end],
487499
sem=semaphore_ref.at[0],
488500
).start(priority=1)
@@ -509,7 +521,7 @@ def zero_out_end(
509521
*,
510522
dims: Dimensions,
511523
):
512-
out_dma = out_ref.reshape(-1, dims.size_lhs_sublane, dims.size_n)
524+
out_dma = out_ref.reshape(-1, dims.size_lhs_sublane, out_ref.shape[-1])
513525
pltpu.make_async_copy(
514526
src_ref=out_dma.at[pl.ds(0, zero_size)],
515527
dst_ref=out_dma.at[pl.ds(0, zero_size)],
@@ -562,8 +574,8 @@ def kernel_main(
562574
cfgs: GmmConfigs.
563575
"""
564576

565-
num_k = cfgs.dims.size_k // cfgs.tiles.tile_k
566-
num_n = cfgs.dims.size_n // cfgs.tiles.tile_n
577+
num_k = pl.cdiv(cfgs.dims.size_k, cfgs.tiles.tile_k)
578+
num_n = pl.cdiv(cfgs.dims.size_n, cfgs.tiles.tile_n)
567579

568580
# Fill metadata buffer and return number of group & m iterations.
569581
num_gm = fill_metadata(
@@ -595,8 +607,8 @@ def kernel_main(
595607

596608
# Bounded slice requires second last dim to be aligned to the sublane size.
597609
# rhs_ref uses static tiling thus reshape is not needed.
598-
lhs_in = lhs_ref.reshape(-1, cfgs.dims.size_lhs_sublane, cfgs.dims.size_k)
599-
out_in = out_ref.reshape(-1, cfgs.dims.size_lhs_sublane, cfgs.dims.size_n)
610+
lhs_in = lhs_ref.reshape(-1, cfgs.dims.size_lhs_sublane, lhs_ref.shape[-1])
611+
out_in = out_ref.reshape(-1, cfgs.dims.size_lhs_sublane, out_ref.shape[-1])
600612
scratches = [partial_out_ref, acc_ref, metadata_ref]
601613
pipeline_fn(lhs_in, rhs_ref, out_in, scratches=scratches)
602614

@@ -615,13 +627,8 @@ def calculate_tiling(
615627
lhs_bits = jax.dtypes.itemsize_bits(lhs_dtype)
616628
rhs_bits = jax.dtypes.itemsize_bits(rhs_dtype)
617629

618-
# Otherwise we run into a VMEM OOM
619-
# TODO (jacobplatin/wenxindongwork): remove this hotfix once FP4 RHS bug is fixed
620-
if rhs_bits == 4:
621-
rhs_bits = 8
622-
623630
# When using bf16 for lhs and rhs, 128 is the largest tile_m value that is
624-
# safe to use for most scenarios. But if are using lower bitwidth, we need
631+
# safe to use for most scenarios. But if lower bitwidth is used, we need
625632
# to tweak tile_m to account for using faster hardware unit.
626633
# TODO(kyuyeunk): Account for different TPU hardware specs.
627634
bf16_bf16_tile_m = 128
@@ -633,60 +640,40 @@ def calculate_tiling(
633640
# Calculate vmem limit for a single rhs buffer when using triple buffers.
634641
num_rhs_buffers = 3
635642
rhs_vmem_target = vmem_limit_bytes // num_rhs_buffers
636-
rhs_base_size_bytes = rhs_size_bytes = dims.size_k * dims.size_n * rhs_bits // 8
637-
638-
num_lanes = pltpu.get_tpu_info().num_lanes
643+
base_rhs_size_bytes = dims.size_k * dims.size_n * rhs_bits // 8
639644

640645
# To avoid stalling MXU, we add some buffer room where tile_n cannot go
641646
# smaller than 2x of mxu_column_size.
642647
tile_n_limit = pltpu.get_tpu_info().mxu_column_size * 2
643648
tile_n_limit = min(tile_n_limit, dims.size_n)
644649

645650
# Initialize tile_k and tile_n to their maximum valid values.
646-
num_k_tiles = 1
647-
tile_k = dims.size_k
648-
k_rem = 0
649-
650-
# last_valid_n_tiles stores num_n_tiles that can evenly divide size_n but may
651-
# or may not be sufficient to fit rhs into vmem target without changing
652-
# tile_k.
653-
last_valid_n_tiles = num_n_tiles = 1
654-
tile_n = dims.size_n
655-
n_rem = 0
651+
num_k_tiles = num_n_tiles = 1
652+
num_lanes = pltpu.get_tpu_info().num_lanes
653+
tile_k = align_to(dims.size_k, num_lanes)
654+
tile_n = align_to(dims.size_n, num_lanes)
656655

657656
# Multiple k tiles will introduce accumulation overhead. Thus, we first try
658657
# to fit rhs into vmem by only adjusting tile_n.
659658

660659
# Decrease tile_n until rhs fits in vmem target.
661-
while rhs_size_bytes > rhs_vmem_target or n_rem or tile_n % num_lanes:
662-
if n_rem == 0 and tile_n % num_lanes == 0:
663-
last_valid_n_tiles = num_n_tiles
660+
while pl.cdiv(base_rhs_size_bytes, num_n_tiles) >= rhs_vmem_target and tile_n >= tile_n_limit:
664661
num_n_tiles += 1
665-
tile_n = dims.size_n // num_n_tiles
666-
n_rem = dims.size_n % num_n_tiles
667-
rhs_size_bytes = rhs_base_size_bytes // num_n_tiles
662+
tile_n = align_to(dims.size_n, num_n_tiles * num_lanes) // num_n_tiles
668663

669664
# If decreasing tile_n is no longer possible, we decrease tile_k instead.
670-
if tile_n % num_lanes or n_rem or tile_n < tile_n_limit:
671-
num_n_tiles = last_valid_n_tiles
672-
tile_n = dims.size_n // num_n_tiles
673-
rhs_size_bytes = rhs_base_size_bytes // num_n_tiles
665+
if tile_n < tile_n_limit:
666+
num_n_tiles -= 1
667+
tile_n = align_to(dims.size_n, num_n_tiles * num_lanes) // num_n_tiles
674668

675669
# Decrease tile_k until rhs fits in vmem target.
676-
while rhs_size_bytes > rhs_vmem_target or k_rem or tile_k % num_lanes:
670+
base_rhs_size_bytes = pl.cdiv(base_rhs_size_bytes, num_n_tiles)
671+
while pl.cdiv(base_rhs_size_bytes, num_k_tiles) >= rhs_vmem_target:
677672
num_k_tiles += 1
678-
tile_k = dims.size_k // num_k_tiles
679-
k_rem = dims.size_k % num_k_tiles
680-
rhs_size_bytes = rhs_base_size_bytes // (num_k_tiles * num_n_tiles)
681-
682-
is_tile_n_invalid = tile_n % num_lanes or n_rem or tile_n < tile_n_limit
683-
is_tile_k_invalid = tile_k % num_lanes or k_rem
673+
tile_k = align_to(dims.size_k, num_k_tiles * num_lanes) // num_k_tiles
684674

685-
if is_tile_n_invalid or is_tile_k_invalid:
686-
raise ValueError(
687-
f"Could not find valid tile sizes for {dims=} and"
688-
f" {rhs_vmem_target=}. Last tried tiles: {tile_m=} {tile_k=} {tile_n=}"
689-
)
675+
if tile_n == 0 or tile_k == 0:
676+
raise ValueError(f"Could not find valid tile sizes for {dims=} and {rhs_vmem_target=}.")
690677

691678
return TileSizes(tile_m=tile_m, tile_k=tile_k, tile_n=tile_n)
692679

@@ -718,14 +705,8 @@ def validate_inputs(
718705

719706
assert group_offset.shape == (1,)
720707

721-
# TODO(kyuyeunk): Add support for implicit padding along lane dimensions.
722-
num_lanes = pltpu.get_tpu_info().num_lanes
723-
if size_k % num_lanes != 0 or size_n % num_lanes != 0:
724-
raise NotImplementedError("Implicit padding along lane dimensions is not supported.")
725-
726-
bitwidth = jax.dtypes.itemsize_bits(lhs.dtype)
727-
packing = 32 // bitwidth
728-
size_lhs_sublane = pltpu.get_tpu_info().num_sublanes * packing
708+
size_lhs_sublane = pltpu.get_tpu_info().get_sublane_tiling(lhs.dtype)
709+
size_lhs_sublane = min(size_lhs_sublane, size_m)
729710

730711
return Dimensions(
731712
size_m=size_m,
@@ -737,20 +718,6 @@ def validate_inputs(
737718
)
738719

739720

740-
def validate_tiles(tiles: TileSizes, dims: Dimensions):
741-
"""Validates the tile sizes for the GMM kernel."""
742-
743-
def _validate(x: int, tx: int, name: str):
744-
if x % tx != 0:
745-
raise ValueError(f"size_{name}={x} is not divisible by tile_{name}={tx}.")
746-
if tx > x:
747-
raise ValueError(f"tile_{name}={tx} is larger than size_{name}={x}.")
748-
749-
_validate(dims.size_m, pltpu.get_tpu_info().num_sublanes, "m")
750-
_validate(dims.size_k, tiles.tile_k, "k")
751-
_validate(dims.size_n, tiles.tile_n, "n")
752-
753-
754721
def get_cost_estimate(
755722
lhs: jax.Array,
756723
rhs: WeightsRef,
@@ -764,8 +731,9 @@ def get_cost_estimate(
764731

765732
lhs_bytes = dims.size_m * dims.size_k * lhs.dtype.itemsize
766733

767-
# TODO(kyuyeunk): Handle 4-bit quantization case.
768-
rhs_bytes = dims.size_group * dims.size_k * dims.size_n * rhs.weight.dtype.itemsize
734+
rhs_bytes = (
735+
dims.size_group * dims.size_k * dims.size_n * jax.dtypes.itemsize_bits(rhs.weight)
736+
) // 8
769737
if rhs.scale is not None:
770738
rhs_bytes += dims.size_n * jnp.dtype(jnp.float32).itemsize
771739
if rhs.bias is not None:
@@ -860,8 +828,6 @@ def make_gmm_configs(
860828
lhs_q_dtype = lhs_q_dtype if lhs_q_dtype is not None else lhs.dtype
861829
tiles = tile_info(lhs_q_dtype, rhs.dtype, dims, vmem_limit_bytes)
862830

863-
validate_tiles(tiles, dims)
864-
865831
return GmmConfigs(
866832
dims=dims,
867833
tiles=tiles,
@@ -916,7 +882,7 @@ def gmm_v2(
916882
"""GMM kernel implemented with emit_pipeline.
917883
918884
Dynamically calculate offset lhs/out tiles to reduce redundant computations.
919-
Additionally, adjusting dma size based on number of valid rows and utilize
885+
Additionally, it adjusts dma size based on number of valid rows and utilize
920886
triple buffering on weights to better utilize memory.
921887
922888
Args:
@@ -976,7 +942,7 @@ def gmm_v2(
976942
rhs_bias_spec = pl.BlockSpec(memory_space=pltpu.HBM)
977943

978944
# Initialize scratch shapes.
979-
max_num_gm = dims.size_group + dims.size_m // tiles.tile_m - 1
945+
max_num_gm = dims.size_group + pl.cdiv(dims.size_m, tiles.tile_m) - 1
980946

981947
scratch_shapes = [
982948
# partial_out_ref
@@ -990,6 +956,7 @@ def gmm_v2(
990956
),
991957
]
992958

959+
num_lanes = pltpu.get_tpu_info().num_lanes
993960
if cfgs.zero_init:
994961
# TODO(kyuyeunk): Create better heuristics for determining this value.
995962
target_zero_ref_bytes = 2 * 1024 * 1024
@@ -1004,7 +971,6 @@ def gmm_v2(
1004971
# (which is smallest allowed column size for DMA) and reuse the buffer by
1005972
# size_n//num_lanes times in a single tile, we can significantly increase
1006973
# tile_zero_m without triggering OOM.
1007-
num_lanes = pltpu.get_tpu_info().num_lanes
1008974
out_bytes = jnp.dtype(cfgs.out_dtype).itemsize
1009975
tile_zero_m = target_zero_ref_bytes // num_lanes // out_bytes
1010976
tile_zero_m = min(tile_zero_m, dims.size_m)
@@ -1016,7 +982,8 @@ def gmm_v2(
1016982
else:
1017983
scratch_shapes += [None, None]
1018984

1019-
out_init = jax.ShapeDtypeStruct((dims.size_m, dims.size_n), cfgs.out_dtype)
985+
aligned_n = align_to(dims.size_n, num_lanes)
986+
out_init = jax.ShapeDtypeStruct((dims.size_m, aligned_n), cfgs.out_dtype)
1020987
rhs_weights = WeightsRef(weight=rhs, scale=rhs_scale, bias=rhs_bias)
1021988

1022989
return pl.pallas_call(
@@ -1042,22 +1009,9 @@ def gmm_v2(
10421009
name=get_scope_name(dims, tiles),
10431010
cost_estimate=get_cost_estimate(lhs, rhs_weights, out_init.dtype, dims),
10441011
metadata=get_metadata(cfgs),
1045-
)(group_sizes, group_offset, lhs, rhs_weights)
1012+
)(group_sizes, group_offset, lhs, rhs_weights)[:, : dims.size_n]
10461013

10471014

1048-
def is_supported_by_gmm_v2(lhs: jax.Array, rhs: jax.Array, rhs_scale: jax.Array | None) -> bool:
1049-
"""Return false if gmm_v2 does not support the inputs yet."""
1050-
1051-
if rhs_scale is not None and rhs_scale.shape[1] != 1:
1052-
# gmm_v2 does not support subchannel quantization.
1053-
return False
1054-
# gmm_v2 does not support implicit padding along lane dimension.
1055-
num_lanes = pltpu.get_tpu_info().num_lanes
1056-
if lhs.shape[-1] % num_lanes != 0 or rhs.shape[-1] % num_lanes != 0:
1057-
return False
1058-
# gmm_v2 does not support when lhs is not multiple of sublane size.
1059-
lhs_bytes = lhs.dtype.itemsize
1060-
if lhs.shape[0] % (pltpu.get_tpu_info().num_sublanes * lhs_bytes) != 0:
1061-
return False
1062-
# Handle weird edge cases where inputs are already quantized.
1063-
return lhs.dtype in [jnp.bfloat16, jnp.float32]
1015+
def is_supported_by_gmm_v2(rhs_scale: jax.Array | None) -> bool:
1016+
# gmm_v2 does not support subchannel quantization.
1017+
return rhs_scale is None or rhs_scale.shape[1] == 1

0 commit comments

Comments
 (0)