Skip to content

Commit be36d59

Browse files
authored
[tt-train] Add causal mask support for SDPA backward (#36661)
### Ticket #34531 ### Problem description The SDPA backward pass (sdpa_bw_q and sdpa_bw_kv) required materializing a full S × S triangular attention mask for causal attention, which involved: - Precomputing the mask tensor on the host - Transferring the mask to DRAM - Reading unique mask tiles from DRAM for each Q/K position during backward computation - This introduces an O(S²) memory dependency and unnecessary DRAM traffic, becoming a bottleneck for long-sequence training. The forward pass already supports on-the-fly causal mask generation. ### What's changed This PR extends on-the-fly causal mask generation to the SDPA backward pass: - The writer kernel generates a triangular mask tile once during initialization using generate_causal_mask_tile() - The compute kernel reuses the same mask tile for all diagonal positions - Masked-out regions are skipped entirely: - sdpa_bw_q: K/V rows beyond the causal boundary are not processed - sdpa_bw_kv: Q rows before the causal boundary are not processed - Updated apply_mask_on_reg() with a template parameter to control CB pop behavior, enabling mask tile reuse for causal mode ### Impact - Removes the O(S²) memory dependency for causal masking in the SDPA backward pass - Reduces DRAM bandwidth usage by eliminating mask tensor reads - Avoids unnecessary computation for masked regions - Enables end-to-end causal attention training without mask tensors Avoiding unnecessary computation in masked regions provides a significant performance improvement. Baseline (composite SDPA, memory_efficient): - Mean step time: 2640.16 ms (Without fused SDPA we can’t fit TinyLlama (1B) on a single Wormhole card.) Fused SDPA kernel: - Arbitrary mask, memory_efficient: 3082.58 ms - Causal mask, memory_efficient: 2506.47 ms - Causal mask, default: 1983.56 ms ### Checklist - [x] [![All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml/badge.svg?branch=vmelnykov%2Fcasual_mask_support_sdpa_bw)](https://github.com/tenstorrent/tt-metal/actions/runs/21550483959) - [x] New/Existing tests provide coverage for changes
1 parent 6d954eb commit be36d59

18 files changed

+354
-66
lines changed

tt-train/sources/ttml/metal/ops/sdpa_bw/device/kernels/compute/sdpa_bw_compute_utils.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ constexpr uint32_t onetile = 1U;
3131
// masked ones.
3232
// This way, after applying softmax, masked positions will effectively become zero,
3333
// and only the unmasked positions will retain meaningful attention weights
34+
//
35+
// Note: Does NOT pop the mask tile - caller must pop explicitly when done with the tile.
36+
// This allows reusing the same mask tile for causal masks.
3437
void apply_mask_on_reg(
3538
const uint32_t register_idx,
3639
const uint32_t cb_attn_mask,
@@ -62,8 +65,6 @@ void apply_mask_on_reg(
6265
// unmasked positions remain unchanged
6366
add_binary_tile_init();
6467
add_binary_tile(register_idx, mask_register, register_idx);
65-
66-
cb_pop_front(cb_attn_mask, onetile);
6768
}
6869

6970
// Recomputes attention weights from pre-softmax scores using stored statistics.

tt-train/sources/ttml/metal/ops/sdpa_bw/device/kernels/compute/sdpa_bw_kv_compute_kernel.cpp

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ constexpr uint32_t cb_attn_output = tt::CBIndex::c_1; // Attention outpu
7878
constexpr uint32_t cb_query = tt::CBIndex::c_2; // Original query
7979
constexpr uint32_t cb_key = tt::CBIndex::c_3; // Original key
8080
constexpr uint32_t cb_value = tt::CBIndex::c_4; // Original value
81+
#if defined(CAUSAL_MASK) || defined(USE_ATTN_MASK)
8182
constexpr uint32_t cb_attn_mask = tt::CBIndex::c_5; // Original mask
83+
#endif
8284
constexpr uint32_t cb_intermediates = tt::CBIndex::c_6; // Forward pass intermediates
8385
constexpr uint32_t cb_mat_mul_reduction = tt::CBIndex::c_7; // Temporary computations
8486
constexpr uint32_t cb_grad_value_accum = tt::CBIndex::c_8; // L1 accumulator for grad_value
@@ -98,21 +100,45 @@ const uint32_t tiles_per_row = qWt; // assuming qWt == kWt == vWt
98100
const uint32_t num_of_interm_tiles = 2U; // number of tiles in intermediates buffer per head
99101

100102
void MAIN {
103+
// Runtime args - needed for causal mask to know global position within sequence
104+
const uint32_t start_row = get_arg_val<uint32_t>(0);
105+
101106
init_sfpu(cb_query, cb_key);
102107
binary_op_init_common(cb_grad_output, cb_query, cb_key);
103108

104109
cb_wait_front(cb_mat_mul_reduction, onetile);
105110

106111
mm_init(cb_query, cb_key, cb_attention_weights);
107112

113+
#ifdef CAUSAL_MASK
114+
// Wait for causal mask tile ONCE - it's generated by writer and will be reused for every diagonal
115+
cb_wait_front(cb_attn_mask, onetile);
116+
#endif
117+
108118
for (uint32_t row = 0; row < num_rows_per_core; ++row) {
109119
cb_wait_front(cb_key, tiles_per_row);
110120
cb_wait_front(cb_value, tiles_per_row);
111121

122+
#ifdef CAUSAL_MASK
123+
// Calculate global position for this K/V row
124+
const uint32_t global_row_idx = start_row + row;
125+
const uint32_t k_row_tile = global_row_idx % Ht; // position within sequence (0 to Ht-1)
126+
127+
// For causal mask: only process Q rows from k_row_tile to Ht-1
128+
// Q rows 0 to k_row_tile-1 have zero attention weights (can't attend to future keys)
129+
const uint32_t q_start_tile = k_row_tile;
130+
const uint32_t num_q_tiles_to_process = Ht - k_row_tile;
131+
#else
132+
const uint32_t q_start_tile = 0;
133+
const uint32_t num_q_tiles_to_process = Ht;
134+
#endif
135+
112136
for (uint32_t head_idx = 0; head_idx < heads_per_group; ++head_idx) {
113137
const uint32_t matmul_accum_reg = 0;
114138

115-
for (uint32_t h = 0; h < Ht; ++h) {
139+
for (uint32_t q_idx = 0; q_idx < num_q_tiles_to_process; ++q_idx) {
140+
const uint32_t h = q_start_tile + q_idx; // actual Q row tile index
141+
116142
// Wait for Q, dO, O, mask and intermediates for this K/V row
117143
cb_wait_front(cb_query, tiles_per_row);
118144
cb_wait_front(cb_grad_output, tiles_per_row);
@@ -134,12 +160,27 @@ void MAIN {
134160
/* dst_reg_idx*/ matmul_accum_reg); // accumulate in dest_reg 0
135161
}
136162

137-
/*
138-
* apply attention mask on dest_reg.
139-
* function assumes that dest_reg is in acquired state via *acquire_dst* call
140-
* function transforms mask from 1/0 to 0/-inf and applies it on dest_reg
141-
*/
163+
#ifdef CAUSAL_MASK
164+
// For causal mask: apply triangular mask on diagonal tile (h == k_row_tile)
165+
// Writer generates causal mask tile once, reused for every diagonal
166+
if (h == k_row_tile) {
167+
apply_mask_on_reg(matmul_accum_reg, cb_attn_mask, scaler_bits, minus_one_bits, custom_inf_bits);
168+
// Don't pop - causal mask tile is reused for all diagonal positions
169+
} else {
170+
// Off-diagonal (h > k_row_tile): just scale, no mask needed
171+
binop_with_scalar_tile_init();
172+
mul_unary_tile(matmul_accum_reg, scaler_bits);
173+
}
174+
#elif defined(USE_ATTN_MASK)
175+
// Apply attention mask from DRAM
176+
// Transforms mask from 1/0 to 0/-inf and applies it on dest_reg
142177
apply_mask_on_reg(matmul_accum_reg, cb_attn_mask, scaler_bits, minus_one_bits, custom_inf_bits);
178+
cb_pop_front(cb_attn_mask, onetile); // Pop each unique mask tile after use
179+
#else
180+
// No mask: just scale
181+
binop_with_scalar_tile_init();
182+
mul_unary_tile(matmul_accum_reg, scaler_bits);
183+
#endif
143184
tile_regs_commit();
144185
tile_regs_wait();
145186
pack_reconfig_data_format(cb_attention_weights);
@@ -151,14 +192,15 @@ void MAIN {
151192
apply_statistics_inplace(cb_attention_weights, cb_intermediates, num_of_interm_tiles);
152193

153194
// Step 3: Accumulate grad_V = Attention^T @ grad_output
195+
// For causal mask: first iteration is q_idx=0 (h=k_row_tile), head_idx=0
154196
update_grad_value(
155197
cb_attention_weights,
156198
cb_transpose_wh,
157199
cb_grad_output,
158200
cb_grad_value_accum,
159201
tiles_per_row,
160202
block_size,
161-
/* do_accumulate */ h > 0 || head_idx > 0);
203+
/* do_accumulate */ q_idx > 0 || head_idx > 0);
162204
cb_wait_front(cb_grad_value_accum, tiles_per_row);
163205

164206
// Step 4: calculate u_scalar_row = sum(dO * O) per row
@@ -185,7 +227,7 @@ void MAIN {
185227
cb_grad_key_accum,
186228
tiles_per_row,
187229
block_size,
188-
/* do_accumulate */ h > 0 || head_idx > 0);
230+
/* do_accumulate */ q_idx > 0 || head_idx > 0);
189231
cb_wait_front(cb_grad_key_accum, tiles_per_row);
190232

191233
// Pop intermediate results used for computing dK and dV
@@ -208,6 +250,11 @@ void MAIN {
208250
cb_pop_front(cb_key, tiles_per_row);
209251
cb_pop_front(cb_value, tiles_per_row);
210252
}
253+
254+
#ifdef CAUSAL_MASK
255+
// Pop the causal mask tile after all rows are processed (was reused for every diagonal)
256+
cb_pop_front(cb_attn_mask, onetile);
257+
#endif
211258
}
212259

213260
} // namespace NAMESPACE

tt-train/sources/ttml/metal/ops/sdpa_bw/device/kernels/compute/sdpa_bw_q_compute_kernel.cpp

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ constexpr uint32_t cb_attn_output = tt::CBIndex::c_1; // Attention outpu
7070
constexpr uint32_t cb_query = tt::CBIndex::c_2; // Original query
7171
constexpr uint32_t cb_key = tt::CBIndex::c_3; // Original key
7272
constexpr uint32_t cb_value = tt::CBIndex::c_4; // Original value
73+
#if defined(CAUSAL_MASK) || defined(USE_ATTN_MASK)
7374
constexpr uint32_t cb_attn_mask = tt::CBIndex::c_5; // Original mask
75+
#endif
7476
constexpr uint32_t cb_intermediates = tt::CBIndex::c_6; // Forward pass intermediates
7577
constexpr uint32_t cb_mat_mul_reduction = tt::CBIndex::c_7; // Temporary computations
7678
constexpr uint32_t cb_grad_query_accum = tt::CBIndex::c_8; // L1 accumulator for grad_query
@@ -85,12 +87,20 @@ const uint32_t tiles_per_row = qWt; // number of tiles per row (qWt == kWt
8587
const uint32_t num_of_interm_tiles = 2U; // number of tiles in intermediates buffer per head
8688

8789
void MAIN {
90+
// Runtime args - needed for causal mask to know global position within sequence
91+
const uint32_t start_row = get_arg_val<uint32_t>(0);
92+
8893
init_sfpu(cb_query, cb_key);
8994
binary_op_init_common(cb_grad_output, cb_query, cb_key);
9095

9196
cb_wait_front(cb_mat_mul_reduction, onetile);
9297
mm_init(cb_query, cb_key, cb_attention_weights);
9398

99+
#ifdef CAUSAL_MASK
100+
// Wait for causal mask tile ONCE - it's generated by writer and will be reused for every diagonal
101+
cb_wait_front(cb_attn_mask, onetile);
102+
#endif
103+
94104
for (uint32_t row = 0; row < num_rows_per_core; ++row) {
95105
cb_wait_front(cb_attn_output, tiles_per_row);
96106
cb_wait_front(cb_grad_output, tiles_per_row);
@@ -101,8 +111,20 @@ void MAIN {
101111
compute_u_scalar_row(
102112
cb_grad_output, cb_attn_output, cb_u_scalar_row, cb_mat_mul_reduction, tiles_per_row, scaler_bits);
103113

114+
#ifdef CAUSAL_MASK
115+
// Calculate global position within sequence for causal mask
116+
const uint32_t global_row_idx = start_row + row;
117+
const uint32_t q_row_tile = global_row_idx % Ht; // position within sequence (0 to Ht-1)
118+
119+
// For causal mask: only process K/V tiles up to and including the diagonal
120+
// q_row_tile determines how many K/V chunks we need (0..q_row_tile inclusive)
121+
const uint32_t num_kv_tiles_to_process = q_row_tile + 1;
122+
#else
123+
const uint32_t num_kv_tiles_to_process = Ht;
124+
#endif
125+
104126
const uint32_t matmul_accum_reg = 0;
105-
for (uint32_t h = 0; h < Ht; ++h) {
127+
for (uint32_t h = 0; h < num_kv_tiles_to_process; ++h) {
106128
cb_wait_front(cb_key, tiles_per_row);
107129
cb_wait_front(cb_value, tiles_per_row);
108130

@@ -120,12 +142,27 @@ void MAIN {
120142
/* dst_reg_idx*/ matmul_accum_reg); // accumulate in dest_reg 0
121143
}
122144

123-
/*
124-
* apply attention mask on dest_reg.
125-
* function assumes that dest_reg is in acquired state via *acquire_dst* call
126-
* function transforms mask from 1/0 to 0/-inf and applies it on dest_reg
127-
*/
145+
#ifdef CAUSAL_MASK
146+
// For causal mask: apply triangular mask on diagonal tile (h == q_row_tile)
147+
// Writer generates causal mask tile once, reused for every diagonal
148+
if (h == q_row_tile) {
149+
apply_mask_on_reg(matmul_accum_reg, cb_attn_mask, scaler_bits, minus_one_bits, custom_inf_bits);
150+
// Don't pop - causal mask tile is reused for all diagonal positions
151+
} else {
152+
// Off-diagonal: just scale
153+
binop_with_scalar_tile_init();
154+
mul_unary_tile(matmul_accum_reg, scaler_bits);
155+
}
156+
#elif defined(USE_ATTN_MASK)
157+
// Apply attention mask from DRAM
158+
// Transforms mask from 1/0 to 0/-inf and applies it on dest_reg
128159
apply_mask_on_reg(matmul_accum_reg, cb_attn_mask, scaler_bits, minus_one_bits, custom_inf_bits);
160+
cb_pop_front(cb_attn_mask, onetile); // Pop each unique mask tile after use
161+
#else
162+
// No mask: just scale
163+
binop_with_scalar_tile_init();
164+
mul_unary_tile(matmul_accum_reg, scaler_bits);
165+
#endif
129166
tile_regs_commit();
130167
tile_regs_wait();
131168
pack_reconfig_data_format(cb_attention_weights);
@@ -161,6 +198,9 @@ void MAIN {
161198
cb_pop_front(cb_value, tiles_per_row);
162199
cb_pop_front(cb_attention_weights, onetile);
163200
cb_pop_front(cb_grad_attn_weights, onetile);
201+
// Note: Mask pops are handled explicitly after apply_mask_on_reg:
202+
// - USE_ATTN_MASK: pops each unique mask tile after use
203+
// - CAUSAL_MASK: doesn't pop (reuses same tile for all diagonals)
164204
// Note: cb_grad_scores is popped inside update_grad_query
165205
}
166206

@@ -173,6 +213,11 @@ void MAIN {
173213
cb_pop_front(cb_attn_output, tiles_per_row);
174214
cb_pop_front(cb_grad_output, tiles_per_row);
175215
}
216+
217+
#ifdef CAUSAL_MASK
218+
// Pop the causal mask tile after all rows are processed (was reused for every diagonal)
219+
cb_pop_front(cb_attn_mask, onetile);
220+
#endif
176221
}
177222

178223
} // namespace NAMESPACE

tt-train/sources/ttml/metal/ops/sdpa_bw/device/kernels/dataflow/sdpa_bw_kv_reader_kernel.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ void kernel_main() {
2727
constexpr uint32_t cb_query = tt::CBIndex::c_2;
2828
constexpr uint32_t cb_key = tt::CBIndex::c_3;
2929
constexpr uint32_t cb_value = tt::CBIndex::c_4;
30+
#ifdef USE_ATTN_MASK
3031
constexpr uint32_t cb_attn_mask = tt::CBIndex::c_5;
32+
#endif
3133
constexpr uint32_t cb_intermediates = tt::CBIndex::c_6;
3234
constexpr uint32_t cb_matmul_reduce = tt::CBIndex::c_7;
3335

@@ -58,7 +60,9 @@ void kernel_main() {
5860
const auto query_address_generator = TensorAccessor(query_args, query_addr, tile_bytes);
5961
const auto key_address_generator = TensorAccessor(key_args, key_addr, tile_bytes);
6062
const auto value_address_generator = TensorAccessor(value_args, value_addr, tile_bytes);
63+
#ifdef USE_ATTN_MASK
6164
const auto mask_address_generator = TensorAccessor(mask_args, mask_addr, tile_bytes);
65+
#endif
6266
const auto intermediates_address_generator = TensorAccessor(intermediates_args, intermediates_addr, tile_bytes);
6367

6468
generate_matmul_row_reduce_tile(cb_matmul_reduce); // generate tile for matmul row reduce (auto-detects data type)
@@ -81,23 +85,43 @@ void kernel_main() {
8185
const uint32_t first_q_head_idx = group_idx * heads_per_group;
8286
const uint32_t q_offset = (batch_idx * q_heads + first_q_head_idx) * Ht * qWt;
8387

88+
// k_row_tile = position within sequence (0 to Ht-1)
89+
const uint32_t k_row_tile = global_row_idx % Ht;
90+
91+
#ifdef CAUSAL_MASK
92+
// For causal mask: only read Q rows from k_row_tile to Ht-1
93+
// Q rows 0 to k_row_tile-1 have zero attention weights (can't attend to future keys)
94+
const uint32_t q_start_tile = k_row_tile;
95+
const uint32_t num_q_tiles_to_read = Ht - k_row_tile;
96+
#else
97+
const uint32_t q_start_tile = 0;
98+
const uint32_t num_q_tiles_to_read = Ht;
99+
#endif
100+
101+
#ifdef USE_ATTN_MASK
84102
// Mask is (1, 1, S, S) - same mask for all batches/heads, indexed by sequence position only
85-
// For KV kernel, we read column (global_row_idx % Ht) from each row h of the mask
86-
const uint32_t mask_offset = (global_row_idx % Ht);
103+
// For KV kernel, we read column k_row_tile from each row h of the mask
104+
const uint32_t mask_offset = k_row_tile;
105+
#endif
87106

88107
// add change here: multiply by num_of_interm_tiles because we need to read 2 tiles per head row
89108
uint32_t intermediates_offset = (batch_idx * q_heads + first_q_head_idx) * Ht * num_of_interm_tiles;
90109

91110
// TODO: add calculation for dO, O indexes because in forward pass they are stored with shape (B, 1, S,
92111
// qNH*qEmbd)
93112
for (uint32_t q_head_idx = 0; q_head_idx < heads_per_group; ++q_head_idx) {
94-
for (uint32_t h = 0; h < Ht; ++h) {
113+
for (uint32_t q_idx = 0; q_idx < num_q_tiles_to_read; ++q_idx) {
114+
const uint32_t h = q_start_tile + q_idx; // actual Q row tile index
115+
95116
const uint32_t q_start_idx = q_offset + (q_head_idx * Ht + h) * qWt;
96117
read_tiles_by_row(cb_query, query_address_generator, q_start_idx, qWt, tile_bytes, qWt);
97118

119+
#ifdef USE_ATTN_MASK
98120
// read one tile of attn_mask for current row of K and V
99121
// row of K define the column in (QK^T) matrix, so it define the column of attn_mask
100122
read_one_tile(cb_attn_mask, mask_address_generator, mask_offset + h * Ht);
123+
#endif
124+
// Note: For CAUSAL_MASK, the mask tile is generated once by writer and reused by compute
101125

102126
// Read intermediates - one tile per row (contains 1/sum_exp values from forward pass)
103127
// TODO[improve](vmelnykov): Now we share two intermediates values per head row: row-wise max value and

tt-train/sources/ttml/metal/ops/sdpa_bw/device/kernels/dataflow/sdpa_bw_kv_writer_kernel.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ void kernel_main() {
2424
constexpr uint32_t q_heads = get_compile_time_arg_val(2); // number of query heads
2525
constexpr uint32_t heads_per_group = get_compile_time_arg_val(3); // heads per group
2626

27+
#ifdef CAUSAL_MASK
28+
// Generate causal mask tile ONCE - will be reused for every diagonal
29+
constexpr uint32_t cb_attn_mask = tt::CBIndex::c_5;
30+
generate_causal_mask_tile(cb_attn_mask);
31+
#endif
32+
2733
const uint32_t tile_bytes = get_tile_size(cb_grad_key);
2834

2935
// TensorAccessor definitions with chained offsets

0 commit comments

Comments
 (0)