Fused kernels and removal of torch cuda synchronize #105
Conversation
Replaces the eager apply_rope sequence (~5-6 elementwise kernels per call) with a single Triton kernel for both forward and backward. Reduces kernel launch overhead in the training forward path, where RoPE accounts for roughly 15-20% of the 21k kernels/call at Qwen3-4B scale. Opt-in via `--sdpa.fused_rope true`. Falls back to eager apply_rope when Triton is unavailable or the input shape/dtype is unsupported. Correctness: verified against fp64 reference across bf16/fp16/fp32 and Qwen3 shapes. Fused kernel accumulates in fp32 internally and is typically more accurate than eager (0.69x mean error vs fp64). Measured on Qwen3-4B, A100-40GB, Helmet nq 64k, cpu_offload=True: - Training (6 steps): 1199s -> 1174s (-2.1%) - Validation: 1111s -> 1076s (-3.1%) - val_loss: 18.248 -> 16.264 (slight improvement, likely fp32 accumulation) Standalone microbenchmark at Qwen3 Q shape (2, 32, 2048, 128) bf16: - Forward: 0.394 ms -> 0.070 ms (5.6x) - Forward+backward: 1.061 ms -> 0.339 ms (3.1x)
Replaces the eager RMSNorm.forward (fp32 cast + x*x + mean + rsqrt + normalize + mul weight + cast-back = ~5-6 kernels) with a single Triton forward kernel. Backward uses two kernels: a row-parallel kernel for dL/dx and a split-K two-pass reduction for dL/dw. The dL/dw split-K design: a first pass produces (n_m_blocks, D) partial sums in parallel across both M and D; a second pass reduces partials along M per column. Target ~432 total programs (4 per SM on A100) for full occupancy. Opt-in via `--sdpa.fused_rmsnorm true`. When enabled, monkey-patches both `keys_values.model.RMSNorm.forward` and `litgpt.model.RMSNorm.forward` so all existing call sites (norm_1, post_attention_norm, norm_2, post_mlp_norm, norm_q, norm_k, ln_f) transparently benefit. Falls back to the original eager forward when Triton is unavailable, the tensor is on CPU, or D > 16384. Correctness: verified against fp64 reference across bf16/fp16/fp32 and Qwen3 shapes. The fused forward accumulates in fp32 and is slightly more accurate than eager in the Gemma add_unit_offset path. Standalone microbenchmark at Qwen3 hidden (2, 2048, 2560) bf16: - Forward: 0.355 ms -> 0.061 ms (5.8x) - Forward+backward: 1.452 ms -> 0.430 ms (3.4x) End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, cpu_offload=True, combined with --sdpa.fused_rope true: - Training (6 steps): 1199s -> 1126s (-6.1%) - Validation: 1111s -> 1034s (-6.9%) - val_loss: 18.248 -> 18.504 (within variance) - Total wall-clock: 2310s -> 2160s (-6.5%)
Replaces the eager `F.silu(x_fc_1) * x_fc_2` inside LLaMAMLP.forward (2 kernels: silu + elementwise multiply) with a single fused Triton kernel. Backward kernel computes both dL/dx_fc_1 and dL/dx_fc_2 in one pass. Opt-in via `--sdpa.fused_swiglu true`. When enabled, monkey-patches both `keys_values.lora.LLaMAMLP.forward` and `litgpt.model.LLaMAMLP.forward`. Falls back to eager when inputs are not on CUDA or dtypes mismatch. Correctness verified against fp64 reference across bf16/fp16/fp32 and Qwen3 shapes. Fused is slightly more accurate (0.71x mean error vs eager) from fp32 accumulation of `a * sigmoid(a) * b`. Standalone microbenchmark at Qwen3 intermediate (2, 2048, 6912) bf16: - Forward: 0.215 ms -> 0.130 ms (1.65x) - Forward+backward: 0.909 ms -> 0.645 ms (1.41x) End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, stacking on top of --sdpa.fused_rope true --sdpa.fused_rmsnorm true: - Training (6 steps): 1126s -> 1121s (-0.4%) - Validation: 1034s -> 1030s (-0.4%) - Total wall-clock: 2160s -> 2150s (-0.4%) Cumulative speedup of all three kernels (RoPE + RMSNorm + SwiGLU) vs baseline: training 1199s -> 1121s (-6.5%), total 2310s -> 2150s (-6.9%).
`CausalSelfAttention.forward` wraps RoPE with `torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1)` to splice the rotated prefix back with the un-roped tail for partial-rope models. When `rope_n_elem == head_size` (Qwen3, Llama3, most modern configs), the tail slice is empty but cat still launches a copy kernel. Add a conditional: when rope covers the full head dim, assign `q = q_roped` directly. For partial-rope models (rotary_percentage < 1.0, e.g., litgpt default 0.25) the else-branch preserves the original cat behavior exactly. Eliminates 2 kernel launches per Block.forward call (one for q, one for k) on full-rope models. End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, stacked on top of `--sdpa.fused_rope=true --sdpa.fused_rmsnorm=true --sdpa.fused_swiglu=true`: - Training (6 steps): 1121s -> 1120s (within noise) - Validation: 1030s -> 1004s (-2.5%) - Total wall-clock: 2150s -> 2124s (-1.1%) Cumulative vs untouched baseline: 2310s -> 2124s (-8.0%). val_loss stays within variance (18.440 vs baseline 18.248, ±0.3).
Adds `torch.cuda.nvtx.range_push/pop` around the major phases of
`GradientAccumulator.run` so nsys profiles can attribute GPU work to
meaningful regions:
- `accumulator_run_{first_layer_idx}` (outer, in main.py)
- `compute_checkpoints`, `sync_checkpoints`
- Per-cell: `cell_inputs_{col}`, `forward_{col}`, `backward_{col}`,
`write_head_grads_{col}`
No performance impact — NVTX is a CPU-side marker that nsys reads. When
nsys is not running, the calls are no-ops.
These annotations match the names used in prior profiling work
(PROFILING_RESULTS.md), so before/after comparisons are directly
aligned.
`create_random_index` used to do a Python for-loop over (batch, n_heads)
calling `torch.randperm(length)[:num]` per iteration. Each CUDA randperm
launches ~7 small kernels (histogram, radix sort passes, exclusive sum,
duplicate handling, arange, fill), so the loop produced batch*n_heads*7
kernel launches.
Replace with batched argsort(rand_vals): draw uniform random values of
shape (batch, n_heads, length) in one call, argsort along the last dim
to produce independent permutations per (b, h), slice to `num`. Same
semantics, one sort-kernel sequence instead of batch*n_heads of them.
Discovered via nsys re-profile after the three Triton kernels landed.
forward_1 had ~358k radix-sort kernels per step 3 (49% of the sort
machinery's total), traceable to this single code path through
`create_ext_annotations` in the saved_tensors_hooks annotation
pipeline.
Keeps the CPU loop as a fallback for non-CUDA devices (randperm on
CPU is a single call, no benefit from batching).
Standalone microbenchmark (batch=2, n_heads=8 or 32, num=64, length=2048):
- Small (n_kv_heads=8): 0.985 ms -> 0.083 ms (11.9x)
- Large (n_heads=32): 3.832 ms -> 0.082 ms (46.7x)
End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, stacked on top of all
three Triton fused kernels + RoPE cat skip:
- Training (6 steps): 1120.23s -> 1077.40s (-3.8%)
- Validation: 1004.14s -> 1007.39s (~neutral, no annotations
in inference path)
- Total wall-clock: 2124.37s -> 2084.79s (-1.9%)
- val_loss: 18.440 -> 18.237 (within baseline variance)
Cumulative vs untouched baseline: 2309.88s -> 2084.79s (-9.7%).
`_inference_forward_pass` had a `torch.cuda.synchronize()` at the end of each middle-loop (per-layer) iteration, intended to wait for the D2H of the just-written layer-input checkpoint before overwriting the shared quantizer buffer on the next layer. The D2H in `QuantizerState.copy_` already uses `non_blocking=True` on the default stream, and the next layer's quantize kernel runs on the same stream, so GPU-level ordering was already correct without the sync. The sync only stalled the CPU thread; GPU work wasn't waiting for it. Removing the sync lets CPU dispatch run ahead of the GPU pipeline again. The original TODO comment suspected this already. Measured at Qwen3-4B on A100-40GB, Helmet nq 64k, cpu_offload=True, stacked on top of all six prior async_offload commits: - Step 1: 276.6s -> 261.8s (-5%) - Step 3: 154.7s -> 148.7s (-4%) - 3 steps: 612.8s -> 592.1s (-3.4%) - Losses within prior variance; no corruption. A followup (`d2h_stream` + event-based ordering) would let the D2H actually overlap with the next layer's quantize on a separate stream; this commit is the zero-infrastructure part of that path.
mseeger
left a comment
There was a problem hiding this comment.
Just some comments. But otherwise, can merge this, and I can do the comparison tests, unless you have the code for them already.
| FlashInfer SDPA if summed attention weights are required. If | ||
| `flex_attention == False`, this kernel is also used if attention | ||
| weights are not needed. | ||
| fused_rope: If `True`, replace the eager rotary position embedding |
There was a problem hiding this comment.
Are these compiled on the fly? I don't know how Triton works?
And if they speed up things, we should maybe make the default True, no?
There was a problem hiding this comment.
The first run it is compiled and then you see the speed ups for the remaining steps.
| if self._verbose_more: | ||
| print("Forward pass to store KV cache checkpoints") | ||
| infer_replay_caches = self._create_inference_replay_caches(model_part) | ||
| torch.cuda.nvtx.range_push("compute_checkpoints") |
There was a problem hiding this comment.
Remove this again. Also the other nvtx commands below.
| for b in range(shape[0]): | ||
| for h in range(shape[1]): | ||
| result[b, h, :] = torch.randperm(length, **index_kwargs)[:num] | ||
| # Batched random permutation: draw uniform random values of shape |
There was a problem hiding this comment.
Great catch, and nice solution (create random numbers and just sort them).
This should be a function, but I can do such clean-up later.
| perms = perms[..., :num] | ||
| result = perms.to(dtype=dtype) | ||
| else: | ||
| # CPU fallback: keep the original loop — randperm on CPU is a single |
There was a problem hiding this comment.
What is wrong with the code above? I'd do the same on CPU, it is always a better solution
| snapshots = self._record_gpu_memory_snapshots | ||
| else: | ||
| snapshots = None | ||
| torch.cuda.nvtx.range_push(f"accumulator_run_{first_layer_idx}") |
There was a problem hiding this comment.
Remove all nvtx annotations
| @@ -0,0 +1,447 @@ | |||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | |||
There was a problem hiding this comment.
In general, we need tests to compare these fused operators with the PyTorch code.
I can do this, don't worry. Let us merge it. Unless you have this comparison code already, then let us use it.
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Fused Triton kernel for rotary position embedding (RoPE). |
There was a problem hiding this comment.
Also here, need comparison tests. I can do this, unless you already have it.
There was a problem hiding this comment.
What I did for the comparison was just run the entire setup with and without it. And compared the resulting timing results.
| # TODO: Is this really a problem? We don't modify | ||
| # `embeddings`, but just unlink it. | ||
| torch.cuda.synchronize() | ||
| # No explicit sync needed between layers: the D2H in |
| ) | ||
| q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) | ||
| k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) | ||
| # Skip the outer cat when rope covers the full head dim (e.g., |
| def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: | ||
| if _USE_FUSED_ROPE: | ||
| from keys_values.fused_rope import fused_apply_rope, can_use_fused_rope | ||
| if can_use_fused_rope(x, cos, sin): |
There was a problem hiding this comment.
Where is this defined? Is this included somewhere?
There was a problem hiding this comment.
This is for the fused triton kernel for rope.
mseeger
left a comment
There was a problem hiding this comment.
Approving as is. I'll follow up with another PR adding tests
Summary
Seven incremental performance improvements targeting training-forward GPU
idle time on
LongContextGradientModel. End-to-end measured at Qwen3-4Bon A100-40GB, Helmet nq 64k context,
cpu_offload=True:All optimizations are opt-in via CLI flags and default to off so the
upstream default behavior is unchanged.
Before vs After latency
Profiled at step 3 (steady state, post-warmup) of a 3-step training run.
Both profiles were captured with
nsys profile --delay=300on the sameGPU back-to-back; the only difference is whether the optimizations were
applied. See
BEFORE_AFTER_REPORT.mdfor full methodology.Step 3 (1 training step) — headline numbers
Dispatch-gap distribution (first
forward_1call)Kernel-to-kernel idle gaps — the classic "Python dispatch overhead" signature.
The biggest wins are in the 20-50 μs bucket (kernel-launch-overhead-sized
gaps between tiny sort/randperm/elementwise kernels). That bucket dropped
from 2,153 gaps totaling 76 ms to 101 gaps totaling 2.8 ms — almost entirely
from commit 6 (batched randperm) eliminating the ~360k sort kernels that
were pulling the tail.
End-to-end (6 steps + validation)
Separate, un-profiled run — this is the real user-facing wall-clock impact.
Commits
a007142--sdpa.fused_rope)537ab13--sdpa.fused_rmsnorm)8882064--sdpa.fused_swiglu)f4d3b8ctorch.catinCausalSelfAttention.forwardwhen rope covers the full head dim (Qwen3, Llama3, most modern configs)122d569c7d8845torch.randpermincreate_random_index(~360k sort-family kernels eliminated per step 3)d00e0a1torch.cuda.synchronize()in_inference_forward_pass's layer-input checkpoint loopCorrectness
Each kernel is verified against an fp64 reference. Standalone microbenchmarks
in the kernel source files. End-to-end
val_lossstays within run-to-runvariance (historic spread across our runs: 17.87 – 18.50).
Experiments that didn't work (reverted, documented)
compile_blocksandcompile_mlpvariants —torch.compilecorruptsgradients when stacked on the training cell loop's
saved_tensors_hookswrapperhead-gradient checkpoints (Target A) — neither produced a wall-clock
improvement above the noise floor
Detailed measurements and analysis:
PROFILING_RESULTS.md,BEFORE_AFTER_REPORT.md,TRANSFER_IDLE_REPORT.md(local-only).Reproducing the numbers
All measurements were taken with the command below (same config as
BENCHMARK_RUNS.md). The only difference between theBeforeandAfterrows in the Summary is the three--sdpa.fused_*flags —commits 4-7 (cat skip, NVTX ranges, batched randperm, sync removal)
are always-on and require no flags.
Benchmark: 6 training steps + final validation
To reproduce the Before baseline, drop the three
--sdpa.fused_*flags(or set them to
false). Commits 4-7 are always-on, so "Before" here iscommits 1-3 disabled, not literal
upstream/main. For a pureupstream/mainbaseline, check outupstream/maindirectly.Standalone kernel microbenchmarks
Each fused kernel file has a
__main__-style microbench block. Rundirectly, e.g.:
nsys profile (optional)
The 14 NVTX range push/pop pairs added in commit
122d569let nsysattribute work to named phases (
compute_checkpoints,forward_{col},backward_{col},write_head_grads_{col}, etc.). Measured overheadwhen no profiler is attached is ~0.3 ms per training step (<0.001% of
step wall). To profile:
The
--delay=300skips tokenization + model load + step-1 torch.compile/ Triton JIT warmup. For analysis, export to sqlite
(
nsys export -t sqlite ...) and query per-phase wall / kernel / idlebreakdowns.
Test plan
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.