Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Feb 10, 2025
1 parent c4bc5a0 commit 57e1dc4
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 41 deletions.
58 changes: 30 additions & 28 deletions test/benchmarks/test_ragged_paged_attention_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,13 @@ def _ref_ragged_paged_attention(

return jnp.concatenate(outputs, axis=0)


def _get_closest_power_of_two(x):
if x <= 0:
raise ValueError(f"x must be positive. Got {x}")
return 2**int(np.ceil(np.log2(x)))


def benchmark(args):
seq_lens = [
(1, 1328),
Expand All @@ -100,7 +102,7 @@ def benchmark(args):
dtype = jnp.float32
page_size = 16
num_pages = 32768
num_queries_per_block=128
num_queries_per_block = 128

num_seqs = len(seq_lens)
# Make sure the q_len is no longer than the kv_len. For example,
Expand Down Expand Up @@ -138,16 +140,12 @@ def benchmark(args):
max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size
# The reason why we need to pad max_num_pages_per_seq is that
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
max_num_pages_per_seq = _get_closest_power_of_two(
max_num_pages_per_seq)
max_num_pages_per_seq = _get_closest_power_of_two(max_num_pages_per_seq)
# The assert below mimics the reality that each page get a unique index.
# But for testing, the assert could be omitted.
# assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}"
page_indices = jax.random.randint(
k4, (num_q_tokens, max_num_pages_per_seq),
0,
num_pages,
dtype=jnp.int32)
k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32)

# Create a cu_q_lens: jax.Array, # i32[num_tokens + 1]
q_lens_with_paddings = [0] * num_q_tokens
Expand All @@ -174,28 +172,28 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
if profile:
jax.profiler.start_trace(profile_path)

actual_output=None
actual_output = None
for _ in range(num_iters):
if args.kernel == "ragged-paged-attention":
err, actual_output = ragged_paged_attention(
queries,
k_pages,
v_pages,
kv_lens_np,
page_indices,
cu_q_lens,
num_seqs,
queries,
k_pages,
v_pages,
kv_lens_np,
page_indices,
cu_q_lens,
num_seqs,
)
err.throw()
elif args.kernel == "ragged-paged-attention-ref-impl":
actual_output = _ref_ragged_paged_attention(
queries,
k_pages,
v_pages,
kv_lens_np,
page_indices,
cu_q_lens,
num_seqs,
queries,
k_pages,
v_pages,
kv_lens_np,
page_indices,
cu_q_lens,
num_seqs,
)
else:
assert False, f"Invalid kernel name {args.kernel}"
Expand All @@ -206,7 +204,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
if profile:
jax.profiler.stop_trace()
return (end_time - start_time) / num_iters

# Warmup.
print("Warming up...")
run_benchmark(num_iters=3, profile=False)
Expand All @@ -217,14 +215,18 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
else:
latency = run_benchmark(num_iters=10, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--kernel",
type=str,
choices=["ragged-paged-attention","ragged-paged-attention-ref-impl",],
default="multi-queries-paged-attn")
parser.add_argument(
"--kernel",
type=str,
choices=[
"ragged-paged-attention",
"ragged-paged-attention-ref-impl",
],
default="multi-queries-paged-attn")
parser.add_argument("--profile", action="store_true")
args = parser.parse_args()
benchmark(args)
11 changes: 7 additions & 4 deletions test/test_ragged_paged_attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def _verify_ragged_paged_attention(
# The reason why we need to pad max_num_pages_per_seq is that
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
num_kv_pages_per_block = 128
max_num_pages_per_seq = self._round_up_closest_multiple_of(max_num_pages_per_seq, num_kv_pages_per_block)
max_num_pages_per_seq = self._round_up_closest_multiple_of(
max_num_pages_per_seq, num_kv_pages_per_block)
# The assert below mimics the reality that each page get a unique index.
# But for testing, the assert could be omitted.
# assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}"
Expand Down Expand Up @@ -230,7 +231,8 @@ def test_paged_attention_varlen_comprehensive(
num_pages: int,
):
if jtu.is_device_tpu(version=4) and head_dim == 256 and page_size == 32:
self.skipTest("TPU v4 has small VMEM. It will run into VMEM OOM. Skip the test.")
self.skipTest(
"TPU v4 has small VMEM. It will run into VMEM OOM. Skip the test.")
self._verify_ragged_paged_attention(
seq_lens,
num_heads,
Expand Down Expand Up @@ -368,8 +370,9 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len(self,):
max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size
# The reason why we need to pad max_num_pages_per_seq is that
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
num_kv_pages_per_block=128
max_num_pages_per_seq = self._round_up_closest_multiple_of(max_num_pages_per_seq, num_kv_pages_per_block)
num_kv_pages_per_block = 128
max_num_pages_per_seq = self._round_up_closest_multiple_of(
max_num_pages_per_seq, num_kv_pages_per_block)
# The assert below mimics the reality that each page get a unique index.
# But for testing, the assert could be omitted.
assert max_num_pages_per_seq * num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def store_to_output(): # pylint: disable=unused-variable
m_ref[q_head_idx_per_kv] = m_scratch_ref[q_head_idx_per_kv].astype(
m_ref.dtype)


# grid = (num_kv_heads, num_logical_q_tiles, num_kv_blks)
def _compute_next_block_indices(kv_head_idx, logical_q_blk_idx, kv_blk_idx,
num_logical_q_blks, kv_blk_size, seq_ids,
Expand Down Expand Up @@ -554,6 +555,7 @@ def advance_logical_q_blk_idx():
advance_logical_q_blk_idx,
)


# grid = (num_kv_heads, num_logical_q_tiles, num_kv_blks)
def paged_flash_attention_kernel(
# prefetch refs
Expand Down Expand Up @@ -646,24 +648,29 @@ def create_kv_async_copy_descriptors(seq_idx, kv_head_idx, kv_blk_idx,
@pl.when(step == 0)
def prefetch_first_block(): # pylint: disable=unused-variable
async_copy_k, async_copy_v = create_kv_async_copy_descriptors(
cur_seq_idx, kv_head_idx, kv_blk_idx, buffer_index, cur_page_indices_ref)
cur_seq_idx, kv_head_idx, kv_blk_idx, buffer_index,
cur_page_indices_ref)
async_copy_k.start()
async_copy_v.start()

next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = _compute_next_block_indices(kv_head_idx, logical_q_blk_idx, kv_blk_idx, num_logical_q_blks, kv_blk_size, seq_ids, effective_kv_lens_ref)
next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = _compute_next_block_indices(
kv_head_idx, logical_q_blk_idx, kv_blk_idx, num_logical_q_blks,
kv_blk_size, seq_ids, effective_kv_lens_ref)

@pl.when(next_kv_head_idx < num_kv_heads)
def prefetch_next_block(): # pylint: disable=unused-variable
next_buffer_index = jnp.where(buffer_index == 0, 1, 0)
next_seq_idx = seq_ids[next_logical_q_blk_idx]
async_copy_next_k, async_copy_next_v = create_kv_async_copy_descriptors(
next_seq_idx, next_kv_head_idx, next_kv_blk_idx, next_buffer_index, next_page_indices_ref)
next_seq_idx, next_kv_head_idx, next_kv_blk_idx, next_buffer_index,
next_page_indices_ref)
async_copy_next_k.start()
async_copy_next_v.start()
buffer_index_ref[0] = next_buffer_index

async_copy_k, async_copy_v = create_kv_async_copy_descriptors(
cur_seq_idx, kv_head_idx, kv_blk_idx, buffer_index, cur_page_indices_ref)
cur_seq_idx, kv_head_idx, kv_blk_idx, buffer_index,
cur_page_indices_ref)
k = async_copy_k.wait_and_get_loaded(
) # [pages_per_compute_block*page_size,head_dim]
v = async_copy_v.wait_and_get_loaded()
Expand Down Expand Up @@ -704,6 +711,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable
def _round_up_to_multiple_of_tm(x, tm):
return (x + tm - 1) // tm * tm


MIN_BLOCK_SIZE = 128


Expand Down Expand Up @@ -826,27 +834,35 @@ def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx,
)
# Note page_indices.shape=[num_tokens, pages_per_sequence], pages_per_sequence % num_kv_pages_per_block==0
# Unsqueeze an extra dimension in page_indices so that num_tokens can avoid the 2nd last dimension having to be a multiple of 8.
expanded_page_indices = jnp.expand_dims(page_indices, 1) # [num_tokens, 1, pages_per_sequence]
expanded_page_indices = jnp.expand_dims(
page_indices, 1) # [num_tokens, 1, pages_per_sequence]

def cur_page_indices_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx,
sequence_metadata, *_):
sequence_metadata, *_):
seq_ids, physical_q_tile_ids = sequence_metadata
del physical_q_tile_ids
seq_id = seq_ids[logical_q_blk_idx]
return (seq_id, 0, kv_blk_idx)

