Skip to content

Commit ff8edef

Browse files
Pavle Josipovicclaude
andcommitted
Fix SDPA decode watcher errors with half-tile (16x32) CBs
`generate_reduce_scaler` hardcoded 2048 bytes and 4 faces, assuming full 32x32 bf16 tiles. When circular buffers use half tiles (1024B, 2 faces), this overwrites adjacent L1 memory causing watcher-detected corruption. Restore the `half_tile` template parameter (previously removed in cleanup) so the zero-fill size and face iteration adapt to the actual tile dimensions. Also fix idle core runtime args count mismatch in sdpa_decode_program_factory. Fixes: #37631 Fixes: #29225 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 89175b6 commit ff8edef

File tree

4 files changed

+8
-9
lines changed

4 files changed

+8
-9
lines changed

models/demos/deepseek_v3_b1/tests/unit_tests/test_flash_mla.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from loguru import logger
1313

1414
import ttnn
15-
from models.common.utility_functions import comp_pcc, is_blackhole, is_watcher_enabled
15+
from models.common.utility_functions import comp_pcc
1616
from models.demos.deepseek_v3_b1.micro_ops.flash_mla.op import FlashMLADecode
1717

1818

@@ -22,8 +22,6 @@
2222
@pytest.mark.parametrize("max_seq_len", [32 * 1024]) # 32k max sequence length per chip
2323
def test_flash_mla_decode(device, batch_size, num_chunks, k_chunk_size, max_seq_len):
2424
"""Test FlashMLADecode op."""
25-
if is_blackhole() and is_watcher_enabled():
26-
pytest.skip("Skipping test on Blackhole with watcher enabled, see issue #37631")
2725

2826
# Calculate decode_position from num_chunks and k_chunk_size
2927
decode_position = num_chunks * k_chunk_size - 1

ttnn/cpp/ttnn/kernel/dataflow/generate_reduce_scaler.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
// Tile is assumed to have 16-bit elements
1010
// Scaler is assumed to be a 16-bit value double packed into a u32
11+
template <bool half_tile = false>
1112
FORCE_INLINE void generate_reduce_scaler(const uint32_t cb_id, const uint32_t scaler) {
1213
cb_reserve_back(cb_id, 1);
1314

14-
constexpr uint32_t num_zeros_reads = 2048 / MEM_ZEROS_SIZE;
15+
constexpr uint32_t num_zeros_reads = (half_tile ? 1024 : 2048) / MEM_ZEROS_SIZE;
1516
static_assert(num_zeros_reads > 0, "num_zeros_reads must be greater than 0");
1617
uint64_t zeros_noc_addr = get_noc_addr(MEM_ZEROS_BASE);
1718
uint32_t write_addr = get_write_ptr(cb_id);
@@ -27,7 +28,7 @@ FORCE_INLINE void generate_reduce_scaler(const uint32_t cb_id, const uint32_t sc
2728
noc_async_read_barrier();
2829

2930
if (scaler != 0) {
30-
for (int k = 0; k < 4; ++k) {
31+
for (int k = 0; k < (half_tile ? 2 : 4); ++k) {
3132
uint32_t idx = k << 7;
3233
for (int j = 0; j < 8; ++j) {
3334
ptr[idx + j] = scaler;

ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ void kernel_main() {
134134
constexpr uint32_t cb_out_l = tt::CBIndex::c_18;
135135

136136
// generate and send scaler to compute
137-
// These helper functions respect tile size of CBs (ie. no need for special handling of tiny tiles)
138-
generate_reduce_scaler(cb_identity_scale_in, identity_scalar_packed);
139-
generate_reduce_scaler(cb_zero_in, zero_scalar_packed);
137+
constexpr bool is_half_tile = (get_tile_size(cb_identity_scale_in) < 2 * tt::constants::TILE_HW);
138+
generate_reduce_scaler<is_half_tile>(cb_identity_scale_in, identity_scalar_packed);
139+
generate_reduce_scaler<is_half_tile>(cb_zero_in, zero_scalar_packed);
140140
generate_bcast_col_scalar(cb_col_identity, identity_scalar_packed);
141141

142142
if (k_chunk_start == window_start_chunk && window_start_unaligned > 0) {

ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ SdpaDecodeProgramFactory::cached_program_t SdpaDecodeProgramFactory::create(
10111011
for (auto core : core_group_idle) {
10121012
log_debug(tt::LogOp, "Setting core {} to idle", core);
10131013
// reader runtime args
1014-
std::vector<uint32_t> reader_rt_args = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
1014+
std::vector<uint32_t> reader_rt_args = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
10151015

10161016
// writer runtime args
10171017
std::vector<uint32_t> writer_rt_args = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};

0 commit comments

Comments
 (0)