Skip to content

Commit 296599e

Browse files
committed
Refactor all tests
Signed-off-by: Pratikkumar Prajapati <pprajapati@tenstorrent.com>
1 parent 901538e commit 296599e

File tree

13 files changed

+115
-179
lines changed

13 files changed

+115
-179
lines changed

.github/workflows/galaxy-deepseek-tests-long-impl.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ jobs:
6565
timeout-minutes: 360
6666
run: |
6767
uv pip install -r models/demos/deepseek_v3/reference/deepseek/requirements.txt
68-
DEEPSEEK_MAX_SEQ_LEN=16384 pytest models/demos/deepseek_v3/tests/test_mla.py --timeout 3600 --durations=0
69-
DEEPSEEK_MAX_SEQ_LEN=16384 pytest models/demos/deepseek_v3/tests/test_decoder_block.py --timeout 7200 --durations=0
70-
DEEPSEEK_MAX_SEQ_LEN=2048 pytest models/demos/deepseek_v3/tests/test_model.py --timeout 10800 --durations=0
68+
DEEPSEEK_MAX_SEQ_LEN_OVERRIDE=16384 pytest models/demos/deepseek_v3/tests/test_mla.py --timeout 3600 --durations=0
69+
DEEPSEEK_MAX_SEQ_LEN_OVERRIDE=16384 pytest models/demos/deepseek_v3/tests/test_decoder_block.py --timeout 7200 --durations=0
70+
DEEPSEEK_MAX_SEQ_LEN_OVERRIDE=2048 pytest models/demos/deepseek_v3/tests/test_model.py --timeout 10800 --durations=0
7171
- uses: tenstorrent/tt-metal/.github/actions/slack-report@main
7272
if: ${{ failure() }}
7373
with:

models/demos/deepseek_v3/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,13 @@ def hf_config_short(request, hf_config):
182182
Build a shortened DeepSeek config for tests.
183183
184184
Environment variables:
185-
DEEPSEEK_MAX_SEQ_LEN: Optional override for `hf_config_short.max_seq_len`.
185+
DEEPSEEK_MAX_SEQ_LEN_OVERRIDE: Optional override for `hf_config_short.max_seq_len`.
186186
When set (e.g. "32768"), tests that read `hf_config_short.max_seq_len`
187187
can exercise longer sequence lengths without modifying code.
188188
"""
189189
hf_config_out = deepcopy(hf_config)
190190
hf_config_out.num_hidden_layers = getattr(request, "param", 1)
191-
max_seq_len_override = os.getenv("DEEPSEEK_MAX_SEQ_LEN")
191+
max_seq_len_override = os.getenv("DEEPSEEK_MAX_SEQ_LEN_OVERRIDE")
192192
if max_seq_len_override is not None:
193193
hf_config_out.max_seq_len = int(max_seq_len_override)
194194
else:

models/demos/deepseek_v3/tests/pytest_utils.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import pytest
88

9+
DEFAULT_PREFILL_SEQ_LEN = 128
10+
911

1012
def expand_test_cases_with_position_ids_ranges(base_cases):
1113
"""
@@ -121,46 +123,45 @@ def build_expanded_test_ids(expanded_cases):
121123
return expanded_ids
122124

123125