cur_page_indices_spec = pl.BlockSpec(
(None, None, num_kv_pages_per_block),
cur_page_indices_index_map,
memory_space=pltpu.TPUMemorySpace.SMEM,
)
page_size = k_pages.shape[2]
kv_blk_size = page_size * num_kv_pages_per_block
def next_kv_blk_page_indices_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx,
sequence_metadata, num_logical_q_tiles_1d, kv_lens, *_):

def next_kv_blk_page_indices_index_map(kv_head_idx, logical_q_blk_idx,
kv_blk_idx, sequence_metadata,
num_logical_q_tiles_1d, kv_lens, *_):
seq_ids, physical_q_tile_ids = sequence_metadata
next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = _compute_next_block_indices(kv_head_idx, logical_q_blk_idx, kv_blk_idx, num_logical_q_tiles_1d[0], kv_blk_size, seq_ids, kv_lens)
next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = _compute_next_block_indices(
kv_head_idx, logical_q_blk_idx, kv_blk_idx, num_logical_q_tiles_1d[0],
kv_blk_size, seq_ids, kv_lens)
del physical_q_tile_ids
next_seq_id = seq_ids[next_logical_q_blk_idx]
return (next_seq_id, 0, next_kv_blk_idx)

next_page_indices_spec = pl.BlockSpec(
(None, None, num_kv_pages_per_block),
next_kv_blk_page_indices_index_map,
Expand Down

0 comments on commit 57e1dc4

Please sign in to comment.