Skip to content

Commit 927a540

Browse files
authored
#33349: TopK general update and bug fix (#36136)
### Ticket #33349 ### Problem description TopK produced wrong results for some shapes from Qwen3-32b. Some of the functionality like - `indices_tensor` parameter were available only in Multicore version. ### What's changed - General code cleanup and refactor, - Unification of functionality in both program factories (some were available only to one or the other), - Possibility to utilize the full number of cores for SingleCore approach processing - parallelization over Ht, - Allowing the use of all tensor shapes and dim, - Resolved issue with accuracy drop with Owen-32b shapes, - Updated implementation docs, - Removed redundant kernel. ### Checklist - [x] [![All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml/badge.svg?branch=mgajewskiTT/33349-topk-slight-accuracy-drop-with-qwen3-32b-shapes)](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml?query=branch:mgajewskiTT/33349-topk-slight-accuracy-drop-with-qwen3-32b-shapes) - [x] [![Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml/badge.svg?branch=mgajewskiTT/33349-topk-slight-accuracy-drop-with-qwen3-32b-shapes)](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml?query=branch:mgajewskiTT/33349-topk-slight-accuracy-drop-with-qwen3-32b-shapes) - OK, eltwise test fail not connected to this PR - [ ] [![cpp-unit-tests](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml/badge.svg?branch=mgajewskiTT/33349-topk-slight-accuracy-drop-with-qwen3-32b-shapes)](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml?query=branch:mgajewskiTT/33349-topk-slight-accuracy-drop-with-qwen3-32b-shapes) - [x] New/Existing tests provide coverage for changes <img width="867" height="537" alt="New implementation and Current implementation" src="https://github.com/user-attachments/assets/913fc106-613b-44bb-90a0-23aad0dfe11a" /> #### Discovered issues #36329 #### Model tests If your changes cover model-related code, you should run tests corresponding to affected models and platforms (Single card, T3K, Galaxy). "Choose your pipeline" workflows facilitate running multiple kinds of tests in a single run. Each offers `models-mandatory` and `models-extended` presets. The former includes a minimal set of tests, to be run always. The latter extends that with additional ones - use your best judgement in deciding which is the most appropriate for your PR. - [ ] [![(Single) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml/badge.svg?branch=mgajewskiTT/33349-topk-slight-accuracy-drop-with-qwen3-32b-shapes)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml?query=branch:mgajewskiTT/33349-topk-slight-accuracy-drop-with-qwen3-32b-shapes) - [ ] `models-mandatory` preset (runs: [Device perf regressions](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml) and [Frequent model and ttnn tests](https://github.com/tenstorrent/tt-metal/actions/workflows/fast-dispatch-full-regressions-and-models.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/single-card-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml) tests) - [ ] other selection - specify runs - [ ] [![(T3K) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-t3k.yaml/badge.svg?branch=mgajewskiTT/33349-topk-slight-accuracy-drop-with-qwen3-32b-shapes)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-t3k.yaml?query=branch:mgajewskiTT/33349-topk-slight-accuracy-drop-with-qwen3-32b-shapes) - [ ] `models-mandatory` preset (runs: [Unit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-unit-tests.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-model-perf-tests.yaml) tests) - [ ] other selection - specify runs - [ ] [![(Galaxy) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml/badge.svg?branch=mgajewskiTT/33349-topk-slight-accuracy-drop-with-qwen3-32b-shapes)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml?query=branch:mgajewskiTT/33349-topk-slight-accuracy-drop-with-qwen3-32b-shapes) - [ ] `models-mandatory` preset (runs: [Quick tests](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-quick.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-model-perf-tests.yaml) tests) - [ ] other selection - specify runs
1 parent 5eac985 commit 927a540

28 files changed

+1831
-1071
lines changed

models/demos/deepseek_v3/tests/unit/test_topk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
TOPK_MEMORY_CONFIG = ttnn.L1_MEMORY_CONFIG
1414

