diff --git a/test/test_pallas.py b/test/test_pallas.py index 0ca0b2905a4..d49df491dc0 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -567,7 +567,10 @@ def test_paged_attention_multi_queries_wrapper(self): max_kv_len = 2048 query_len = 64 - kv_seq_lens = torch.randint(query_len, max_kv_len, (3,), dtype=torch.int32) + batch_size = 3 + kv_seq_lens = torch.randint( + query_len, max_kv_len, (batch_size,), dtype=torch.int32) + effective_q_lens = torch.full((batch_size,), query_len, dtype=torch.int32) assert query_len <= max_kv_len for cur_kv_seq in kv_seq_lens: assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.' @@ -592,6 +595,7 @@ def test_paged_attention_multi_queries_wrapper(self): v_pages_xla = v_pages.to("xla") kv_seq_lens_xla = kv_seq_lens.to("xla") page_indices_xla = page_indices.to("xla") + effective_q_lens_xla = effective_q_lens.to("xla") output = multi_queries_paged_attention( q_xla, @@ -599,6 +603,7 @@ def test_paged_attention_multi_queries_wrapper(self): v_pages_xla, kv_seq_lens_xla, page_indices_xla, + effective_q_lens_xla, num_kv_pages_per_compute_block=block_kv_size // page_size, num_queries_per_compute_block=num_queries_per_compute_block, ) @@ -609,6 +614,7 @@ def test_paged_attention_multi_queries_wrapper(self): v_pages_xla, kv_seq_lens_xla, page_indices_xla, + effective_q_lens_xla, num_kv_pages_per_compute_block=block_kv_size // page_size, num_queries_per_compute_block=num_queries_per_compute_block, use_kernel=False, @@ -619,6 +625,7 @@ def test_paged_attention_multi_queries_wrapper(self): v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32) kv_seq_lens_jax = jnp.array(kv_seq_lens.numpy(), dtype=jnp.int32) page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) + effective_q_lens_jax = jnp.array(effective_q_lens.numpy(), dtype=jnp.int32) expected_output = torch.from_numpy( np.array( jax_multi_queries_paged_attention( @@ -627,6 +634,7 @@ def test_paged_attention_multi_queries_wrapper(self): v_pages_jax, kv_seq_lens_jax, page_indices_jax, + effective_q_lens_jax, num_kv_pages_per_compute_block=block_kv_size // page_size, num_queries_per_compute_block=num_queries_per_compute_block, ))) @@ -654,7 +662,10 @@ def test_paged_attention_multi_queries_wrapper_with_dynamo(self): max_kv_len = 2048 query_len = 64 - kv_seq_lens = torch.randint(query_len, max_kv_len, (3,), dtype=torch.int32) + batch_size = 3 + kv_seq_lens = torch.randint( + query_len, max_kv_len, (batch_size,), dtype=torch.int32) + effective_q_lens = torch.full((batch_size,), query_len, dtype=torch.int32) assert query_len <= max_kv_len for cur_kv_seq in kv_seq_lens: assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.' @@ -679,9 +690,10 @@ def test_paged_attention_multi_queries_wrapper_with_dynamo(self): v_pages_xla = v_pages.to("xla") kv_seq_lens_xla = kv_seq_lens.to("xla") page_indices_xla = page_indices.to("xla") + effective_q_lens_xla = effective_q_lens.to("xla") def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, - page_indices, + page_indices, effective_q_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel): @@ -691,6 +703,7 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, v_pages, kv_seq_lens, page_indices, + effective_q_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel=use_kernel, @@ -705,6 +718,7 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, v_pages_xla, kv_seq_lens_xla, page_indices_xla, + effective_q_lens_xla, num_kv_pages_per_compute_block=block_kv_size // page_size, num_queries_per_compute_block=num_queries_per_compute_block, use_kernel=True, @@ -716,6 +730,7 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, v_pages_xla, kv_seq_lens_xla, page_indices_xla, + effective_q_lens_xla, num_kv_pages_per_compute_block=block_kv_size // page_size, num_queries_per_compute_block=num_queries_per_compute_block, use_kernel=False, diff --git a/test/test_tpu_paged_attention_kernel.py b/test/test_tpu_paged_attention_kernel.py index f03763ded0b..80023652d69 100644 --- a/test/test_tpu_paged_attention_kernel.py +++ b/test/test_tpu_paged_attention_kernel.py @@ -43,8 +43,9 @@ def _ref_jax_extended_paged_attention( q, # [batch_size, query_len, num_query_heads, head_size] k_pages, # [num_kv_heads, total_num_pages, page_size, head_size] v_pages, # [num_kv_heads, total_num_pages, page_size, head_size] - lengths, # [batch_size] + lengths, # [batch_size], the effective kv_length. page_indices, # [batch_size, pages_per_sequence] + effective_q_lens, # [batch_size] the effective q_length ): batch_size, query_len, num_query_heads, head_size = q.shape num_kv_heads, total_num_pages, page_size, _ = k_pages.shape @@ -72,7 +73,8 @@ def _ref_jax_extended_paged_attention( attn = jnp.einsum("qhd,khd->hqk", q[i], k) attn = attn.astype('float32') - q_span = (kv_len - query_len) + jax.lax.broadcasted_iota( + effective_q_len = effective_q_lens[i] + q_span = (kv_len - effective_q_len) + jax.lax.broadcasted_iota( jnp.int32, (query_len, kv_len), 0) kv_span = jax.lax.broadcasted_iota(jnp.int32, (query_len, kv_len), 1) mask = jnp.where(q_span < kv_span, float("-inf"), 0.) @@ -91,17 +93,16 @@ class PagedAttentionKernelTest(jtu.JaxTestCase): def setUp(self): super().setUp() - -# def test_paged_attention( -# self, -# ): -# dtype = jnp.bfloat16 -# page_size=16 -# num_kv_heads = 8 -# q_kv_head_ratio = 4 -# head_dim = 256 -# num_queries_per_compute_block = 32 -# block_kv_size = 256 + # def test_paged_attention( + # self, + # ): + # dtype = jnp.bfloat16 + # page_size=16 + # num_kv_heads = 8 + # q_kv_head_ratio = 4 + # head_dim = 256 + # num_queries_per_compute_block = 32 + # block_kv_size = 256 @parameterized.product( dtype=(jnp.float32, jnp.bfloat16), @@ -112,7 +113,7 @@ def setUp(self): num_queries_per_compute_block=(16, 32), block_kv_size=(128, 256), ) - def test_paged_attention( + def test_paged_attention_without_query_padding( self, dtype, page_size, @@ -125,13 +126,13 @@ def test_paged_attention( max_kv_len = 2048 query_len = 64 + batch_size = 3 kv_seq_lens = jax.random.randint( - jax.random.key(0), (3,), query_len, max_kv_len) + jax.random.key(0), (batch_size,), query_len, max_kv_len) assert query_len <= max_kv_len for cur_kv_seq in kv_seq_lens: assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.' - batch_size = len(kv_seq_lens) pages_per_sequence = max_kv_len // page_size total_num_pages = batch_size * pages_per_sequence assert max_kv_len <= total_num_pages * page_size @@ -150,12 +151,14 @@ def test_paged_attention( print(f'Running paged_attention with {query_len=}') num_kv_pages_per_compute_block = block_kv_size // page_size + effective_q_lens = jnp.full_like(kv_seq_lens, query_len) actual_output = paged_attention( q, k_pages, v_pages, kv_seq_lens, page_indices, + effective_q_lens, num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, num_queries_per_compute_block=num_queries_per_compute_block, ) @@ -168,6 +171,7 @@ def test_paged_attention( v_pages, kv_seq_lens, page_indices, + effective_q_lens, ) self.assertEqual(actual_output.shape, expected_output.shape) @@ -183,5 +187,110 @@ def test_paged_attention( self.assertTrue( jnp.allclose(expected_output, actual_output, atol=atol, rtol=rtol)) + # def test_paged_attention_query_len_longer_than_kv_seq_len( + # self, + # ): + # dtype = jnp.float32 + # page_size=16 + # num_kv_heads = 8 + # q_kv_head_ratio = 4 + # head_dim = 256 + # num_queries_per_compute_block = 32 + # block_kv_size = 256 + # In practice, vLLM would pad the query so that the query seq len will be longer than the kv seq len. query seq len may be padded but not for kv seq len. + # When this happens, we need an additional parameter `effective_q_lens` to the paged_attention to set the causal mask right. + @parameterized.product( + dtype=(jnp.float32, jnp.bfloat16), + page_size=(16, 32, 64), + num_kv_heads=(1, 8), + q_kv_head_ratio=(1, 4, 8), + head_dim=(128, 256), + num_queries_per_compute_block=(16, 32), + block_kv_size=(128, 256), + ) + def test_paged_attention_with_query_padding( + self, + dtype, + page_size, + num_kv_heads, + q_kv_head_ratio, + head_dim, + num_queries_per_compute_block, + block_kv_size, + ): + + max_kv_len = 2048 + # Set query_len>kv_seq_lens + query_len = max_kv_len + batch_size = 3 + kv_seq_lens = jax.random.randint( + jax.random.key(0), (batch_size,), 0, max_kv_len) + effective_q_lens = jax.random.randint( + jax.random.key(0), (batch_size,), 0, kv_seq_lens) + for cur_effec_q_len, cur_kv_seq_len in zip(effective_q_lens, kv_seq_lens): + assert cur_effec_q_len <= cur_kv_seq_len, f'The effective query len {cur_effec_q_len} should be less than or equal to the kv_len {cur_kv_seq_len} in the current sequence.' + + pages_per_sequence = max_kv_len // page_size + total_num_pages = batch_size * pages_per_sequence + assert max_kv_len <= total_num_pages * page_size + q, k_pages, v_pages, page_indices = _generate_qkv( + kv_seq_lens, + page_size, + max_kv_len, + query_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + jax.random.key(0), + dtype, + ) + + print( + f'Running paged_attention with {query_len=}, {kv_seq_lens=}, {effective_q_lens=}' + ) + num_kv_pages_per_compute_block = block_kv_size // page_size + actual_output = paged_attention( + q, + k_pages, + v_pages, + kv_seq_lens, + page_indices, + effective_q_lens, + num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, + num_queries_per_compute_block=num_queries_per_compute_block, + ) + # actual_output = jax.block_until_ready(actual_output) + + # Run the ref impl. + expected_output = _ref_jax_extended_paged_attention( + q, + k_pages, + v_pages, + kv_seq_lens, + page_indices, + effective_q_lens, + ) + + self.assertEqual(actual_output.shape, expected_output.shape) + + if dtype == jnp.float32: + atol = 2e-2 + rtol = 1e-2 + elif dtype == jnp.bfloat16: + atol = 6e-1 + rtol = 1e-1 + else: + self.fail(f'Unsupported dtype: {dtype}') + for b in range(batch_size): + # N.B. For the output ([batch_size, query_len, num_q_heads, head_dim]) at query_len dim, all the value after the effective_q_len will be thrown away due to we padding the query seq len. The values after the effective_q_len may differ between the kernel and the ref impl because of the causal mask. + effective_q_len = effective_q_lens[b] + self.assertTrue( + jnp.allclose( + expected_output[b, :effective_q_len], + actual_output[b, :effective_q_len], + atol=atol, + rtol=rtol)) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index d840a10b55b..fdc5992c3b0 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -487,8 +487,9 @@ def _multi_queries_paged_attention_nonkernel( q, # [batch_size, query_len, num_heads, head_size] k_pages, # [num_kv_heads, total_num_pages, page_size, head_size] v_pages, # [num_kv_heads, total_num_pages, page_size, head_size] - lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens) + lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens), the effective kv_length. page_indices, # [batch_size, pages_per_sequence] + effective_q_lens, # [batch_size], the effective q_length ) -> torch.Tensor: # [batch_size, query_len, num_heads, head_dim] batch_size, query_len, num_query_heads, head_size = q.shape num_kv_heads, total_num_pages, page_size, _ = k_pages.shape @@ -528,7 +529,8 @@ def _multi_queries_paged_attention_nonkernel( k) # [num_query_heads, query_len, kv_len] attn = attn.float() empty_mask = torch.ones(query_len, kv_len, device=attn.device) - mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + effective_q_len = effective_q_lens[i] + mask = torch.triu(empty_mask, diagonal=kv_len - effective_q_len + 1).bool() attn.masked_fill_(mask, float("-inf")) attn = torch.softmax( attn, dim=-1).to(v.dtype) # [num_query_heads, query_len, kv_len] @@ -547,6 +549,7 @@ def multi_queries_paged_attention( v_pages, # [num_kv_heads, total_num_pages, page_size, head_size] lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens) page_indices, # [batch_size, pages_per_sequence] + effective_q_lens, # [batch_size] num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel=True, @@ -559,6 +562,7 @@ def multi_queries_paged_attention( v_pages, lengths, page_indices, + effective_q_lens, ) # Import JAX within the function such that we don't need to call the jax_import_guard() @@ -572,6 +576,7 @@ def multi_queries_paged_attention( v_pages, lengths, page_indices, + effective_q_lens, num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, num_queries_per_compute_block=num_queries_per_compute_block, static_argnames=[ @@ -592,6 +597,7 @@ def multi_queries_paged_attention( [ lengths, page_indices_reshaped, + effective_q_lens, buffer_index, step, q.to(q_dtype_for_kernel_launch), @@ -1081,7 +1087,7 @@ def paged_attention_non_xla(q: torch.Tensor, XLA_LIB.define( - "multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int num_kv_pages_per_compute_block, int num_queries_per_compute_block, bool use_kernel) -> Tensor", + "multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, Tensor effective_q_lens, int num_kv_pages_per_compute_block, int num_queries_per_compute_block, bool use_kernel) -> Tensor", ) @@ -1089,10 +1095,10 @@ def paged_attention_non_xla(q: torch.Tensor, def multi_queries_paged_attention_xla( q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, lengths: torch.Tensor, page_indices: torch.Tensor, - num_kv_pages_per_compute_block: int, num_queries_per_compute_block: int, - use_kernel: bool): + effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, use_kernel: bool): return multi_queries_paged_attention(q, k_pages, v_pages, lengths, - page_indices, + page_indices, effective_q_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel) @@ -1102,8 +1108,8 @@ def multi_queries_paged_attention_xla( def multi_queries_paged_attention_non_xla( q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, lengths: torch.Tensor, page_indices: torch.Tensor, - num_kv_pages_per_compute_block: int, num_queries_per_compute_block: int, - use_kernel: bool): + effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, use_kernel: bool): return non_xla_attetion(q, k_pages, v_pages, "paged") diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index ae2e352c34f..92724b05bfc 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -97,6 +97,7 @@ def wait_and_get_loaded(self) -> jax.Array: def _flash_attention( q_head_idx_per_kv, # scalar, ranges from 0 to num_query_heads_per_kv_head lengths_ref, # [batch_size] jax.Array the length of each example + effective_q_lens_ref, # [batch_size] jax.Array the length of the effective query lengths # input q_ref, # [1, num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] k, # [pages_per_compute_block*page_size,head_dim] @@ -148,7 +149,8 @@ def start_new_sequence(): q_index = q_blk_idx * num_queries_per_compute_block kv_index = kv_blk_idx * kv_seq_len_per_kv_compute_blk kv_len = lengths_ref[b] - row_ids = (kv_len - query_len) + q_index + jax.lax.broadcasted_iota( + effective_q_len = effective_q_lens_ref[b] + row_ids = (kv_len - effective_q_len) + q_index + jax.lax.broadcasted_iota( jnp.int32, (num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk), 0) col_ids = kv_index + jax.lax.broadcasted_iota( @@ -209,6 +211,7 @@ def paged_flash_attention_kernel( lengths_ref, # [batch_size] jax.Array the length of each example # 1d vector, results from page_indices.reshape(-1) where originally page_indices.shape=[batch_size, pages_per_sequence] page_indices_ref, + effective_q_lens_ref, # [batch_size] jax.Array the length of the effective query lengths buffer_index_ref, step_ref, # At caller, q.shape=[batch_size, num_q_heads query_len, head_dim] @@ -355,6 +358,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable _flash_attention( q_head_idx, lengths_ref, + effective_q_lens_ref, q_ref, # [1, num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] k, v, @@ -394,6 +398,7 @@ def paged_attention( v_pages: jax.Array | quantization_utils.QuantizedTensor, lengths: jax.Array, page_indices: jax.Array, + effective_q_lens: jax.Array, *, mask_value: float = DEFAULT_MASK_VALUE, num_kv_pages_per_compute_block: int, @@ -405,10 +410,11 @@ def paged_attention( q: A [batch_size, query_len, num_q_heads, head_dim] jax.Array. k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. - lengths: A i32[batch_size] jax.Array the length of each example. + lengths: A i32[batch_size] jax.Array the effective kv length of each example. page_indices: A i32[batch_size, pages_per_sequence] jax.Array. Each entry should be in the range of [0, total_num_pages), indicating where to locate the page in `k_pages` or `v_pages`. + effective_q_lens: A i32[batch_size] jax.Array the effective query length of each example. mask_value: The value used for padding in attention. By default it is a very negative floating point number. num_kv_pages_per_compute_block: how many kv pages to be processed in one flash @@ -448,6 +454,9 @@ def paged_attention( f" {head_dim} and {head_dim_k}.") if lengths.shape != (batch_size,): raise ValueError("`lengths` and `q` must have the same batch size") + if lengths.shape != effective_q_lens.shape: + raise ValueError( + "`lengths` and `effective_q_lens` must have the same size: batch_size") if batch_size_paged_indices != batch_size: raise ValueError("`page_indices` and `q` must have the same batch size") if lengths.dtype != jnp.int32: @@ -576,7 +585,7 @@ def lm_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_): mask_value=mask_value, query_len=query_len), grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=4, + num_scalar_prefetch=5, in_specs=in_specs, out_specs=out_specs, grid=grid, @@ -598,6 +607,7 @@ def lm_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_): outs = kernel( lengths, page_indices_1d, + effective_q_lens, buffer_index, step, q.astype(q_dtype_for_kernel_launch),