Skip to content

Commit 2435ddc

Browse files
committed
More tests for FlashInfer SDPA
1 parent 945c0bb commit 2435ddc

4 files changed

Lines changed: 154 additions & 324 deletions

File tree

keys_values/attention/flashinfer_wrapper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def triton_score_sum(
160160
scale: float,
161161
n_kv_heads: int,
162162
group_size: int,
163+
causal_masking: bool = True,
163164
) -> torch.Tensor:
164165
"""Compute attention weight sums using Triton (no V needed).
165166
@@ -174,6 +175,8 @@ def triton_score_sum(
174175
scale: softmax scale factor (1/sqrt(head_size))
175176
n_kv_heads: number of KV heads
176177
group_size: GQA group size (n_head // n_kv_heads)
178+
causal_masking: Whether to use causal attention mask or not. Defaults
179+
to `True`
177180
178181
Returns:
179182
W: [batch, n_kv_heads, kv_len] (fp32) attention weight sums
@@ -258,7 +261,7 @@ def triton_score_sum(
258261
BLOCK_Q=BLOCK_Q,
259262
HEAD_DIM=head_size,
260263
GROUP_SIZE=group_size,
261-
HAS_CAUSAL=True,
264+
HAS_CAUSAL=causal_masking,
262265
num_warps=NUM_WARPS,
263266
num_stages=NUM_STAGES,
264267
)

keys_values/attention/sdpa_wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,13 +330,13 @@ def _reorder(
330330

331331
def reorder_key_value(
332332
key: torch.Tensor,
333-
value: torch.Tensor,
333+
value: Optional[torch.Tensor],
334334
token_positions: torch.Tensor,
335335
input_pos: int,
336336
q_len: int,
337337
sort_if_3d: bool = True,
338338
check_token_pos: bool = False,
339-
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
339+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Dict[str, torch.Tensor]]:
340340
"""
341341
Reorder `key, value` tensors using permutations (for each b, h) which, if
342342
applied to `token_positions`, place `input_pos:(input_pos + q_len)` at the
@@ -364,7 +364,7 @@ def reorder_key_value(
364364
extra_info = dict(index_gat=index_gat, index_scat=index_scat)
365365
return (
366366
reorder_buffer_given_extra_info(key, **extra_info),
367-
reorder_buffer_given_extra_info(value, **extra_info),
367+
None if value is None else reorder_buffer_given_extra_info(value, **extra_info),
368368
extra_info,
369369
)
370370

keys_values/scripts/collect_eval_results.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,20 @@ def main(
6363

6464
if __name__ == "__main__":
6565
base_path = Path.home() / "out/finetune/neurips_exp/lora/qwen3_4b"
66+
dataset_size = "64k"
6667
datasets = [
67-
"helmet_nq_32k",
68-
# "helmet_trivia_qa_32k",
69-
# "helmet_hotpot_qa_32k",
70-
# "helmet_pop_qa_32k",
68+
f"helmet_nq_{dataset_size}",
69+
f"helmet_trivia_qa_{dataset_size}",
70+
f"helmet_hotpot_qa_{dataset_size}",
71+
f"helmet_pop_qa_{dataset_size}",
7172
]
7273
cases = [
73-
"lr_4gpu_lpc2_avg1_lr5",
74+
"lr_4gpu_cs2048_lr5",
75+
"h2o_4gpu_cs2048_lr5",
76+
"slr_4gpu_cs2048_lr5",
77+
# "qh2o_4gpu_cs2048_lr5",
78+
# "h2onorm_4gpu_cs2048_lr5",
79+
# "qh2onorm_4gpu_cs2048_lr5",
7480
]
7581
model_type = "lora"
7682
for dataset, case in product(datasets, cases):

0 commit comments

Comments
 (0)