Skip to content

Commit 32c8784

Browse files
cjx0709JamesBrianDpengchengneo
authored
perf: optimize get_forward_metadata for SWA models (#930)
perf: optimize get_forward_metadata Co-authored-by: Brian <donghouze666@outlook.com> Co-authored-by: pc-new <pengchengneo@gmail.com>
1 parent 55a8521 commit 32c8784

1 file changed

Lines changed: 31 additions & 60 deletions

File tree

python/sgl_jax/srt/layers/attention/flashattention_backend.py

Lines changed: 31 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,19 @@ def get_forward_metadata(
108108
"""Return the metadata for a forward pass."""
109109
metadata = FlashAttentionMetadata()
110110

111+
# Stride by page_size to pick one slot per page — O(N/page_size) instead of O(N)
111112
indices = np.arange(0, len(batch.cache_loc), self.page_size)
112-
selected_cache_locs = batch.cache_loc[indices]
113-
page_indices = (selected_cache_locs // self.page_size).astype(np.int32)
113+
page_indices = (batch.cache_loc[indices] // self.page_size).astype(np.int32)
114+
115+
# SWA page indices: apply mapping on ~N_pages entries
116+
# instead of ~N_tokens entries (page_size x fewer random accesses)
117+
swa_page_indices = None
118+
swa_mapping = getattr(self, "swa_index_mapping", None)
119+
if swa_mapping is not None:
120+
swa_slots = swa_mapping[batch.cache_loc[indices]]
121+
swa_page_indices = (swa_slots // self.page_size).astype(np.int32)
114122

123+
# cu_q_lens
115124
if batch.forward_mode == ForwardMode.EXTEND:
116125
cu_q_lens = np.concatenate(
117126
[
@@ -120,82 +129,44 @@ def get_forward_metadata(
120129
]
121130
)
122131
elif batch.forward_mode == ForwardMode.DECODE:
123-
cu_q_lens = np.concatenate(
124-
[
125-
np.array([0], dtype=np.int32),
126-
np.cumsum(np.ones(len(batch.seq_lens), dtype=np.int32)),
127-
]
128-
)
132+
cu_q_lens = np.arange(len(batch.seq_lens) + 1, dtype=np.int32)
129133
else:
130134
raise ValueError(f"Invalid forward mode: {batch.forward_mode}")
131135

132-
seq_lens = np.copy(batch.seq_lens)
136+
seq_lens = batch.seq_lens
133137

134138
aligned_seq_lens = (
135139
(batch.seq_lens + self.page_size - 1) // self.page_size
136140
) * self.page_size
141+
142+
# cu_kv_lens
137143
cu_kv_lens = np.concatenate(
138144
[
139145
np.array([0], dtype=np.int32),
140-
np.cumsum(aligned_seq_lens),
146+
np.cumsum(aligned_seq_lens, dtype=np.int32),
141147
]
142148
)
143149

144-
num_seqs = np.sum(batch.seq_lens > 0, dtype=np.int32).reshape(
145-
1,
146-
)
147-
148-
# Construct distribution for V2 kernel: [decode_end, prefill_end, mixed_end]
150+
# distribution for V2 kernel: [decode_end, prefill_end, mixed_end]
151+
num_seqs = np.sum(batch.seq_lens > 0, dtype=np.int32)
149152
if batch.forward_mode == ForwardMode.DECODE:
150-
# All sequences are decode/mixed mode
151-
distribution = np.array([0, 0, num_seqs.item()], dtype=np.int32)
153+
distribution = np.array([0, 0, num_seqs], dtype=np.int32)
152154
elif batch.forward_mode == ForwardMode.EXTEND:
153-
# All sequences are prefill mode
154-
distribution = np.array([0, num_seqs.item(), num_seqs.item()], dtype=np.int32)
155+
distribution = np.array([0, num_seqs, num_seqs], dtype=np.int32)
155156
else:
156157
raise ValueError(f"Invalid forward mode: {batch.forward_mode}")
157158

158-
# Compute swa_page_indices if SWA index mapping is available
159-
swa_page_indices = None
160-
if self.swa_index_mapping is not None:
161-
swa_cache_loc = self.swa_index_mapping[batch.cache_loc]
162-
swa_indices = np.arange(0, len(swa_cache_loc), self.page_size)
163-
swa_selected = swa_cache_loc[swa_indices]
164-
swa_page_indices = (swa_selected // self.page_size).astype(np.int32)
165-
166-
if swa_page_indices is not None:
167-
(
168-
metadata.num_seqs,
169-
metadata.cu_q_lens,
170-
metadata.cu_kv_lens,
171-
metadata.page_indices,
172-
metadata.seq_lens,
173-
metadata.distribution,
174-
metadata.swa_page_indices,
175-
) = device_array(
176-
(
177-
num_seqs,
178-
cu_q_lens,
179-
cu_kv_lens,
180-
page_indices,
181-
seq_lens,
182-
distribution,
183-
swa_page_indices,
184-
),
185-
sharding=(NamedSharding(self.mesh, P()) if jax.process_count() == 1 else None),
186-
)
187-
else:
188-
(
189-
metadata.num_seqs,
190-
metadata.cu_q_lens,
191-
metadata.cu_kv_lens,
192-
metadata.page_indices,
193-
metadata.seq_lens,
194-
metadata.distribution,
195-
) = device_array(
196-
(num_seqs, cu_q_lens, cu_kv_lens, page_indices, seq_lens, distribution),
197-
sharding=(NamedSharding(self.mesh, P()) if jax.process_count() == 1 else None),
198-
)
159+
(
160+
metadata.cu_q_lens,
161+
metadata.cu_kv_lens,
162+
metadata.page_indices,
163+
metadata.swa_page_indices,
164+
metadata.seq_lens,
165+
metadata.distribution,
166+
) = device_array(
167+
(cu_q_lens, cu_kv_lens, page_indices, swa_page_indices, seq_lens, distribution),
168+
sharding=(NamedSharding(self.mesh, P()) if jax.process_count() == 1 else None),
169+
)
199170
return metadata
200171

201172
def get_eagle_forward_metadata(self, batch: ModelWorkerBatch):

0 commit comments

Comments
 (0)