Commit 3168e3f
Fused kernels and removal of torch cuda synchronize (#105)
* Add fused Triton RoPE kernel
Replaces the eager apply_rope sequence (~5-6 elementwise kernels per call)
with a single Triton kernel for both forward and backward. Reduces kernel
launch overhead in the training forward path, where RoPE accounts for
roughly 15-20% of the 21k kernels/call at Qwen3-4B scale.
Opt-in via `--sdpa.fused_rope true`. Falls back to eager apply_rope when
Triton is unavailable or the input shape/dtype is unsupported.
Correctness: verified against fp64 reference across bf16/fp16/fp32 and
Qwen3 shapes. Fused kernel accumulates in fp32 internally and is typically
more accurate than eager (0.69x mean error vs fp64).
Measured on Qwen3-4B, A100-40GB, Helmet nq 64k, cpu_offload=True:
- Training (6 steps): 1199s -> 1174s (-2.1%)
- Validation: 1111s -> 1076s (-3.1%)
- val_loss: 18.248 -> 16.264 (slight improvement, likely fp32 accumulation)
Standalone microbenchmark at Qwen3 Q shape (2, 32, 2048, 128) bf16:
- Forward: 0.394 ms -> 0.070 ms (5.6x)
- Forward+backward: 1.061 ms -> 0.339 ms (3.1x)
* Add fused Triton RMSNorm kernel
Replaces the eager RMSNorm.forward (fp32 cast + x*x + mean + rsqrt + normalize
+ mul weight + cast-back = ~5-6 kernels) with a single Triton forward kernel.
Backward uses two kernels: a row-parallel kernel for dL/dx and a split-K
two-pass reduction for dL/dw.
The dL/dw split-K design: a first pass produces (n_m_blocks, D) partial sums
in parallel across both M and D; a second pass reduces partials along M per
column. Target ~432 total programs (4 per SM on A100) for full occupancy.
Opt-in via `--sdpa.fused_rmsnorm true`. When enabled, monkey-patches both
`keys_values.model.RMSNorm.forward` and `litgpt.model.RMSNorm.forward` so
all existing call sites (norm_1, post_attention_norm, norm_2, post_mlp_norm,
norm_q, norm_k, ln_f) transparently benefit. Falls back to the original
eager forward when Triton is unavailable, the tensor is on CPU, or D > 16384.
Correctness: verified against fp64 reference across bf16/fp16/fp32 and Qwen3
shapes. The fused forward accumulates in fp32 and is slightly more accurate
than eager in the Gemma add_unit_offset path.
Standalone microbenchmark at Qwen3 hidden (2, 2048, 2560) bf16:
- Forward: 0.355 ms -> 0.061 ms (5.8x)
- Forward+backward: 1.452 ms -> 0.430 ms (3.4x)
End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, cpu_offload=True,
combined with --sdpa.fused_rope true:
- Training (6 steps): 1199s -> 1126s (-6.1%)
- Validation: 1111s -> 1034s (-6.9%)
- val_loss: 18.248 -> 18.504 (within variance)
- Total wall-clock: 2310s -> 2160s (-6.5%)
* Add fused Triton SwiGLU kernel
Replaces the eager `F.silu(x_fc_1) * x_fc_2` inside LLaMAMLP.forward
(2 kernels: silu + elementwise multiply) with a single fused Triton
kernel. Backward kernel computes both dL/dx_fc_1 and dL/dx_fc_2 in one
pass.
Opt-in via `--sdpa.fused_swiglu true`. When enabled, monkey-patches both
`keys_values.lora.LLaMAMLP.forward` and `litgpt.model.LLaMAMLP.forward`.
Falls back to eager when inputs are not on CUDA or dtypes mismatch.
Correctness verified against fp64 reference across bf16/fp16/fp32 and
Qwen3 shapes. Fused is slightly more accurate (0.71x mean error vs
eager) from fp32 accumulation of `a * sigmoid(a) * b`.
Standalone microbenchmark at Qwen3 intermediate (2, 2048, 6912) bf16:
- Forward: 0.215 ms -> 0.130 ms (1.65x)
- Forward+backward: 0.909 ms -> 0.645 ms (1.41x)
End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, stacking on top of
--sdpa.fused_rope true --sdpa.fused_rmsnorm true:
- Training (6 steps): 1126s -> 1121s (-0.4%)
- Validation: 1034s -> 1030s (-0.4%)
- Total wall-clock: 2160s -> 2150s (-0.4%)
Cumulative speedup of all three kernels (RoPE + RMSNorm + SwiGLU) vs
baseline: training 1199s -> 1121s (-6.5%), total 2310s -> 2150s (-6.9%).
* Skip RoPE outer cat when rope covers the full head dim
`CausalSelfAttention.forward` wraps RoPE with `torch.cat((q_roped,
q[..., rope_n_elem:]), dim=-1)` to splice the rotated prefix back with
the un-roped tail for partial-rope models. When `rope_n_elem ==
head_size` (Qwen3, Llama3, most modern configs), the tail slice is
empty but cat still launches a copy kernel.
Add a conditional: when rope covers the full head dim, assign
`q = q_roped` directly. For partial-rope models (rotary_percentage < 1.0,
e.g., litgpt default 0.25) the else-branch preserves the original cat
behavior exactly.
Eliminates 2 kernel launches per Block.forward call (one for q, one for
k) on full-rope models.
End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, stacked on top of
`--sdpa.fused_rope=true --sdpa.fused_rmsnorm=true --sdpa.fused_swiglu=true`:
- Training (6 steps): 1121s -> 1120s (within noise)
- Validation: 1030s -> 1004s (-2.5%)
- Total wall-clock: 2150s -> 2124s (-1.1%)
Cumulative vs untouched baseline: 2310s -> 2124s (-8.0%).
val_loss stays within variance (18.440 vs baseline 18.248, ±0.3).
* Add NVTX ranges for gradient accumulation phases
Adds `torch.cuda.nvtx.range_push/pop` around the major phases of
`GradientAccumulator.run` so nsys profiles can attribute GPU work to
meaningful regions:
- `accumulator_run_{first_layer_idx}` (outer, in main.py)
- `compute_checkpoints`, `sync_checkpoints`
- Per-cell: `cell_inputs_{col}`, `forward_{col}`, `backward_{col}`,
`write_head_grads_{col}`
No performance impact — NVTX is a CPU-side marker that nsys reads. When
nsys is not running, the calls are no-ops.
These annotations match the names used in prior profiling work
(PROFILING_RESULTS.md), so before/after comparisons are directly
aligned.
* Batch create_random_index randperm calls
`create_random_index` used to do a Python for-loop over (batch, n_heads)
calling `torch.randperm(length)[:num]` per iteration. Each CUDA randperm
launches ~7 small kernels (histogram, radix sort passes, exclusive sum,
duplicate handling, arange, fill), so the loop produced batch*n_heads*7
kernel launches.
Replace with batched argsort(rand_vals): draw uniform random values of
shape (batch, n_heads, length) in one call, argsort along the last dim
to produce independent permutations per (b, h), slice to `num`. Same
semantics, one sort-kernel sequence instead of batch*n_heads of them.
Discovered via nsys re-profile after the three Triton kernels landed.
forward_1 had ~358k radix-sort kernels per step 3 (49% of the sort
machinery's total), traceable to this single code path through
`create_ext_annotations` in the saved_tensors_hooks annotation
pipeline.
Keeps the CPU loop as a fallback for non-CUDA devices (randperm on
CPU is a single call, no benefit from batching).
Standalone microbenchmark (batch=2, n_heads=8 or 32, num=64, length=2048):
- Small (n_kv_heads=8): 0.985 ms -> 0.083 ms (11.9x)
- Large (n_heads=32): 3.832 ms -> 0.082 ms (46.7x)
End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, stacked on top of all
three Triton fused kernels + RoPE cat skip:
- Training (6 steps): 1120.23s -> 1077.40s (-3.8%)
- Validation: 1004.14s -> 1007.39s (~neutral, no annotations
in inference path)
- Total wall-clock: 2124.37s -> 2084.79s (-1.9%)
- val_loss: 18.440 -> 18.237 (within baseline variance)
Cumulative vs untouched baseline: 2309.88s -> 2084.79s (-9.7%).
* Drop per-layer sync in layer-input checkpoint loop
`_inference_forward_pass` had a `torch.cuda.synchronize()` at the end of
each middle-loop (per-layer) iteration, intended to wait for the D2H of
the just-written layer-input checkpoint before overwriting the shared
quantizer buffer on the next layer.
The D2H in `QuantizerState.copy_` already uses `non_blocking=True` on
the default stream, and the next layer's quantize kernel runs on the
same stream, so GPU-level ordering was already correct without the
sync. The sync only stalled the CPU thread; GPU work wasn't waiting
for it.
Removing the sync lets CPU dispatch run ahead of the GPU pipeline
again. The original TODO comment suspected this already.
Measured at Qwen3-4B on A100-40GB, Helmet nq 64k, cpu_offload=True,
stacked on top of all six prior async_offload commits:
- Step 1: 276.6s -> 261.8s (-5%)
- Step 3: 154.7s -> 148.7s (-4%)
- 3 steps: 612.8s -> 592.1s (-3.4%)
- Losses within prior variance; no corruption.
A followup (`d2h_stream` + event-based ordering) would let the D2H
actually overlap with the next layer's quantize on a separate stream;
this commit is the zero-infrastructure part of that path.
---------
Co-authored-by: VihangPatil <pvihang@amazon.com>
Co-authored-by: Matthias Seeger <matthis@amazon.de>1 parent b733be1 commit 3168e3f
11 files changed
Lines changed: 1056 additions & 16 deletions
File tree
- keys_values
- finetune
- kvcache/gradient
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
585 | 585 | | |
586 | 586 | | |
587 | 587 | | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
| 593 | + | |
| 594 | + | |
| 595 | + | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
588 | 608 | | |
589 | 609 | | |
590 | 610 | | |
| |||
597 | 617 | | |
598 | 618 | | |
599 | 619 | | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
600 | 623 | | |
601 | 624 | | |
602 | 625 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
982 | 982 | | |
983 | 983 | | |
984 | 984 | | |
| 985 | + | |
| 986 | + | |
| 987 | + | |
| 988 | + | |
| 989 | + | |
| 990 | + | |
985 | 991 | | |
986 | 992 | | |
987 | 993 | | |
| |||
0 commit comments