1515
# Sub-core grids for mesh device tests
16-
SUB_CORE_GRIDS = ttnn.CoreRangeSet([ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(8, 9))])
16+
SUB_CORE_GRIDS = ttnn.CoreRangeSet([ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(6, 7))])
1717

1818

1919
# k=32 matches the DeepSeek v3 MoE gating configuration, where the gate selects 32 experts per token.

tests/ttnn/unit_tests/operations/reduce/test_topk.py

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
1+
# SPDX-FileCopyrightText: © 2026 Tenstorrent AI ULC
22

33
# SPDX-License-Identifier: Apache-2.0
44

@@ -8,49 +8,61 @@
88

99
import torch
1010
import ttnn
11-
from tests.ttnn.utils_for_testing import assert_with_pcc
11+
from tests.ttnn.utils_for_testing import assert_allclose, assert_equal
12+
13+
UINT16_MAX = 65535
1214

1315

1416
def run_topk_test(N, C, H, W, k, dtype, dim, sorted, largest, device, sub_core_grids=None, pass_indices_tensor=False):
1517
torch.manual_seed(2005)
18+
19+
# Input tensor
1620
shape = [N, C, H, W]
21+
ttnn_indices_dtype = ttnn.uint16 if W <= UINT16_MAX else ttnn.uint32
22+
torch_indices_dtype = torch.uint16 if W <= UINT16_MAX else torch.uint32
1723
torch_dtype = torch.bfloat16
1824
input = torch.randn(shape, dtype=torch_dtype) * 0.9
19-
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=dim, largest=largest, sorted=True)
2025
ttnn_input = ttnn.from_torch(input, dtype, layout=ttnn.Layout.TILE, device=device)
2126

27+
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=dim, largest=largest, sorted=True)
28+
2229
if pass_indices_tensor:
23-
indices_tensor_torch = torch.zeros(shape, dtype=torch.int32)
30+
indices_tensor_torch = torch.zeros(shape, dtype=torch_indices_dtype)
2431
for i in range(W):
2532
indices_tensor_torch[:, :, :, i] = i
26-
indices_tensor = ttnn.from_torch(indices_tensor_torch, ttnn.uint16, layout=ttnn.Layout.TILE, device=device)
33+
indices_tensor = ttnn.from_torch(
34+
indices_tensor_torch, ttnn_indices_dtype, layout=ttnn.Layout.TILE, device=device
35+
)
2736
else:
2837
indices_tensor = None
2938

30-
try:
31-
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(
32-
ttnn_input,
33-
k,
34-
dim=dim,
35-
largest=largest,
36-
sorted=sorted,
37-
sub_core_grids=sub_core_grids,
38-
indices_tensor=indices_tensor,
39-
)
40-
except Exception as e:
41-
raise e
39+
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(
40+
ttnn_input,
41+
k,
42+
dim=dim,
43+
largest=largest,
44+
sorted=sorted,
45+
sub_core_grids=sub_core_grids,
46+
indices_tensor=indices_tensor,
47+
)
48+
49+
# Convert TTNN outputs to Torch for comparison
50+
ttnn_torch_values = ttnn.to_torch(ttnn_topk_values)
51+
ttnn_torch_indices = ttnn.to_torch(ttnn_topk_indices, dtype=torch_indices_dtype)
4252

53+
# Assert output shapes
4354
desired_shape = [N, C, H, W]
4455
desired_shape[dim] = k
4556
assert list(ttnn_topk_values.shape) == desired_shape
4657
assert list(ttnn_topk_indices.shape) == desired_shape
4758

