Skip to content

Commit ecf164a

Browse files
authored
Reduce GQA cpu test combinations (#26897)
The current testing strategy for GQA on CPU attempts to run a Cartesian product of all configuration parameters (batch size, sequence length, rotary embeddings, packed KV, softcap, etc.), leading to over 2000 test combinations. This causes significant runtime overhead and potential timeouts. This PR optimizes `test_gqa_cpu.py` by: - Replacing the nested loop over all parameters with a round-robin selection strategy (`combo_index`). - Significantly reducing the number of test cases (from ~2304 to ~32 in pipeline mode) while maintaining coverage of individual features (rotary, packed, softcap, etc.). This ensures the test suite remains robust but much faster. It reduces test time from minutes to seconds, and saves lot of compute resource in CI pipeline.
1 parent 0d59f8d commit ecf164a

File tree

1 file changed

+93
-90
lines changed

1 file changed

+93
-90
lines changed

onnxruntime/test/python/transformers/test_gqa_cpu.py

Lines changed: 93 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -2491,90 +2491,105 @@ def run_test_config(
24912491
print(
24922492
f"\nRunning tests with precision: {'FLOAT16' if precision['ort_type'] == TensorProto.FLOAT16 else 'FLOAT32'}"
24932493
)
2494+
local_opts = [additional_params["local"]] if "local" in additional_params else [False, True]
2495+
rotary_opts = (
2496+
[(additional_params["rotary"], additional_params["rotary_interleaved"])]
2497+
if "rotary" in additional_params
2498+
else [(False, False), (True, False), (True, True)]
2499+
)
2500+
packed_opts = [additional_params["packed"]] if "packed" in additional_params else [False, True]
2501+
softcap_opts = [additional_params["softcap"]] if "softcap" in additional_params else [0.0, 50.0]
2502+
smooth_opts = (
2503+
[additional_params["use_smooth_softmax"]]
2504+
if "use_smooth_softmax" in additional_params
2505+
else [False, True]
2506+
)
2507+
head_sink_opts = [additional_params["head_sink"]] if "head_sink" in additional_params else [False, True]
2508+
2509+
combo_index = 0
24942510
for b in batches:
24952511
for s, s2 in seqs:
24962512
for n, n2 in num_h:
24972513
for h in h_sizes:
2498-
for local in [False, True]:
2499-
for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]:
2500-
for packed in [False, True]:
2501-
for softcap in [0.0, 50.0]:
2502-
for use_smooth_softmax in [False, True]:
2503-
for has_pos, has_attn in pos_ids_attn_bias:
2504-
for head_sink in [False, True]:
2505-
if use_smooth_softmax and head_sink:
2506-
continue
2507-
for output_qk in qk_output:
2508-
if config_class == PromptConfig:
2509-
config = config_class(
2510-
b,
2511-
s,
2512-
s2,
2513-
s + s2 + 8,
2514-
n,
2515-
n2,
2516-
h,
2517-
has_pos,
2518-
has_attn,
2519-
head_sink,
2520-
output_qk,
2521-
)
2522-
else: # Config
2523-
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
2524-
config = config_class(
2525-
b,
2526-
s,
2527-
s2,
2528-
sp,
2529-
n,
2530-
n2,
2531-
h,
2532-
has_pos,
2533-
has_attn,
2534-
head_sink,
2535-
output_qk,
2536-
)
2537-
2538-
params = {
2539-
"config": config,
2540-
"torch_type": precision["torch_type"],
2541-
"numpy_type": precision["numpy_type"],
2542-
"ort_type": precision["ort_type"],
2543-
"rtol": precision["rtol"],
2544-
"atol": precision["atol"],
2545-
"local": local,
2546-
"past_format": Formats.BNSH,
2547-
"rotary": rotary,
2548-
"rotary_interleaved": rotary_interleaved,
2549-
"packed": packed,
2550-
"softcap": softcap,
2551-
"use_smooth_softmax": use_smooth_softmax,
2552-
}
2553-
params.update(additional_params)
2554-
2555-
all_close = test_func(**params)
2556-
self.assertTrue(all_close)
2514+
local = local_opts[combo_index % len(local_opts)]
2515+
rotary, rotary_interleaved = rotary_opts[combo_index % len(rotary_opts)]
2516+
packed = packed_opts[combo_index % len(packed_opts)]
2517+
softcap = softcap_opts[combo_index % len(softcap_opts)]
2518+
use_smooth_softmax = smooth_opts[combo_index % len(smooth_opts)]
2519+
2520+
has_pos, has_attn = pos_ids_attn_bias[combo_index % len(pos_ids_attn_bias)]
2521+
head_sink = head_sink_opts[combo_index % len(head_sink_opts)]
2522+
output_qk = qk_output[combo_index % len(qk_output)]
2523+
2524+
combo_index += 1
2525+
2526+
if rotary and h % 16 != 0: # rotary requires head_size to be a multiple of 16
2527+
continue
2528+
2529+
if use_smooth_softmax and head_sink:
2530+
continue
2531+
if config_class == PromptConfig:
2532+
config = config_class(
2533+
b,
2534+
s,
2535+
s2,
2536+
s + s2 + 8,
2537+
n,
2538+
n2,
2539+
h,
2540+
has_pos,
2541+
has_attn,
2542+
head_sink,
2543+
output_qk,
2544+
)
2545+
else: # Config
2546+
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
2547+
config = config_class(
2548+
b,
2549+
s,
2550+
s2,
2551+
sp,
2552+
n,
2553+
n2,
2554+
h,
2555+
has_pos,
2556+
has_attn,
2557+
head_sink,
2558+
output_qk,
2559+
)
2560+
2561+
params = {
2562+
"config": config,
2563+
"torch_type": precision["torch_type"],
2564+
"numpy_type": precision["numpy_type"],
2565+
"ort_type": precision["ort_type"],
2566+
"rtol": precision["rtol"],
2567+
"atol": precision["atol"],
2568+
"local": local,
2569+
"past_format": Formats.BNSH,
2570+
"rotary": rotary,
2571+
"rotary_interleaved": rotary_interleaved,
2572+
"packed": packed,
2573+
"softcap": softcap,
2574+
"use_smooth_softmax": use_smooth_softmax,
2575+
}
2576+
params.update(additional_params)
2577+
2578+
all_close = test_func(**params)
2579+
self.assertTrue(all_close)
25572580

25582581
def test_gqa_no_past(self):
25592582
print("-------- TEST GQA NO PAST (PROMPT CASE) ---------")
2560-
batches = [3] if pipeline_mode else [1, 3, 5]
2583+
batches = [1, 3] if pipeline_mode else [1, 3, 5]
25612584
seqs = (
25622585
[(127, 127), (240, 240)]
25632586
if pipeline_mode
25642587
else [(127, 127), (35, 35), (2000, 2000), (200, 200), (240, 240), (8000, 8000)]
25652588
)
2566-
pos_ids_attn_bias = (
2567-
[(False, False), (True, True)]
2568-
if pipeline_mode
2569-
else [(False, False), (True, True), (False, True), (True, False)]
2570-
)
2589+
pos_ids_attn_bias = [(False, False), (True, True), (False, True), (True, False)]
25712590
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2572-
h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
2573-
qk_output = (
2574-
[QKOutputType.NO_OUTPUT]
2575-
if pipeline_mode
2576-
else [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX]
2577-
)
2591+
h_sizes = [40, 128] if pipeline_mode else [32, 48, 64, 80, 96, 128, 160, 192, 224, 256]
2592+
qk_output = [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX]
25782593

25792594
# Test with buffer
25802595
self.run_test_config(
@@ -2601,24 +2616,16 @@ def test_gqa_no_past(self):
26012616

26022617
def test_gqa_past(self):
26032618
print("-------- TEST GQA PAST (TOKEN GEN) ---------")
2604-
batches = [1] if pipeline_mode else [1, 3, 5]
2619+
batches = [1, 3] if pipeline_mode else [1, 3, 5]
26052620
seqs = (
26062621
[(1, 128)]
26072622
if pipeline_mode
26082623
else [(1, 128), (1, 339), (1, 1024), (1, 5000), (1, 800), (1, 256), (1, 799), (1, 2048)]
26092624
)
2610-
pos_ids_attn_bias = (
2611-
[(False, False), (True, True)]
2612-
if pipeline_mode
2613-
else [(False, False), (True, True), (False, True), (True, False)]
2614-
)
2625+
pos_ids_attn_bias = [(False, False), (True, True), (False, True), (True, False)]
26152626
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2616-
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
2617-
qk_output = (
2618-
[QKOutputType.NO_OUTPUT]
2619-
if pipeline_mode
2620-
else [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX]
2621-
)
2627+
h_sizes = [64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
2628+
qk_output = [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX]
26222629

26232630
# Test with buffer
26242631
self.run_test_config(parity_check_gqa_past, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias, qk_output)
@@ -2638,18 +2645,14 @@ def test_gqa_interactive_one_batch(self):
26382645
print("-------- TEST GQA INTERACTIVE ---------")
26392646
batches = [1]
26402647
seqs = (
2641-
[(256, 2048)]
2648+
[(256, 2048), (1, 128)]
26422649
if pipeline_mode
26432650
else [(1, 128), (1, 339), (1, 1024), (1, 5000), (1, 800), (1, 256), (1, 799), (1, 2048)]
26442651
)
2645-
pos_ids_attn_bias = (
2646-
[(False, False), (True, True)]
2647-
if pipeline_mode
2648-
else [(False, False), (True, True), (False, True), (True, False)]
2649-
)
2652+
pos_ids_attn_bias = [(False, False), (True, True), (False, True), (True, False)]
26502653
qk_output = [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX]
26512654
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2652-
h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
2655+
h_sizes = [32, 80] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
26532656

26542657
# Only test softcap=0.0 for interactive case as per original
26552658
self.run_test_config(

0 commit comments

Comments
 (0)