124-
def get_base_test_cases(users_per_row, prefill_seq_lens, include_decode_random_pos_ids=True):
126+
def get_base_test_cases(users_per_row, prefill_seq_len, include_decode_random_pos_ids=True):
125127
"""
126128
Build base test cases for decode and prefill paths.
127129
128-
Used by:
130+
This helper is only exercised by these tests.:
129131
- models/demos/deepseek_v3/tests/test_mla.py
130132
- models/demos/deepseek_v3/tests/test_decoder_block.py
131133
- models/demos/deepseek_v3/tests/test_model.py
132134
133-
This helper is only exercised by these tests.
134-
135135
Args:
136136
users_per_row: Number of users per row (USERS_PER_ROW).
137-
prefill_seq_lens: Iterable of prefill sequence lengths.
137+
prefill_seq_len: Prefill sequence length to use when DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is not set.
138138
include_decode_random_pos_ids: If True, include ("decode", 1, users_per_row, None).
139139
140-
environment variable DEEPSEEK_MAX_SEQ_LEN is primarily a CI override to expand prefill and decode coverage.
141-
When set, we add specific prefill and decode position_ids (0 and max_seq_len - 1) in
142-
addition to the default random prefill and decode position_ids cases.
140+
The environment variable DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is primarily a CI override to expand
141+
prefill and decode coverage.
143142
144143
Behavior:
145-
- Adds decode cases:
146-
- random position_ids (optional)
147-
- position_id 0
148-
- position_id max_seq_len - 1 when DEEPSEEK_MAX_SEQ_LEN is set
149-
- Adds prefill cases:
150-
- a direct prefill at DEEPSEEK_MAX_SEQ_LEN when set
151-
- the standard prefill seq lens with skip marks when DEEPSEEK_MAX_SEQ_LEN is not set
144+
- Decode cases:
145+
- when DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is set, additionally includes:
146+
- position_id 0
147+
- position_id max_seq_len - 1
148+
- Prefill cases:
149+
- when DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is not set, includes one prefill case using
150+
prefill_seq_len: ("prefill", prefill_seq_len, 1, None)
151+
- when DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is set, replaces the prefill list with a single case:
152+
("prefill", max_seq_len, 1, None)
152153
153154
"""
154155
base_cases = []
155156
if include_decode_random_pos_ids:
156157
base_cases += [("decode", 1, users_per_row, None)]
157158

