[tt-train] Fused SwiGLU forward + single-sender multicast, packer L1 accumulation, batched tile multicast#34172
Open
mdragulaTT wants to merge 32 commits intomainfrom
Open
[tt-train] Fused SwiGLU forward + single-sender multicast, packer L1 accumulation, batched tile multicast#34172mdragulaTT wants to merge 32 commits intomainfrom
mdragulaTT wants to merge 32 commits intomainfrom
Conversation
ef6fd55 to
e8e65c1
Compare
075ccb9 to
d0bba2b
Compare
Contributor
There was a problem hiding this comment.
Pull request overview
This PR optimizes the SwiGLU forward operation by implementing multicast-based weight distribution to reduce DRAM bandwidth consumption. Instead of each core independently reading weight matrices (W1, W2, W3) from DRAM, leftmost column cores (x==0) read weights and multicast them to other cores in their row via NoC. The PR also introduces uniform workload padding to handle imbalanced batch sizes that don't divide evenly across cores.
Changes:
- Split reader kernel into sender (
reader_swiglu_fw_sender.cpp) and receiver (reader_swiglu_fw_receiver.cpp) variants - Implemented multicast synchronization using shared semaphores across W1/W2/W3 (reused since they execute sequentially)
- Added uniform workload padding mechanism where all cores in a grid row loop for
max_rows_for_synciterations to maintain multicast synchronization - Added reusable multicast helper functions in
dataflow_utils.hpp - Updated test suite with
SwiGLU_RepeatedRuns_NoHangtest for imbalanced workloads and corrected NanoLlama test naming - Added const-correctness improvements to compute kernels
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| swiglu_op_test.cpp | Added test for repeated runs with imbalanced batch sizes; updated test naming from NanoGPT to NanoLlama |
| swiglu_fw_program_factory.hpp | Extended shared variables to track sender/receiver kernels, core counts, and multicast flag |
| swiglu_fw_program_factory.cpp | Implemented multicast topology setup, kernel splitting, uniform padding logic, and runtime argument assignment |
| reader_swiglu_fw_sender.cpp | New sender kernel that reads from DRAM and multicasts W1/W2/W3 to receivers in same row |
| reader_swiglu_fw_receiver.cpp | New receiver kernel that receives W1/W2/W3 via multicast and only reads X from DRAM |
| reader_swiglu_fw_interleaved_start_id.cpp | Removed (replaced by sender/receiver split) |
| writer_swiglu_fw_interleaved_start_id.cpp | Minor const-correctness improvements to runtime arguments |
| swiglu_fw_kernel_m_fits_l1.cpp | Added const qualifiers to function parameters and local register indices |
| swiglu_fw_kernel.cpp | Added const qualifiers to function parameters and local register indices |
| dataflow_utils.hpp | Added multicast synchronization primitives and sender/receiver helper functions; improved const-correctness throughout |
tt-train/sources/ttml/metal/ops/swiglu_fw/device/swiglu_fw_program_factory.cpp
Outdated
Show resolved
Hide resolved
tt-train/sources/ttml/metal/ops/swiglu_fw/device/kernels/dataflow/reader_swiglu_fw_sender.cpp
Outdated
Show resolved
Hide resolved
tt-train/sources/ttml/metal/ops/swiglu_fw/device/swiglu_fw_program_factory.cpp
Outdated
Show resolved
Hide resolved
Optimizations: - Remove #ifdef W1W3_BATCHED conditional compilation for cleaner code - Always batch W1/W3 multicast as block_size×block_size tiles (16 tiles) - Add valid_tiles_per_row/valid_num_rows params to handle unaligned dimensions - Padding clamps indices to last valid row/tile when dimensions not divisible - Update compute kernel to consume full tile batch at once (cb_wait_front/pop for all 16 tiles) - CB sizing always uses 2*block_size*block_size for W1/W3 (32 tiles double-buffered) Performance Results (NanoLlama3: 64×256×384, hidden=1024): - Fused SwiGLU: 2.14ms - Composite baseline: 4.28ms - Speedup: 50% Profiling Analysis (per-row breakdown): - Reader Phase A (W1/W3): 94 μs - batched row reads - Reader Phase C (W2): 118 μs - column-by-column reads - Compute Phase A: 96 μs - matmul accumulation - Compute Phase B: 45 μs - SiLU activation + multiply - Compute Phase C: 68 μs - final matmul Key insight: Reader (212 μs) ≈ Compute (209 μs) indicates system is well-balanced. Further gains require W2 transposition or different blocking. Added PERFORMANCE_ANALYSIS.md documenting optimization opportunities.
0e93d9e to
3f6735a
Compare
Batch W2 multicast optimization: - Add mcast_sender_read_batched_cols_and_send helper for batched column reads - Update sender/receiver/compute kernels for batched W2 (block_size × block_size tiles) - Increase W2 CB size to match W1/W3 batched sizing - Performance: 5.68ms -> 5.14ms (9.5% improvement) - Sync reduction: Phase C from 96 syncs/row to 24 syncs/row (4x reduction) Cleanup and simplification: - Remove separate swiglu_fw_kernel_m_fits_l1.cpp, merge into main compute kernel - Remove non-M-fits-L1 code path from sender/receiver kernels (unused in practice) - Add TT_FATAL requiring M row to fit in L1 (all practical TP configs satisfy this) - Add weight shape validation with clear error messages for transposed layout - Add const qualifiers and comment improvements throughout - Add shape validation unit tests Co-authored-by: Cursor <cursoragent@cursor.com>
Optimize SwiGLU forward kernel with single-sender multicast: - Core (0,0) reads weights from DRAM once and multicasts to ALL cores - Use loopback multicast APIs so sender is also a receiver - Add uniform padding iterations for multicast synchronization - Fix compute kernel to drain CBs during padding rows (prevents deadlock) Performance: Fused SwiGLU now matches composite (4.23ms vs 4.29ms baseline) Changes: - dataflow_utils.hpp: Add loopback multicast helper functions - swiglu_fw_program_factory.cpp: Single-sender topology from (0,0) - reader_swiglu_fw_sender.cpp: Use loopback multicast functions - swiglu_fw_kernel.cpp: Loop for max_rows_for_sync, drain on padding - swiglu_op_test.cpp: Add unbalanced workload edge case test (57 rows)
This implements the 'True Flash' memory optimization for SwiGLU forward pass: - Loop inversion: k_block outer, p_block inner (vs original p_block outer) - Computes M tiles on-demand instead of materializing full M row - ~50% L1 memory reduction (560 KB -> 280 KB for NanoLlama3) New files: - swiglu_fw_true_flash_kernel.cpp: Compute kernel with inverted loop order - reader_swiglu_fw_true_flash_sender.cpp: Sender dataflow kernel - reader_swiglu_fw_true_flash_receiver.cpp: Receiver dataflow kernel - writer_swiglu_fw_true_flash.cpp: Writer kernel for full Y rows - swiglu_fw_true_flash_program_factory.hpp/cpp: Program factory API changes: - Added SwiGLUAlgorithm enum: ORIGINAL, TRUE_FLASH, AUTO - swiglu_fw() accepts optional algorithm parameter - Tests now use TRUE_FLASH by default All 14 SwiGLU tests pass with True Flash algorithm. Note: Performance optimization (Phase 2: block matmul, Phase 3: X caching) will be implemented in follow-up commits.
Replace matmul_tiles with matmul_block using ct_dim=block_size in the accumulate_XW_for_k_block function. This reduces matmul calls from p×k=16 to just p=4 per p_block. Key changes: - Use mm_block_init_short with ct_dim=block_size, rt_dim=1 - Iterate over inner dimension (p_block_size) calling matmul_block - in0_index increments by 1 (next X tile) - in1_index increments by block_size (next W row in row-major layout) All 14 SwiGLU tests pass.
- Add new rt_dim=2 compute kernel (swiglu_fw_true_flash_kernel_rt2.cpp) - Process 2 rows at a time using matmul_block(rt_dim=2, ct_dim=2) - DST register layout: [r0_k0, r0_k1, r1_k0, r1_k1] for first k-half - Pack in row-major order: [r0_k0..k3, r1_k0..k3] - Add rt_dim=2 dataflow kernels - reader_swiglu_fw_sender_rt2.cpp: Read 2 rows of X per iteration - reader_swiglu_fw_receiver_rt2.cpp: Match sender's 2-row pattern - writer_swiglu_fw_interleaved_start_id_rt2.cpp: Write 2 rows of Y - Update program factory for rt_dim=2 - CB sizes: rt_dim * block_size tiles for intermediates - Input CB: 2 * rt_dim * block_size for double-buffering - max_rows_for_sync rounded up to even for row-pair processing - Remove L1 check (True Flash doesn't need full row caching) Performance: NanoLlama test ~3% faster (55086ms -> 53304ms)
- Cache full X row(s) in L1 to avoid re-reading from DRAM for each k_block - X CB enlarged from 2*rt_dim_tiles to kRtDim*Wt (full row per row pair) - Dataflow restructured: read X once per row pair, k_block loop is outer - Compute kernel: wait for full X cache at start, pop once at end - Added TT_FATAL validation that X row caching fits in L1 - Updated should_use_true_flash() memory calculation for Phase 3 Expected benefit: K_blocks x fewer DRAM reads (8x for NanoLlama3)
- Use llk_pack_reconfig_l1_acc() to keep partial results in L1 between pack operations - Apply L1 accumulation to XW1/XW3 computation across p_blocks - Use pack_tile<true> for out-of-order packing to specific CB positions - Add PACKER_L1_ACC define support to program_utils.hpp - Reserve CB once at first_p_block, push once at last_p_block - Eliminates manual partial result load from CB into DST This reduces CB traffic during the XW accumulation phase by letting the packer read-add-write directly in L1 instead of round-tripping through CB.
True Flash trades compute for memory by recomputing XW products for each k_block. When both algorithms fit in L1, ORIGINAL is faster due to less recomputation. Only use True Flash when it's the only option that fits. Removed misleading TODO about enabling True Flash by default.
Phase 2 (commit f233e5a) accidentally modified the ORIGINAL algorithm's program factory to use TRUE_FLASH kernels (_rt2 variants with rt_dim=2). This caused the ORIGINAL path to hang. This commit restores the ORIGINAL factory to its working state (from commit dcb88dc) with the correct single-row processing kernels: - swiglu_fw_kernel.cpp (compute) - reader_swiglu_fw_sender.cpp / reader_swiglu_fw_receiver.cpp (dataflow) - writer_swiglu_fw_interleaved_start_id.cpp (dataflow) Key differences from TRUE_FLASH: - No rt_dim=2 row-pair processing - Full row caching (M-fits-L1 requirement) - No X row caching or recompute Performance: 4.19ms vs 5.67ms (TRUE_FLASH) on 64x1x256x384 shape. Tests: All 12 SwiGLU tests pass (9 regular + 2 NIGHTLY + 3 shape mismatch).
- Replace matmul_tiles with matmul_block(rt_dim=1, ct_dim=block_size) in mul_XW_accumulate_k_block for X@W1 and X@W3 computations - This processes all k output tiles per inner dimension tile, reducing matmul calls from p×k to just p - Update tests to default to ORIGINAL algorithm Co-authored-by: Cursor <cursoragent@cursor.com>
Eliminate 3 partial CBs (cb_xw1_partial, cb_xw3_partial, cb_y_partial) by using packer L1 accumulation to accumulate directly into final CBs. Phase A: XW1/XW3 accumulate across p_blocks via L1 acc into cb_xw1/cb_xw3 Phase C: Y accumulate across k_blocks via L1 acc into cb_y Saves ~24 KB L1 per core. No measurable perf change on NanoLlama3 (3.95 ms) but cleaner code and more L1 headroom for larger hidden_dim. Co-authored-by: Cursor <cursoragent@cursor.com>
Replace matmul_tiles with matmul_block(rt_dim=1, ct_dim=block_size) for Phase C. Change W2 CB layout from column-major to row-major to match matmul_block expectations. Performance: 3.943 ms (marginal improvement from 3.951 ms). Co-authored-by: Cursor <cursoragent@cursor.com>
Increase W2 CB from double-buffer (2x) to triple-buffer (3x block_size^2). This allows the weight sender to prefetch W2 tiles during Phase B (SiLU), staying 3 batches ahead of compute instead of 2. Performance: 3.906 ms (from 3.943 ms = 0.9% improvement). Cumulative: 8.7% faster than composite (4.28 ms). Co-authored-by: Cursor <cursoragent@cursor.com>
Process 2 SiLU tiles per acquire/commit using all 4 DST registers: - Tile 0: REGs 0,1,2 (result in REG 0) - Tile 1: REGs 1,2,3 reused (result in REG 1) - Pack both results in one commit cycle Halves the acquire/commit/pack overhead for Phase B. Performance: 3.905 ms (negligible change - Phase B is not the bottleneck). Co-authored-by: Cursor <cursoragent@cursor.com>
Dynamically compute W2 CB size based on remaining L1 after all other CBs. More buffering lets the sender prefetch more W2 tiles during Phase B (SiLU), eliminating W2 wait time in Phase C. For NanoLlama3: uses ~24 buffers (384 tiles) filling available L1 headroom. Performance: 3.346 ms (from 3.906 ms = 14.3% improvement!) Cumulative: 21.8% faster than composite (4.28 ms). Scales automatically: small hidden_dim = more W2 buffers, large = fewer. Co-authored-by: Cursor <cursoragent@cursor.com>
Remove the TRUE_FLASH algorithm and all associated files: - Delete TRUE_FLASH compute kernels (basic + rt2) - Delete TRUE_FLASH dataflow kernels (sender, receiver, writer) - Delete TRUE_FLASH program factory - Delete old writer kernel (replaced by reader_x_writer_y_swiglu_fw.cpp) - Remove SwiGLUAlgorithm enum and algorithm parameter from all APIs - Single code path: ORIGINAL with dual-NOC + dynamic W2 prefetch TRUE_FLASH was 32% slower than composite and algorithmically suboptimal (loop inversion penalties). Composite is the better fallback when ORIGINAL doesn't fit. Galaxy (32 devices) provides enough TP for all production models. 11 files deleted, API simplified to just swiglu(input, w1, w2, w3). Co-authored-by: Cursor <cursoragent@cursor.com>
- Update copyright year 2025 -> 2026 in all SwiGLU files - Remove dead projected matmul in backward pass (wasted compute) - Remove stale comments (SwiGLUAlgorithm refs, removed CBs, old params) - Remove unused profiler include from compute kernel - Fix misleading comment about hidden dimension - Fix duplicate test numbering - Remove hpp entries from CMakeLists (only cpp needed) Co-authored-by: Cursor <cursoragent@cursor.com>
- Remove unused c_block_size param from mul_MW2_accumulate_Y_l1 - Use read_full_row_tiles/write_full_row_tiles utilities in X reader/Y writer - Deduplicate weight_tile_start in sender kernel - Rename m1/m2/m3 -> w1/w2/w3 for consistency in device operation - Add constexpr for hidden_Wt_rounded_up and file-level tiles_per_batch - Add missing const on mask_w, mask_hw, max_rows_for_sync, is_sender - Make data_format alias const - Remove unused num_cores_y from shared_variables_t - Remove dead projected matmul in backward pass Co-authored-by: Cursor <cursoragent@cursor.com>
- Use api/compute/ include path in swiglu_fw_kernel.cpp to match repo layout and fix build after compute kernel API header move - Multicast and dataflow cleanup; program factory and benchmark updates Co-authored-by: Cursor <cursoragent@cursor.com>
3f6735a to
de525b0
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
…d, cleanup - compute_utils: pack_no_push_block, pack_l1_acc_block, pack_and_push_block with commit/wait notes; pad case uses pack_and_push_block (garbage) - dataflow_utils: read_batched_rows_with_padding (UseBarrier), std::min for block sizes and actual_row/actual_t; read_full_row_tiles/write use std::min - program_utils: remove packer_l1_acc and PACKER_L1_ACC from create_compute_kernel - swiglu_fw: use pack_and_push_block, pack_l1_acc_block, std::min; remove unused ROW_OF_M_FITS_IN_L1 define; TODO above assign_per_core_runtime_args Co-authored-by: Cursor <cursoragent@cursor.com>
…structure comments Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
…x; swiglu: remove redundant tile_regs_wait Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Ticket
[Feature Request] SwiGLU forward and backward functions - #14195
Context (follow-up PR)
This is the second PR on fused SwiGLU forward. The first PR introduced the dual-path fused SwiGLU: L1-optimized fast path (full rows of XW1/XW3/M in L1 when they fit), fallback streaming path for larger dimensions, and composite backward. It did not use multicast; every core read weights from DRAM independently, so the fused forward was still slow at scale.
This PR builds on that foundation by adding: single-sender multicast for weights, packer L1 accumulation (XW1, XW3, Y), overlapping phases (dual-NOC: X/weights split), batched tile multicast, W2 double-buffering (row-major layout), and related optimizations—bringing fused SwiGLU forward in line with or ahead of composite on the Tracy profiler while reducing activation memory.
Problem description
Before any fusion, the composite path uses separate matmuls and SiLU with each core reading all weights from DRAM. The first fused PR added an L1-optimized kernel (XW1, XW3, M in L1 when they fit) but still had redundant weight DRAM reads (every core read W1/W2/W3) and no multicast, so it was slow at scale.
This (second) PR adds:
num_coresreduction in weight DRAM reads.Backward remains composite (unchanged in this PR). Note on fw+bw measurement: tt-train's autograd does not currently support mixing a custom fused forward with the original composite backward graph. As a workaround, the fused path uses a custom backward that is functionally correct but unoptimized — so fw+bw step times are dominated by the slow backward and are not meaningful for comparing forward performance. All performance data below focuses on forward-only runs. The fw+bw runs are used only to verify that loss decreases correctly with the fused forward.
What's changed
Fused SwiGLU forward kernel
block_sizefor matmul; padding tiles are never read by M@W2.pack_and_push,pack_and_push_block,pack_l1_acc_block.Single-sender multicast weight distribution
mcast_sender_read_batched_rows_and_send_loopback,mcast_receiver_reserve_and_receive).Batched tile multicast
block_size × block_sizetiles per multicast.Dual-NOC / X vs weights split
Uniform workload padding for multicast sync
max_rows_for_sync(max rows assigned to any core).Weight shape validation
TT_FATALchecks for fused SwiGLU weight layout: W1[embed, hidden], W3[embed, hidden], W2[hidden, embed].Testing
SwiGLU_RepeatedRuns_NoHang(batch=100, 56 cores, 3 iterations).SwiGLU_UnbalancedWorkload_57x1x32x32(57 rows on 56-core grid).Performance
Methodology: All performance comparisons use forward-only runs (backward disabled), isolating the actual change. Training correctness (loss convergence) verified separately with fw+bw runs. Step time and loss from
tt-train/scripts/plot_training_comparison.py. Memory from "Memory Usage Summary" FORWARD_PASS block in training logs. Tracy fromanalyze_tracy_output.pyon profiler CSV exports (composite = markers mode; fused = operations mode, op nameSwiGLU). SwiGLU calls per block: fixed at 1 (one MLP per transformer block, one SwiGLU per MLP).NanoLlama3 (64×1×256×384, hidden=1024, 6 blocks)
Loss convergence (fw+bw): Composite final loss 0.122, fused final loss 0.117 — loss decreases correctly with the fused forward. (fw+bw step times not comparable due to unoptimized backward workaround; see note above.)
Forward-only step time comparison (composite vs fused):
Peak memory breakdown (composite vs fused, forward pass):
TinyLlama (embed 2048, hidden 5632, 22 blocks)
Loss convergence (fw+bw): Composite final loss 2.135, fused final loss 2.162 — loss decreases correctly.
Why no memory improvement: TinyLlama uses
runner_type: memory_efficient(gradient checkpointing). In this mode, block forward runs with gradients disabled — intermediate activations (including the SwiGLU intermediates XW1, XW3, M that fused avoids materializing) are not retained during forward. They are recomputed during backward instead. So the fused path's memory advantage disappears: the intermediates it avoids writing to DRAM were already not being stored. Model (1,863 MB) + optimizer (3,709 MB) dominate the 6,568 MB forward peak; retained activations are only ~193 MB for both paths.Forward-only step time comparison (composite vs fused):
Memory analysis
Numbers are from the FORWARD_PASS phase of the "Memory Usage Summary" (first-iteration fw+bw run). Activations = segment change (net DRAM allocated during forward and held for backward). Peak during forward = cumulative peak at end of FORWARD_PASS. See
tt-train/docs/MEMORY_TRACKING.md. Fused reduces activation memory by avoiding materializing full XW1/XW3/M tensors to DRAM.NanoLlama3 (runner_type: default — all forward activations retained):
Approx. reduction: ~37% activations, ~36% peak during forward.
TinyLlama (runner_type: memory_efficient — gradient checkpointing, intermediates not retained):
No activation reduction. With gradient checkpointing enabled, block-internal intermediates (XW1, XW3, M) are already discarded during forward and recomputed during backward — the fused path's memory advantage does not apply.
NanoLlama peak memory breakdown (composite vs fused):
Composite SwiGLU (Tracy per-block summary, 1 call per block)
composite_swiglu_begin/_end, 6 blocks/stepFused SwiGLU (Tracy per-block summary, 1 call per block)
Model-size summary: Fused is faster than composite on NanoLlama3 and slower on TinyLlama. TinyLlama scaling requires further investigation (e.g. fused kernel vs batched matmuls at larger embed/hidden).
Summary
num_coresto 1; batched tiles reduce sync overhead.memory_efficientrunner), so intermediates are not retained and fused shows no memory advantage.Checklist