Skip to content

ttnn.sort returns duplicate indices (not a valid permutation) for tensors wider than 256 elements (VLLM TopK Sampling) #37571

@kmabeeTT

Description

@kmabeeTT

Not a regression, but found when debugging garbage outputs through VLLM with TopK sampling using recent tt-metal 430d1f4 from Feb 7 - maybe LLK related, but not sure to tag label for now. See details below, I thank Claude Code for doing most of the debug here :)

Describe the bug

ttnn.sort returns indices that are not a valid permutation when sorting tensors with more than 256 elements (Wt > 8 tiles) along the last dimension. The sorted values are correct and monotonically ordered, but the returned indices contain duplicates and missing values.

Specifically:

  • Index 0 is always duplicated (appears 2x for Wt=16, 4x for Wt=32, 8x for Wt=64, etc.)
  • Multiples of 256 are always missing (index 256, 512, 768, ...)
  • The bug is data-independent — occurs for random values, all-zeros, pre-sorted ascending, and pre-sorted descending inputs
  • The bug affects both ascending and descending sort directions

This breaks any code that relies on sort indices for scatter/gather operations. In our case, it causes vLLM's top-k/top-p sampler (apply_top_k_top_p) to produce corrupted logits, resulting in garbage text output during non-greedy sampling.

Minimal repro

import torch, ttnn

device = ttnn.open_device(device_id=0)

t = torch.randn([1, 512], dtype=torch.bfloat16)
inp = ttnn.from_torch(t, ttnn.bfloat16, layout=ttnn.Layout.TILE, device=device)
_, indices = ttnn.sort(inp, dim=-1)
idx = ttnn.to_torch(indices)[0].to(torch.int64).tolist()

print(f"Unique indices: {len(set(idx))}/512")   # 511/512
print(f"Index 0 count:   {idx.count(0)}")        # 2  (expected 1)
print(f"Index 256 count: {idx.count(256)}")       # 0  (expected 1)

ttnn.close_device(device)

Output:

Unique indices: 511/512
Index 0 count:   2  (expected 1)
Index 256 count: 0  (expected 1)

Failure boundary

Wt (tiles) Elements Status Index 0 dup count Missing indices
8 256 PASS 1 (correct) none
9 288 FAIL 2 [256]
16 512 FAIL 2 [256]
32 1024 FAIL 4 [256, 512, 768]
64 2048 FAIL 8 [256, 512, ..., 1792]

The bug first appears when the bitonic sort requires a 4th merge stage (Wt=9 pads to 16, needing log2(16)=4 stages vs log2(8)=3 stages for Wt=8).

Steps to Reproduce

  1. Run any of the repro scripts from this gist: https://gist.github.com/kmabeeTT/d5d75c62219e21c44cd386b74eb02736

    • test_ttnn_sort_minimal.py — 15-line minimal repro shown above
    • test_ttnn_sort_repro.py — comprehensive sweep across sizes (32 to 151936)
    • test_ttnn_sort_boundary.py — pinpoints exact boundary, tests data independence
  2. Or inline:

python3 test_ttnn_sort_minimal.py

Expected Behavior

ttnn.sort should return indices that form a valid permutation — each index from 0 to N-1 should appear exactly once. This is required for sort+scatter roundtrips to preserve data, and for torch.gather(input, dim, sort_indices) to reconstruct the sorted values.

Existing test gap

The current test_sort_indices test tests/ttnn/unit_tests/operations/data_movement/test_sort.py only covers shapes where the sort dimension is ≤ 64 elements (Wt ≤ 2), which is below the failure threshold of Wt > 8. The bug is not caught by existing tests.

Additional context

  • Tested on Wormhole B0
  • bf16 input, TILE layout
  • Affects all three program factory paths (single-core for Wt ≤ 64, cross-core and multi-core DRAM for larger)
  • The sorted values are always correct — only the indices are wrong
  • Root cause appears to be in the LLK bitonic merge (_bitonic_topk_merge in ckernel_sfpu_topk.h), specifically in how index load/store addresses are calculated when the merge crosses the 8-tile (256-element) boundary

Metadata

Metadata

Assignees

Type

No type

Projects

Status

🆕 New

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions