@@ -108,10 +108,19 @@ def get_forward_metadata(
108108 """Return the metadata for a forward pass."""
109109 metadata = FlashAttentionMetadata ()
110110
111+ # Stride by page_size to pick one slot per page — O(N/page_size) instead of O(N)
111112 indices = np .arange (0 , len (batch .cache_loc ), self .page_size )
112- selected_cache_locs = batch .cache_loc [indices ]
113- page_indices = (selected_cache_locs // self .page_size ).astype (np .int32 )
113+ page_indices = (batch .cache_loc [indices ] // self .page_size ).astype (np .int32 )
114+
115+ # SWA page indices: apply mapping on ~N_pages entries
116+ # instead of ~N_tokens entries (page_size x fewer random accesses)
117+ swa_page_indices = None
118+ swa_mapping = getattr (self , "swa_index_mapping" , None )
119+ if swa_mapping is not None :
120+ swa_slots = swa_mapping [batch .cache_loc [indices ]]
121+ swa_page_indices = (swa_slots // self .page_size ).astype (np .int32 )
114122
123+ # cu_q_lens
115124 if batch .forward_mode == ForwardMode .EXTEND :
116125 cu_q_lens = np .concatenate (
117126 [
@@ -120,82 +129,44 @@ def get_forward_metadata(
120129 ]
121130 )
122131 elif batch .forward_mode == ForwardMode .DECODE :
123- cu_q_lens = np .concatenate (
124- [
125- np .array ([0 ], dtype = np .int32 ),
126- np .cumsum (np .ones (len (batch .seq_lens ), dtype = np .int32 )),
127- ]
128- )
132+ cu_q_lens = np .arange (len (batch .seq_lens ) + 1 , dtype = np .int32 )
129133 else :
130134 raise ValueError (f"Invalid forward mode: { batch .forward_mode } " )
131135
132- seq_lens = np . copy ( batch .seq_lens )
136+ seq_lens = batch .seq_lens
133137
134138 aligned_seq_lens = (
135139 (batch .seq_lens + self .page_size - 1 ) // self .page_size
136140 ) * self .page_size
141+
142+ # cu_kv_lens
137143 cu_kv_lens = np .concatenate (
138144 [
139145 np .array ([0 ], dtype = np .int32 ),
140- np .cumsum (aligned_seq_lens ),
146+ np .cumsum (aligned_seq_lens , dtype = np . int32 ),
141147 ]
142148 )
143149
144- num_seqs = np .sum (batch .seq_lens > 0 , dtype = np .int32 ).reshape (
145- 1 ,
146- )
147-
148- # Construct distribution for V2 kernel: [decode_end, prefill_end, mixed_end]
150+ # distribution for V2 kernel: [decode_end, prefill_end, mixed_end]
151+ num_seqs = np .sum (batch .seq_lens > 0 , dtype = np .int32 )
149152 if batch .forward_mode == ForwardMode .DECODE :
150- # All sequences are decode/mixed mode
151- distribution = np .array ([0 , 0 , num_seqs .item ()], dtype = np .int32 )
153+ distribution = np .array ([0 , 0 , num_seqs ], dtype = np .int32 )
152154 elif batch .forward_mode == ForwardMode .EXTEND :
153- # All sequences are prefill mode
154- distribution = np .array ([0 , num_seqs .item (), num_seqs .item ()], dtype = np .int32 )
155+ distribution = np .array ([0 , num_seqs , num_seqs ], dtype = np .int32 )
155156 else :
156157 raise ValueError (f"Invalid forward mode: { batch .forward_mode } " )
157158
158- # Compute swa_page_indices if SWA index mapping is available
159- swa_page_indices = None
160- if self .swa_index_mapping is not None :
161- swa_cache_loc = self .swa_index_mapping [batch .cache_loc ]
162- swa_indices = np .arange (0 , len (swa_cache_loc ), self .page_size )
163- swa_selected = swa_cache_loc [swa_indices ]
164- swa_page_indices = (swa_selected // self .page_size ).astype (np .int32 )
165-
166- if swa_page_indices is not None :
167- (
168- metadata .num_seqs ,
169- metadata .cu_q_lens ,
170- metadata .cu_kv_lens ,
171- metadata .page_indices ,
172- metadata .seq_lens ,
173- metadata .distribution ,
174- metadata .swa_page_indices ,
175- ) = device_array (
176- (
177- num_seqs ,
178- cu_q_lens ,
179- cu_kv_lens ,
180- page_indices ,
181- seq_lens ,
182- distribution ,
183- swa_page_indices ,
184- ),
185- sharding = (NamedSharding (self .mesh , P ()) if jax .process_count () == 1 else None ),
186- )
187- else :
188- (
189- metadata .num_seqs ,
190- metadata .cu_q_lens ,
191- metadata .cu_kv_lens ,
192- metadata .page_indices ,
193- metadata .seq_lens ,
194- metadata .distribution ,
195- ) = device_array (
196- (num_seqs , cu_q_lens , cu_kv_lens , page_indices , seq_lens , distribution ),
197- sharding = (NamedSharding (self .mesh , P ()) if jax .process_count () == 1 else None ),
198- )
159+ (
160+ metadata .cu_q_lens ,
161+ metadata .cu_kv_lens ,
162+ metadata .page_indices ,
163+ metadata .swa_page_indices ,
164+ metadata .seq_lens ,
165+ metadata .distribution ,
166+ ) = device_array (
167+ (cu_q_lens , cu_kv_lens , page_indices , swa_page_indices , seq_lens , distribution ),
168+ sharding = (NamedSharding (self .mesh , P ()) if jax .process_count () == 1 else None ),
169+ )
199170 return metadata
200171
201172 def get_eagle_forward_metadata (self , batch : ModelWorkerBatch ):
0 commit comments