-
Notifications
You must be signed in to change notification settings - Fork 347
Description
Not a regression, but found when debugging garbage outputs through VLLM with TopK sampling using recent tt-metal 430d1f4 from Feb 7 - maybe LLK related, but not sure to tag label for now. See details below, I thank Claude Code for doing most of the debug here :)
Describe the bug
ttnn.sort returns indices that are not a valid permutation when sorting tensors with more than 256 elements (Wt > 8 tiles) along the last dimension. The sorted values are correct and monotonically ordered, but the returned indices contain duplicates and missing values.
Specifically:
- Index 0 is always duplicated (appears 2x for Wt=16, 4x for Wt=32, 8x for Wt=64, etc.)
- Multiples of 256 are always missing (index 256, 512, 768, ...)
- The bug is data-independent — occurs for random values, all-zeros, pre-sorted ascending, and pre-sorted descending inputs
- The bug affects both ascending and descending sort directions
This breaks any code that relies on sort indices for scatter/gather operations. In our case, it causes vLLM's top-k/top-p sampler (apply_top_k_top_p) to produce corrupted logits, resulting in garbage text output during non-greedy sampling.
Minimal repro
import torch, ttnn
device = ttnn.open_device(device_id=0)
t = torch.randn([1, 512], dtype=torch.bfloat16)
inp = ttnn.from_torch(t, ttnn.bfloat16, layout=ttnn.Layout.TILE, device=device)
_, indices = ttnn.sort(inp, dim=-1)
idx = ttnn.to_torch(indices)[0].to(torch.int64).tolist()
print(f"Unique indices: {len(set(idx))}/512") # 511/512
print(f"Index 0 count: {idx.count(0)}") # 2 (expected 1)
print(f"Index 256 count: {idx.count(256)}") # 0 (expected 1)
ttnn.close_device(device)Output:
Unique indices: 511/512
Index 0 count: 2 (expected 1)
Index 256 count: 0 (expected 1)
Failure boundary
| Wt (tiles) | Elements | Status | Index 0 dup count | Missing indices |
|---|---|---|---|---|
| 8 | 256 | PASS | 1 (correct) | none |
| 9 | 288 | FAIL | 2 | [256] |
| 16 | 512 | FAIL | 2 | [256] |
| 32 | 1024 | FAIL | 4 | [256, 512, 768] |
| 64 | 2048 | FAIL | 8 | [256, 512, ..., 1792] |
The bug first appears when the bitonic sort requires a 4th merge stage (Wt=9 pads to 16, needing log2(16)=4 stages vs log2(8)=3 stages for Wt=8).
Steps to Reproduce
-
Run any of the repro scripts from this gist: https://gist.github.com/kmabeeTT/d5d75c62219e21c44cd386b74eb02736
test_ttnn_sort_minimal.py— 15-line minimal repro shown abovetest_ttnn_sort_repro.py— comprehensive sweep across sizes (32 to 151936)test_ttnn_sort_boundary.py— pinpoints exact boundary, tests data independence
-
Or inline:
python3 test_ttnn_sort_minimal.pyExpected Behavior
ttnn.sort should return indices that form a valid permutation — each index from 0 to N-1 should appear exactly once. This is required for sort+scatter roundtrips to preserve data, and for torch.gather(input, dim, sort_indices) to reconstruct the sorted values.
Existing test gap
The current test_sort_indices test tests/ttnn/unit_tests/operations/data_movement/test_sort.py only covers shapes where the sort dimension is ≤ 64 elements (Wt ≤ 2), which is below the failure threshold of Wt > 8. The bug is not caught by existing tests.
Additional context
- Tested on Wormhole B0
- bf16 input, TILE layout
- Affects all three program factory paths (single-core for Wt ≤ 64, cross-core and multi-core DRAM for larger)
- The sorted values are always correct — only the indices are wrong
- Root cause appears to be in the LLK bitonic merge (
_bitonic_topk_mergeinckernel_sfpu_topk.h), specifically in how index load/store addresses are calculated when the merge crosses the 8-tile (256-element) boundary
Metadata
Metadata
Assignees
Labels
Type
Projects
Status