@@ -2491,90 +2491,105 @@ def run_test_config(
24912491 print (
24922492 f"\n Running tests with precision: { 'FLOAT16' if precision ['ort_type' ] == TensorProto .FLOAT16 else 'FLOAT32' } "
24932493 )
2494+ local_opts = [additional_params ["local" ]] if "local" in additional_params else [False , True ]
2495+ rotary_opts = (
2496+ [(additional_params ["rotary" ], additional_params ["rotary_interleaved" ])]
2497+ if "rotary" in additional_params
2498+ else [(False , False ), (True , False ), (True , True )]
2499+ )
2500+ packed_opts = [additional_params ["packed" ]] if "packed" in additional_params else [False , True ]
2501+ softcap_opts = [additional_params ["softcap" ]] if "softcap" in additional_params else [0.0 , 50.0 ]
2502+ smooth_opts = (
2503+ [additional_params ["use_smooth_softmax" ]]
2504+ if "use_smooth_softmax" in additional_params
2505+ else [False , True ]
2506+ )
2507+ head_sink_opts = [additional_params ["head_sink" ]] if "head_sink" in additional_params else [False , True ]
2508+
2509+ combo_index = 0
24942510 for b in batches :
24952511 for s , s2 in seqs :
24962512 for n , n2 in num_h :
24972513 for h in h_sizes :
2498- for local in [False , True ]:
2499- for rotary , rotary_interleaved in [(False , False ), (True , False ), (True , True )]:
2500- for packed in [False , True ]:
2501- for softcap in [0.0 , 50.0 ]:
2502- for use_smooth_softmax in [False , True ]:
2503- for has_pos , has_attn in pos_ids_attn_bias :
2504- for head_sink in [False , True ]:
2505- if use_smooth_softmax and head_sink :
2506- continue
2507- for output_qk in qk_output :
2508- if config_class == PromptConfig :
2509- config = config_class (
2510- b ,
2511- s ,
2512- s2 ,
2513- s + s2 + 8 ,
2514- n ,
2515- n2 ,
2516- h ,
2517- has_pos ,
2518- has_attn ,
2519- head_sink ,
2520- output_qk ,
2521- )
2522- else : # Config
2523- sp = random .randint (1 , s2 - s ) if s2 - s > 0 else 0
2524- config = config_class (
2525- b ,
2526- s ,
2527- s2 ,
2528- sp ,
2529- n ,
2530- n2 ,
2531- h ,
2532- has_pos ,
2533- has_attn ,
2534- head_sink ,
2535- output_qk ,
2536- )
2537-
2538- params = {
2539- "config" : config ,
2540- "torch_type" : precision ["torch_type" ],
2541- "numpy_type" : precision ["numpy_type" ],
2542- "ort_type" : precision ["ort_type" ],
2543- "rtol" : precision ["rtol" ],
2544- "atol" : precision ["atol" ],
2545- "local" : local ,
2546- "past_format" : Formats .BNSH ,
2547- "rotary" : rotary ,
2548- "rotary_interleaved" : rotary_interleaved ,
2549- "packed" : packed ,
2550- "softcap" : softcap ,
2551- "use_smooth_softmax" : use_smooth_softmax ,
2552- }
2553- params .update (additional_params )
2554-
2555- all_close = test_func (** params )
2556- self .assertTrue (all_close )
2514+ local = local_opts [combo_index % len (local_opts )]
2515+ rotary , rotary_interleaved = rotary_opts [combo_index % len (rotary_opts )]
2516+ packed = packed_opts [combo_index % len (packed_opts )]
2517+ softcap = softcap_opts [combo_index % len (softcap_opts )]
2518+ use_smooth_softmax = smooth_opts [combo_index % len (smooth_opts )]
2519+
2520+ has_pos , has_attn = pos_ids_attn_bias [combo_index % len (pos_ids_attn_bias )]
2521+ head_sink = head_sink_opts [combo_index % len (head_sink_opts )]
2522+ output_qk = qk_output [combo_index % len (qk_output )]
2523+
2524+ combo_index += 1
2525+
2526+ if rotary and h % 16 != 0 : # rotary requires head_size to be a multiple of 16
2527+ continue
2528+
2529+ if use_smooth_softmax and head_sink :
2530+ continue
2531+ if config_class == PromptConfig :
2532+ config = config_class (
2533+ b ,
2534+ s ,
2535+ s2 ,
2536+ s + s2 + 8 ,
2537+ n ,
2538+ n2 ,
2539+ h ,
2540+ has_pos ,
2541+ has_attn ,
2542+ head_sink ,
2543+ output_qk ,
2544+ )
2545+ else : # Config
2546+ sp = random .randint (1 , s2 - s ) if s2 - s > 0 else 0
2547+ config = config_class (
2548+ b ,
2549+ s ,
2550+ s2 ,
2551+ sp ,
2552+ n ,
2553+ n2 ,
2554+ h ,
2555+ has_pos ,
2556+ has_attn ,
2557+ head_sink ,
2558+ output_qk ,
2559+ )
2560+
2561+ params = {
2562+ "config" : config ,
2563+ "torch_type" : precision ["torch_type" ],
2564+ "numpy_type" : precision ["numpy_type" ],
2565+ "ort_type" : precision ["ort_type" ],
2566+ "rtol" : precision ["rtol" ],
2567+ "atol" : precision ["atol" ],
2568+ "local" : local ,
2569+ "past_format" : Formats .BNSH ,
2570+ "rotary" : rotary ,
2571+ "rotary_interleaved" : rotary_interleaved ,
2572+ "packed" : packed ,
2573+ "softcap" : softcap ,
2574+ "use_smooth_softmax" : use_smooth_softmax ,
2575+ }
2576+ params .update (additional_params )
2577+
2578+ all_close = test_func (** params )
2579+ self .assertTrue (all_close )
25572580
25582581 def test_gqa_no_past (self ):
25592582 print ("-------- TEST GQA NO PAST (PROMPT CASE) ---------" )
2560- batches = [3 ] if pipeline_mode else [1 , 3 , 5 ]
2583+ batches = [1 , 3 ] if pipeline_mode else [1 , 3 , 5 ]
25612584 seqs = (
25622585 [(127 , 127 ), (240 , 240 )]
25632586 if pipeline_mode
25642587 else [(127 , 127 ), (35 , 35 ), (2000 , 2000 ), (200 , 200 ), (240 , 240 ), (8000 , 8000 )]
25652588 )
2566- pos_ids_attn_bias = (
2567- [(False , False ), (True , True )]
2568- if pipeline_mode
2569- else [(False , False ), (True , True ), (False , True ), (True , False )]
2570- )
2589+ pos_ids_attn_bias = [(False , False ), (True , True ), (False , True ), (True , False )]
25712590 num_h = [(32 , 8 )] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2572- h_sizes = [128 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2573- qk_output = (
2574- [QKOutputType .NO_OUTPUT ]
2575- if pipeline_mode
2576- else [QKOutputType .NO_OUTPUT , QKOutputType .BEFORE_SOFTMAX , QKOutputType .AFTER_SOFTMAX ]
2577- )
2591+ h_sizes = [40 , 128 ] if pipeline_mode else [32 , 48 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2592+ qk_output = [QKOutputType .NO_OUTPUT , QKOutputType .BEFORE_SOFTMAX , QKOutputType .AFTER_SOFTMAX ]
25782593
25792594 # Test with buffer
25802595 self .run_test_config (
@@ -2601,24 +2616,16 @@ def test_gqa_no_past(self):
26012616
26022617 def test_gqa_past (self ):
26032618 print ("-------- TEST GQA PAST (TOKEN GEN) ---------" )
2604- batches = [1 ] if pipeline_mode else [1 , 3 , 5 ]
2619+ batches = [1 , 3 ] if pipeline_mode else [1 , 3 , 5 ]
26052620 seqs = (
26062621 [(1 , 128 )]
26072622 if pipeline_mode
26082623 else [(1 , 128 ), (1 , 339 ), (1 , 1024 ), (1 , 5000 ), (1 , 800 ), (1 , 256 ), (1 , 799 ), (1 , 2048 )]
26092624 )
2610- pos_ids_attn_bias = (
2611- [(False , False ), (True , True )]
2612- if pipeline_mode
2613- else [(False , False ), (True , True ), (False , True ), (True , False )]
2614- )
2625+ pos_ids_attn_bias = [(False , False ), (True , True ), (False , True ), (True , False )]
26152626 num_h = [(9 , 3 )] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2616- h_sizes = [64 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2617- qk_output = (
2618- [QKOutputType .NO_OUTPUT ]
2619- if pipeline_mode
2620- else [QKOutputType .NO_OUTPUT , QKOutputType .BEFORE_SOFTMAX , QKOutputType .AFTER_SOFTMAX ]
2621- )
2627+ h_sizes = [64 , 256 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2628+ qk_output = [QKOutputType .NO_OUTPUT , QKOutputType .BEFORE_SOFTMAX , QKOutputType .AFTER_SOFTMAX ]
26222629
26232630 # Test with buffer
26242631 self .run_test_config (parity_check_gqa_past , Config , batches , seqs , num_h , h_sizes , pos_ids_attn_bias , qk_output )
@@ -2638,18 +2645,14 @@ def test_gqa_interactive_one_batch(self):
26382645 print ("-------- TEST GQA INTERACTIVE ---------" )
26392646 batches = [1 ]
26402647 seqs = (
2641- [(256 , 2048 )]
2648+ [(256 , 2048 ), ( 1 , 128 ) ]
26422649 if pipeline_mode
26432650 else [(1 , 128 ), (1 , 339 ), (1 , 1024 ), (1 , 5000 ), (1 , 800 ), (1 , 256 ), (1 , 799 ), (1 , 2048 )]
26442651 )
2645- pos_ids_attn_bias = (
2646- [(False , False ), (True , True )]
2647- if pipeline_mode
2648- else [(False , False ), (True , True ), (False , True ), (True , False )]
2649- )
2652+ pos_ids_attn_bias = [(False , False ), (True , True ), (False , True ), (True , False )]
26502653 qk_output = [QKOutputType .NO_OUTPUT , QKOutputType .BEFORE_SOFTMAX , QKOutputType .AFTER_SOFTMAX ]
26512654 num_h = [(32 , 8 )] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2652- h_sizes = [32 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2655+ h_sizes = [32 , 80 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
26532656
26542657 # Only test softcap=0.0 for interactive case as per original
26552658 self .run_test_config (
0 commit comments