66
77import pytest
88
9+ DEFAULT_PREFILL_SEQ_LEN = 128
10+
911
1012def expand_test_cases_with_position_ids_ranges (base_cases ):
1113 """
@@ -121,46 +123,45 @@ def build_expanded_test_ids(expanded_cases):
121123 return expanded_ids
122124
123125
124- def get_base_test_cases (users_per_row , prefill_seq_lens , include_decode_random_pos_ids = True ):
126+ def get_base_test_cases (users_per_row , prefill_seq_len , include_decode_random_pos_ids = True ):
125127 """
126128 Build base test cases for decode and prefill paths.
127129
128- Used by :
130+ This helper is only exercised by these tests. :
129131 - models/demos/deepseek_v3/tests/test_mla.py
130132 - models/demos/deepseek_v3/tests/test_decoder_block.py
131133 - models/demos/deepseek_v3/tests/test_model.py
132134
133- This helper is only exercised by these tests.
134-
135135 Args:
136136 users_per_row: Number of users per row (USERS_PER_ROW).
137- prefill_seq_lens: Iterable of prefill sequence lengths .
137+ prefill_seq_len: Prefill sequence length to use when DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is not set .
138138 include_decode_random_pos_ids: If True, include ("decode", 1, users_per_row, None).
139139
140- environment variable DEEPSEEK_MAX_SEQ_LEN is primarily a CI override to expand prefill and decode coverage.
141- When set, we add specific prefill and decode position_ids (0 and max_seq_len - 1) in
142- addition to the default random prefill and decode position_ids cases.
140+ The environment variable DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is primarily a CI override to expand
141+ prefill and decode coverage.
143142
144143 Behavior:
145- - Adds decode cases:
146- - random position_ids (optional)
147- - position_id 0
148- - position_id max_seq_len - 1 when DEEPSEEK_MAX_SEQ_LEN is set
149- - Adds prefill cases:
150- - a direct prefill at DEEPSEEK_MAX_SEQ_LEN when set
151- - the standard prefill seq lens with skip marks when DEEPSEEK_MAX_SEQ_LEN is not set
144+ - Decode cases:
145+ - when DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is set, additionally includes:
146+ - position_id 0
147+ - position_id max_seq_len - 1
148+ - Prefill cases:
149+ - when DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is not set, includes one prefill case using
150+ prefill_seq_len: ("prefill", prefill_seq_len, 1, None)
151+ - when DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is set, replaces the prefill list with a single case:
152+ ("prefill", max_seq_len, 1, None)
152153
153154 """
154155 base_cases = []
155156 if include_decode_random_pos_ids :
156157 base_cases += [("decode" , 1 , users_per_row , None )]
157158
158- max_seq_len_env = os .getenv ("DEEPSEEK_MAX_SEQ_LEN " )
159+ max_seq_len_env = os .getenv ("DEEPSEEK_MAX_SEQ_LEN_OVERRIDE " )
159160 if max_seq_len_env is None :
160- # If DEEPSEEK_MAX_SEQ_LEN is not set, use the default prefill sequence length.
161- base_cases += [("prefill" , seq_len , 1 , None ) for seq_len in prefill_seq_lens ]
161+ # If DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is not set, use the default prefill sequence length.
162+ base_cases += [("prefill" , prefill_seq_len , 1 , None )]
162163 else :
163- # If DEEPSEEK_MAX_SEQ_LEN is set, use it to expand prefill and decode coverage.
164+ # If DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is set, use it to expand prefill and decode coverage.
164165 max_seq_len = int (max_seq_len_env )
165166 base_cases += [
166167 ("decode" , 1 , users_per_row , 0 ), # decode position_id 0
@@ -170,7 +171,7 @@ def get_base_test_cases(users_per_row, prefill_seq_lens, include_decode_random_p
170171 return base_cases
171172
172173
173- def build_test_cases_and_ids (users_per_row , prefill_seq_lens , include_decode_random_pos_ids = True ):
174+ def build_test_cases_and_ids (users_per_row , prefill_seq_len , include_decode_random_pos_ids = True ):
174175 """
175176 Build base test cases and return expanded cases with matching pytest IDs.
176177
@@ -179,7 +180,7 @@ def build_test_cases_and_ids(users_per_row, prefill_seq_lens, include_decode_ran
179180 - expand_test_cases_with_position_ids_ranges
180181 - build_expanded_test_ids
181182 """
182- base_cases = get_base_test_cases (users_per_row , prefill_seq_lens , include_decode_random_pos_ids )
183+ base_cases = get_base_test_cases (users_per_row , prefill_seq_len , include_decode_random_pos_ids )
183184 expanded_cases = expand_test_cases_with_position_ids_ranges (base_cases )
184185 expanded_ids = build_expanded_test_ids (expanded_cases )
185186 return expanded_cases , expanded_ids
0 commit comments