Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions atom/plugin/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
@dataclass
class AiterFlashAttentionPhaseMetadata:
max_query_len: int
min_query_len: int
max_seq_len: int
query_start_loc: torch.Tensor

Expand Down Expand Up @@ -58,7 +57,6 @@ class AiterChunkContextMetadata:
@dataclass
class AiterFlashAttentionChunkPrefillMetadata:
max_query_len: int
min_query_len: int
max_seq_len: int
query_start_loc: torch.Tensor
chunk_context_metadata: AiterChunkContextMetadata
Expand Down Expand Up @@ -300,17 +298,31 @@ def build(
num_extend_tokens,
num_prefill_tokens,
) = split_ret
prefill_only = num_decodes == 0 and num_extends == 0 and num_prefills > 0
decode_only = num_decodes > 0 and num_extends == 0 and num_prefills == 0
mixed_request = not (prefill_only or decode_only)

query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens.cpu()
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
if mixed_request:
seq_lens = common_attn_metadata.seq_lens.cpu()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm don't have something like seq_lens_cpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm deprecate the seq_lens_cpu in common_metadata for a while.

query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
else:
seq_lens = None
query_lens_cpu = None

decode_metadata = None
if num_decodes > 0:
decode_metadata = AiterFlashAttentionDecodeMetadata(
max_query_len=query_lens_cpu[:num_decodes].max().item(),
min_query_len=query_lens_cpu[:num_decodes].min().item(),
max_seq_len=seq_lens[:num_decodes].max().item(),
max_query_len=(
common_attn_metadata.max_query_len
if decode_only
else query_lens_cpu[:num_decodes].max().item()
),
max_seq_len=(
common_attn_metadata.max_seq_len
if decode_only
else seq_lens[:num_decodes].max().item()
),
query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1],
)

Expand Down Expand Up @@ -435,26 +447,32 @@ def build(
)
extend_metadata = AiterFlashAttentionChunkPrefillMetadata(
max_query_len=query_lens_for_extend.max().item(),
min_query_len=query_lens_for_extend.min().item(),
max_seq_len=seq_lens[num_extends_slice].max().item(),
query_start_loc=query_start_loc_device - query_start_loc_device[0],
chunk_context_metadata=chunk_context_metadata,
)

prefill_metadata = None
if num_prefills > 0:
query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :]
query_start_loc_device = common_attn_metadata.query_start_loc[
num_decodes + num_extends :
]
prefill_metadata = AiterFlashAttentionPrefillMetadata(
max_query_len=query_lens_for_prefill.max().item(),
min_query_len=query_lens_for_prefill.min().item(),
max_seq_len=seq_lens[num_decodes + num_extends :].max().item(),
max_query_len=(
common_attn_metadata.max_query_len
if prefill_only
else query_lens_cpu[num_decodes + num_extends :].max().item()
),
max_seq_len=(
common_attn_metadata.max_seq_len
if prefill_only
else query_lens_cpu[num_decodes + num_extends :].max().item()
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the prefill metadata construction, the non-prefill-only branch sets max_seq_len using query_lens_cpu[...] instead of sequence lengths. This will report max_seq_len as the max query length for mixed batches (decodes/extents + prefills), which can lead to incorrect kernel configuration / bounds. Use the corresponding seq_lens[...] slice (or another true seq-lens source) for max_seq_len here.

Suggested change
else query_lens_cpu[num_decodes + num_extends :].max().item()
else seq_lens[num_decodes + num_extends :].max().item()

Copilot uses AI. Check for mistakes.
),
query_start_loc=query_start_loc_device - query_start_loc_device[0],
)

num_actual_kv_tokens = torch.sum(seq_lens).item()
# num_actual_kv_tokens = torch.sum(seq_lens).item()
num_actual_kv_tokens = 0

use_cascade = False

Expand Down
Loading