48-
ttnn_torch_values = ttnn.to_torch(ttnn_topk_values)
49-
ttnn_torch_indices = ttnn.to_torch(ttnn_topk_indices).to(torch.int64)
59+
# Assert values correctness
5060
if dtype == ttnn.bfloat8_b:
51-
pcc_values = 0.99
61+
assert_allclose(ttnn_torch_values, pyt_topk_values, rtol=1e-1, atol=1e-1)
5262
else:
53-
pcc_values = 1.0
63+
assert_equal(ttnn_torch_values, pyt_topk_values)
64+
65+
# Assert indices correctness using gather
5466
# pcc is not a good measure for the raw indices
5567
# if index 49 and index 8 are tied, the order of the indices can be different
5668
# but the values associated with the indices should be the same
@@ -64,8 +76,6 @@ def run_topk_test(N, C, H, W, k, dtype, dim, sorted, largest, device, sub_core_g
6476

6577
assert ttnn_torch_cosine > 0.99, "Cosine similarity between topk values and gather from indices is less than 0.99"
6678

67-
assert_with_pcc(pyt_topk_values, ttnn_torch_values, pcc_values)
68-
6979

7080
@pytest.mark.parametrize(
7181
"dtype",
@@ -83,16 +93,19 @@ def run_topk_test(N, C, H, W, k, dtype, dim, sorted, largest, device, sub_core_g
8393
@pytest.mark.parametrize(
8494
"N, C, H, W, dim, k",
8595
(
86-
(1, 1, 32, 8192, 3, 50), # passed
87-
(1, 1, 64, 64, 2, 32), # passed
88-
(1, 1, 64, 64, 2, 64), # passed
89-
(1, 2048, 1, 64, 1, 32), # skipped
90-
(1, 1, 32, 64, 3, 2), # passed
91-
(1, 1, 32, 64, 3, 4), # passed
92-
(1, 1, 32, 8192, 3, 6), # passed
93-
(1, 2048, 1, 64, 1, 8), # passed
94-
(1, 1, 32, 32768, 3, 3000), # passed
95-
(1, 1, 32, 18992, 3, 3000), # passed
96+
(1, 1, 32, 8192, 3, 50),
97+
(1, 1, 64, 64, 2, 32),
98+
(1, 1, 64, 64, 2, 64),
99+
(1, 2048, 1, 64, 1, 32),
100+
(1, 1, 32, 64, 3, 2),
101+
(1, 1, 32, 64, 3, 4),
102+
(1, 1, 32, 8192, 3, 6),
103+
(1, 2048, 1, 64, 1, 8),
104+
(1, 1, 32, 32768, 3, 3000),
105+
(1, 1, 32, 18992, 3, 3000),
106+
(1, 1, 32, 18992, 3, 32),
107+
(1, 1, 32, 10000, 3, 32),
108+
(1, 1, 32, 64128, 3, 32),
96109
),
97110
)
98111
@pytest.mark.parametrize(
@@ -116,12 +129,6 @@ def run_topk_test(N, C, H, W, k, dtype, dim, sorted, largest, device, sub_core_g
116129
],
117130
)
118131
def test_topk(N, C, H, W, dim, k, dtype, sorted, largest, device, sub_core_grids):
119-
if dim == 0 or dim == 1:
120-
# As of now, when we try to get top-k for dim = 0 or 1, we get following error from transpose_op.cpp's validate():
121-
# input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::FLOAT32
122-
# this is because, transpose.cpp always typecasts bf8 to bf16
123-
# and when dim = 0 or 1, transpose converts it into TransposeOpDim::HC & this dim doesnt support bf16 or fp32
124-
pytest.skip()
125132
run_topk_test(N, C, H, W, k, dtype, dim, sorted, largest, device, sub_core_grids)
126133

127134

@@ -186,8 +193,8 @@ def test_topk_sub_core_grids(N, C, H, W, dim, k, dtype, sorted, largest, device,
186193
@pytest.mark.parametrize(
187194
"N, C, H, W, dim, k",
188195
(
189-
(1, 1, 32, 151936, 3, 50), # passed - customer shape 2
190-
(1, 1, 32, 128256, 3, 50), # passed - customer shape 1
196+
(1, 1, 32, 151936, 3, 50),
197+
(1, 1, 32, 128256, 3, 50),
191198
),
192199
)
193200
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)