Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes TT_METAL_WATCHER-detected SDPA decode corruption by making generate_reduce_scaler correctly zero-fill and initialize scalers for both full (32x32) and half (16x32) tiles, and by fixing an idle-core reader runtime-args length mismatch that can trip watcher assertions.
Changes:
- Reintroduce a
half_tiletemplate parameter ingenerate_reduce_scalerto adjust zero-fill size and face iteration for half tiles. - Update SDPA decode writer kernel to select the correct
generate_reduce_scalerspecialization based on the scaler CB tile size. - Fix SDPA decode program factory idle-core reader runtime args count (14 → 15), aligning with the reader kernel’s expected argument count; remove a now-unnecessary test skip for Blackhole + watcher.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp |
Fixes idle-core reader runtime-arg vector length to match the reader kernel’s 15 fixed args. |
ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp |
Uses tile-size-derived compile-time is_half_tile to call the correct generate_reduce_scaler specialization. |
ttnn/cpp/ttnn/kernel/dataflow/generate_reduce_scaler.hpp |
Restores half-tile-aware behavior (1024B/2 faces vs 2048B/4 faces) to prevent L1 overwrite. |
models/demos/deepseek_v3_b1/tests/unit_tests/test_flash_mla.py |
Removes Blackhole+watcher skip tied to the fixed SDPA decode watcher issue. |
|
@pavlejosipovic we're removing the scalar tile generation for blitz sdpa, so would rather this change not go in (the che reduce scalar change and the blitz file change). The change to regular sdpa decode is fine to go in. |
What is your proposal, this is causing a bug? |
tt-aho
left a comment
There was a problem hiding this comment.
Synced offline. Confusion was that blitz file had the skip but fix is for non-blitz.
4d91a27 to
0a22f23
Compare
| "head.kpt.1": {"input_shapes": [[1, 14, 14, 64]], "reshape_output": True}, | ||
| "head.kpt.2": {"input_shapes": [[1, 7, 7, 64]], "reshape_output": True}, | ||
| } | ||
| modules = register_module_replacement_dict(model, nn_to_nn, model_config=model_config) |
| return {} | ||
| return {} | ||
|
|
||
| class TorchMoeLinear(torch.nn.Module): |
| # if self.training: | ||
| return cls, box, obj, kpt | ||
|
|
||
| n = cls[0].shape[0] |
| def forward(self, input_tensor: ttnn.Tensor) -> ttnn.Tensor: | ||
| """Forward pass through Upsample layer.""" | ||
| batch_size, input_height, input_width, channels = input_tensor.shape | ||
| output_height = input_height * self.scale_factor |
| """Forward pass through Upsample layer.""" | ||
| batch_size, input_height, input_width, channels = input_tensor.shape | ||
| output_height = input_height * self.scale_factor | ||
| output_width = input_width * self.scale_factor |
| if len(args) == 6: | ||
| assert isinstance(args[3], (type(None), ttnn.Tensor)), "attn_mask must be None or a TTNN tensor." | ||
| attn_mask = args[3] | ||
| dropout_p = args[4] |
| elif len(args) == 5: | ||
| if isinstance(args[3], (float, int)): | ||
| attn_mask = None | ||
| dropout_p = args[3] |
| else: | ||
| assert isinstance(args[3], (type(None), ttnn.Tensor)), "attn_mask must be None or a TTNN tensor." | ||
| attn_mask = args[3] | ||
| dropout_p = 0.0 |
| elif len(args) == 4: | ||
| if isinstance(args[3], (bool, int)): | ||
| attn_mask = None | ||
| dropout_p = 0.0 |
| else: | ||
| assert isinstance(args[3], (type(None), ttnn.Tensor)), "attn_mask must be None or a TTNN tensor." | ||
| attn_mask = args[3] | ||
| dropout_p = 0.0 |
e97d940 to
fc3efec
Compare
`generate_reduce_scaler` hardcoded 2048 bytes and 4 faces, assuming full 32x32 bf16 tiles. When circular buffers use half tiles (1024B, 2 faces), this overwrites adjacent L1 memory causing watcher-detected corruption. Restore the `half_tile` template parameter (previously removed in cleanup) so the zero-fill size and face iteration adapt to the actual tile dimensions. Also fix idle core runtime args count mismatch in sdpa_decode_program_factory. Fixes: #37631 Fixes: #29225 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The watcher skip for issue #37631 was prematurely removed. Restore it until the underlying issue is fully resolved. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
fc3efec to
cd160f8
Compare
|
Closing due to bad rebase that pulled in unrelated commits and spammed notifications. Re-opening as a clean PR. |
Commit 89175b6 introduced SDPA op to all post commit workflow.
This workflow is now covered in CI by TT_METAL_WATCHER checks. Now they will start to fail.
generate_reduce_scalerhardcoded 2048 bytes and 4 faces, assuming full 32x32 bf16 tiles. When circular buffers use half tiles (1024B, 2 faces), this overwrites adjacent L1 memory causing watcher-detected corruption.Restore the
half_tiletemplate parameter (previously removed in cleanup) so the zero-fill size and face iteration adapt to the actual tile dimensions. Also fix idle core runtime args count mismatch in sdpa_decode_program_factory.Fixes: #37631
Fixes: #29225