158-
max_seq_len_env = os.getenv("DEEPSEEK_MAX_SEQ_LEN")
159+
max_seq_len_env = os.getenv("DEEPSEEK_MAX_SEQ_LEN_OVERRIDE")
159160
if max_seq_len_env is None:
160-
# If DEEPSEEK_MAX_SEQ_LEN is not set, use the default prefill sequence length.
161-
base_cases += [("prefill", seq_len, 1, None) for seq_len in prefill_seq_lens]
161+
# If DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is not set, use the default prefill sequence length.
162+
base_cases += [("prefill", prefill_seq_len, 1, None)]
162163
else:
163-
# If DEEPSEEK_MAX_SEQ_LEN is set, use it to expand prefill and decode coverage.
164+
# If DEEPSEEK_MAX_SEQ_LEN_OVERRIDE is set, use it to expand prefill and decode coverage.
164165
max_seq_len = int(max_seq_len_env)
165166
base_cases += [
166167
("decode", 1, users_per_row, 0), # decode position_id 0
@@ -170,7 +171,7 @@ def get_base_test_cases(users_per_row, prefill_seq_lens, include_decode_random_p
170171
return base_cases
171172

172173

173-
def build_test_cases_and_ids(users_per_row, prefill_seq_lens, include_decode_random_pos_ids=True):
174+
def build_test_cases_and_ids(users_per_row, prefill_seq_len, include_decode_random_pos_ids=True):
174175
"""
175176
Build base test cases and return expanded cases with matching pytest IDs.
176177
@@ -179,7 +180,7 @@ def build_test_cases_and_ids(users_per_row, prefill_seq_lens, include_decode_ran
179180
- expand_test_cases_with_position_ids_ranges
180181
- build_expanded_test_ids
181182
"""
182-
base_cases = get_base_test_cases(users_per_row, prefill_seq_lens, include_decode_random_pos_ids)
183+
base_cases = get_base_test_cases(users_per_row, prefill_seq_len, include_decode_random_pos_ids)
183184
expanded_cases = expand_test_cases_with_position_ids_ranges(base_cases)
184185
expanded_ids = build_expanded_test_ids(expanded_cases)
185186
return expanded_cases, expanded_ids

models/demos/deepseek_v3/tests/test_decoder_block.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
from transformers.configuration_utils import PretrainedConfig
1111

1212
import ttnn
13-
from models.demos.deepseek_v3.conftest import PREFILL_SEQ_LENS
1413
from models.demos.deepseek_v3.reference.modeling_deepseek import DeepseekV3DecoderLayer
15-
from models.demos.deepseek_v3.tests.pytest_utils import build_test_cases_and_ids
14+
from models.demos.deepseek_v3.tests.pytest_utils import DEFAULT_PREFILL_SEQ_LEN, build_test_cases_and_ids
1615
from models.demos.deepseek_v3.tt.decoder_block.decoder_block_2d import DecoderBlock2D
1716
from models.demos.deepseek_v3.tt.decoder_block.decoder_block_2d_base import DecoderBlock2DBase
1817
from models.demos.deepseek_v3.tt.decoder_block.moe_decoder_block_2d import MoEDecoderBlock2D
@@ -196,7 +195,7 @@ def run_test_forward_pass_decoder2d(
196195

197196
TEST_CASES, TEST_IDS = build_test_cases_and_ids(
198197
USERS_PER_ROW,
199-
PREFILL_SEQ_LENS[0:1], # list of default prefill sequence lengths to test
198+
DEFAULT_PREFILL_SEQ_LEN, # default prefill sequence length to test
200199
include_decode_random_pos_ids=True, # include decode random position_ids case
201200
)
202201

models/demos/deepseek_v3/tests/test_embedding.py

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44

5+
import os
6+
57
import pytest
68
import torch
79
from loguru import logger
@@ -10,7 +12,7 @@
1012
from torch.nn import Embedding as EmbeddingReference
1113

1214
import ttnn
13-
from models.demos.deepseek_v3.conftest import PREFILL_SEQ_LENS
15+
from models.demos.deepseek_v3.tests.pytest_utils import DEFAULT_PREFILL_SEQ_LEN
1416
from models.demos.deepseek_v3.tt.embedding.embedding1d import Embedding1D
1517
from models.demos.deepseek_v3.tt.embedding.embedding2d import Embedding2D
1618
from models.demos.deepseek_v3.utils.config_helpers import sub_state_dict
@@ -23,52 +25,27 @@
2325
run_module_forward,
2426
)
2527

28+
_max_seq_len_env = os.getenv("DEEPSEEK_MAX_SEQ_LEN_OVERRIDE")
29+
_prefill_seq_len = int(_max_seq_len_env) if _max_seq_len_env is not None else DEFAULT_PREFILL_SEQ_LEN
30+
2631

2732
@pytest.mark.parametrize(
28-
"device_params",
33+
"EmbeddingClass,mode,batch_size_or_seq_len",
2934
[
30-
{"fabric_config": ttnn.FabricConfig.FABRIC_1D},
35+
pytest.param(Embedding1D, "decode", 32, marks=pytest.mark.requires_device(["TG"])),
36+
pytest.param(Embedding2D, "decode", 128, marks=pytest.mark.requires_device(["TG", "DUAL", "QUAD"])),
37+
pytest.param(Embedding1D, "prefill", _prefill_seq_len, marks=pytest.mark.requires_device(["TG"])),
38+
pytest.param(
39+
Embedding2D, "prefill", _prefill_seq_len, marks=pytest.mark.requires_device(["TG", "DUAL", "QUAD"])
40+
),
3141
],
32-
indirect=True,
3342
)
3443
@pytest.mark.parametrize(
35-
"EmbeddingClass,mode,batch_size_or_seq_len",
44+
"device_params",
3645
[
37-
pytest.param(Embedding1D, "decode", 32, marks=pytest.mark.requires_device(["TG"])),
38-
pytest.param(Embedding2D, "decode", 128, marks=pytest.mark.requires_device(["TG", "DUAL", "QUAD"])),
39-
]
40-
+ [
41-
pytest.param(Embedding1D, "prefill", seq_len, marks=pytest.mark.requires_device(["TG"]))
42-
if seq_len == 128
43-
else pytest.param(
44-
Embedding1D,
45-
"prefill",
46-
seq_len,
47-
marks=[
48-
pytest.mark.requires_device(["TG"]),
49-
pytest.mark.skip(
50-
f"Skipping prefilling with seq_len={seq_len} since this would cause us to exceed our available CI workload time"
51-
),
52-
],
53-
)
54-
for seq_len in PREFILL_SEQ_LENS
55-
]
56-
+ [
57-
pytest.param(Embedding2D, "prefill", seq_len, marks=pytest.mark.requires_device(["TG", "DUAL", "QUAD"]))
58-
if seq_len == 128
59-
else pytest.param(
60-
Embedding2D,
61-
"prefill",
62-
seq_len,
63-
marks=[
64-
pytest.mark.requires_device(["TG", "DUAL", "QUAD"]),
65-
pytest.mark.skip(
66-
f"Skipping prefilling with seq_len={seq_len} since this would cause us to exceed our available CI workload time"
67-
),
68-
],
69-
)
70-
for seq_len in PREFILL_SEQ_LENS
46+
{"fabric_config": ttnn.FabricConfig.FABRIC_1D},
7147
],
48+
indirect=True,
7249
)
7350
@pytest.mark.parametrize(
7451
"generate_reference_io",

models/demos/deepseek_v3/tests/test_lm_head.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44

5+
import os
56
from pathlib import Path
67
from typing import Any
78

@@ -11,7 +12,7 @@
1112
from loguru import logger
1213

1314
import ttnn
14-
from models.demos.deepseek_v3.conftest import PREFILL_SEQ_LENS
15+
from models.demos.deepseek_v3.tests.pytest_utils import DEFAULT_PREFILL_SEQ_LEN
1516
from models.demos.deepseek_v3.tt.ccl import CCL
1617
from models.demos.deepseek_v3.tt.lm_head import LMHead
1718
from models.demos.deepseek_v3.utils.config_helpers import sub_state_dict
@@ -39,30 +40,23 @@ def forward(self, hidden_states):
3940
return self.lm_head(hidden_states)
4041

4142

43+
_max_seq_len_env = os.getenv("DEEPSEEK_MAX_SEQ_LEN_OVERRIDE")
44+
_prefill_seq_len = int(_max_seq_len_env) if _max_seq_len_env is not None else DEFAULT_PREFILL_SEQ_LEN
45+
46+
4247
@pytest.mark.parametrize(
43-
"device_params",
48+
"mode,seq_len",
4449
[
45-
{"fabric_config": ttnn.FabricConfig.FABRIC_1D},
50+
("decode", 32),
51+
("prefill", _prefill_seq_len),
4652
],
47-
indirect=True,
4853
)
4954
@pytest.mark.parametrize(
50-
"mode,seq_len",
55+
"device_params",
5156
[
52-
("decode", 32),
53-
]
54-
+ [
55-
("prefill", seq_len)
56-
if seq_len == 128
57-
else pytest.param(
58-
"prefill",
59-
seq_len,
60-
marks=pytest.mark.skip(
61-
f"Skipping prefilling with seq_len={seq_len} since this would cause us to exceed our available CI workload time"
62-
),
63-
)
64-
for seq_len in PREFILL_SEQ_LENS
57+
{"fabric_config": ttnn.FabricConfig.FABRIC_1D},
6558
],
59+
indirect=True,
6660
)
6761
@pytest.mark.requires_device(["TG"])
6862
def test_forward_pass(

models/demos/deepseek_v3/tests/test_mla.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@
1111

1212
import ttnn
1313
from models.common.utility_functions import comp_pcc
14-
from models.demos.deepseek_v3.conftest import PREFILL_SEQ_LENS
1514
from models.demos.deepseek_v3.reference.modeling_deepseek import DeepseekV3Attention
16-
from models.demos.deepseek_v3.tests.pytest_utils import build_test_cases_and_ids
17-
from models.demos.deepseek_v3.tt.mla.mla1d import MLA1D
15+
from models.demos.deepseek_v3.tests.pytest_utils import DEFAULT_PREFILL_SEQ_LEN, build_test_cases_and_ids
1816
from models.demos.deepseek_v3.tt.mla.mla2d import MLA2D
1917
from models.demos.deepseek_v3.utils.config_helpers import USERS_PER_ROW, sub_state_dict
2018
from models.demos.deepseek_v3.utils.run_config import create_run_config
@@ -312,7 +310,7 @@ def run_test_forward_pass_mla2d(
312310

313311
TEST_CASES, TEST_IDS = build_test_cases_and_ids(
314312
USERS_PER_ROW,
315-
PREFILL_SEQ_LENS[0:1], # list of default prefill sequence lengths to test
313+
DEFAULT_PREFILL_SEQ_LEN, # default prefill sequence length to test
316314
include_decode_random_pos_ids=True, # include decode random position_ids case
317315
)
318316

models/demos/deepseek_v3/tests/test_mlp.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44

5+
import os
6+
57
import pytest
68
import torch
79
from loguru import logger
810

911
import ttnn
1012
from models.common.utility_functions import comp_pcc
11-
from models.demos.deepseek_v3.conftest import PREFILL_SEQ_LENS
1213
from models.demos.deepseek_v3.reference.modeling_deepseek import DeepseekV3MLP
14+
from models.demos.deepseek_v3.tests.pytest_utils import DEFAULT_PREFILL_SEQ_LEN
1315
from models.demos.deepseek_v3.tt.mlp.mlp import MLP
1416
from models.demos.deepseek_v3.tt.mlp.mlp_dequant import MLPDequant
1517
from models.demos.deepseek_v3.tt.mlp.non_expert import NonExpert
@@ -128,6 +130,17 @@ def run_weight_conversion_test(MLPClass, hf_config, state_dict, tmp_path, refere
128130
ttnn.deallocate(w1_ttnn)
129131

130132

133+
_max_seq_len_env = os.getenv("DEEPSEEK_MAX_SEQ_LEN_OVERRIDE")
134+
_prefill_seq_len = int(_max_seq_len_env) if _max_seq_len_env is not None else DEFAULT_PREFILL_SEQ_LEN
135+
136+
137+
@pytest.mark.parametrize(
138+
"mode,seq_len",
139+
[
140+
("decode", 32),
141+
("prefill", _prefill_seq_len),
142+
],
143+
)
131144
@pytest.mark.parametrize("device_params", [{"fabric_config": ttnn.FabricConfig.FABRIC_1D}], indirect=True)
132145
@pytest.mark.parametrize(
133146
"MLPClass,module_path",
@@ -137,24 +150,6 @@ def run_weight_conversion_test(MLPClass, hf_config, state_dict, tmp_path, refere
137150
(SharedExpert, "model.layers.3.mlp.shared_experts"),
138151
],
139152
)
140-
@pytest.mark.parametrize(
141-
"mode,seq_len",
142-
[
143-
("decode", 32),
144-
]
145-
+ [
146-
("prefill", seq_len)
147-
if seq_len == 128
148-
else pytest.param(
149-
"prefill",
150-
seq_len,
151-
marks=pytest.mark.skip(
152-
f"Skipping prefilling with seq_len={seq_len} since this would cause us to exceed our available CI workload time"
153-
),
154-
)
155-
for seq_len in PREFILL_SEQ_LENS
156-
],
157-
)
158153
def test_forward_pass(
159154
MLPClass,
160155
module_path,

models/demos/deepseek_v3/tests/test_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
from transformers.configuration_utils import PretrainedConfig
99

1010
import ttnn
11-
from models.demos.deepseek_v3.conftest import PREFILL_SEQ_LENS
1211
from models.demos.deepseek_v3.reference.modeling_deepseek import DeepseekV3ForCausalLM
13-
from models.demos.deepseek_v3.tests.pytest_utils import build_test_cases_and_ids
12+
from models.demos.deepseek_v3.tests.pytest_utils import DEFAULT_PREFILL_SEQ_LEN, build_test_cases_and_ids
1413
from models.demos.deepseek_v3.tt.mla.mla2d import MLA2D
1514
from models.demos.deepseek_v3.tt.model.row_batched_model import RowBatchedModel
1615
from models.demos.deepseek_v3.utils.config_helpers import USERS_PER_ROW, sub_state_dict
@@ -200,7 +199,7 @@ def run_test_forward_pass_dpmodel(
200199

201200
TEST_CASES, TEST_IDS = build_test_cases_and_ids(
202201
USERS_PER_ROW,
203-
PREFILL_SEQ_LENS[0:1], # list of default prefill sequence lengths to test
202+
DEFAULT_PREFILL_SEQ_LEN, # default prefill sequence length to test
204203
include_decode_random_pos_ids=False, # TODO: Remove include_decode_random_pos_ids=False once non-zero position_ids case is working.
205204
)
206205

0 commit comments

Comments
 (0)