Skip to content

Commit 3168e3f

Browse files
vihangpVihangPatilmseeger
authored
Fused kernels and removal of torch cuda synchronize (#105)
* Add fused Triton RoPE kernel Replaces the eager apply_rope sequence (~5-6 elementwise kernels per call) with a single Triton kernel for both forward and backward. Reduces kernel launch overhead in the training forward path, where RoPE accounts for roughly 15-20% of the 21k kernels/call at Qwen3-4B scale. Opt-in via `--sdpa.fused_rope true`. Falls back to eager apply_rope when Triton is unavailable or the input shape/dtype is unsupported. Correctness: verified against fp64 reference across bf16/fp16/fp32 and Qwen3 shapes. Fused kernel accumulates in fp32 internally and is typically more accurate than eager (0.69x mean error vs fp64). Measured on Qwen3-4B, A100-40GB, Helmet nq 64k, cpu_offload=True: - Training (6 steps): 1199s -> 1174s (-2.1%) - Validation: 1111s -> 1076s (-3.1%) - val_loss: 18.248 -> 16.264 (slight improvement, likely fp32 accumulation) Standalone microbenchmark at Qwen3 Q shape (2, 32, 2048, 128) bf16: - Forward: 0.394 ms -> 0.070 ms (5.6x) - Forward+backward: 1.061 ms -> 0.339 ms (3.1x) * Add fused Triton RMSNorm kernel Replaces the eager RMSNorm.forward (fp32 cast + x*x + mean + rsqrt + normalize + mul weight + cast-back = ~5-6 kernels) with a single Triton forward kernel. Backward uses two kernels: a row-parallel kernel for dL/dx and a split-K two-pass reduction for dL/dw. The dL/dw split-K design: a first pass produces (n_m_blocks, D) partial sums in parallel across both M and D; a second pass reduces partials along M per column. Target ~432 total programs (4 per SM on A100) for full occupancy. Opt-in via `--sdpa.fused_rmsnorm true`. When enabled, monkey-patches both `keys_values.model.RMSNorm.forward` and `litgpt.model.RMSNorm.forward` so all existing call sites (norm_1, post_attention_norm, norm_2, post_mlp_norm, norm_q, norm_k, ln_f) transparently benefit. Falls back to the original eager forward when Triton is unavailable, the tensor is on CPU, or D > 16384. Correctness: verified against fp64 reference across bf16/fp16/fp32 and Qwen3 shapes. The fused forward accumulates in fp32 and is slightly more accurate than eager in the Gemma add_unit_offset path. Standalone microbenchmark at Qwen3 hidden (2, 2048, 2560) bf16: - Forward: 0.355 ms -> 0.061 ms (5.8x) - Forward+backward: 1.452 ms -> 0.430 ms (3.4x) End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, cpu_offload=True, combined with --sdpa.fused_rope true: - Training (6 steps): 1199s -> 1126s (-6.1%) - Validation: 1111s -> 1034s (-6.9%) - val_loss: 18.248 -> 18.504 (within variance) - Total wall-clock: 2310s -> 2160s (-6.5%) * Add fused Triton SwiGLU kernel Replaces the eager `F.silu(x_fc_1) * x_fc_2` inside LLaMAMLP.forward (2 kernels: silu + elementwise multiply) with a single fused Triton kernel. Backward kernel computes both dL/dx_fc_1 and dL/dx_fc_2 in one pass. Opt-in via `--sdpa.fused_swiglu true`. When enabled, monkey-patches both `keys_values.lora.LLaMAMLP.forward` and `litgpt.model.LLaMAMLP.forward`. Falls back to eager when inputs are not on CUDA or dtypes mismatch. Correctness verified against fp64 reference across bf16/fp16/fp32 and Qwen3 shapes. Fused is slightly more accurate (0.71x mean error vs eager) from fp32 accumulation of `a * sigmoid(a) * b`. Standalone microbenchmark at Qwen3 intermediate (2, 2048, 6912) bf16: - Forward: 0.215 ms -> 0.130 ms (1.65x) - Forward+backward: 0.909 ms -> 0.645 ms (1.41x) End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, stacking on top of --sdpa.fused_rope true --sdpa.fused_rmsnorm true: - Training (6 steps): 1126s -> 1121s (-0.4%) - Validation: 1034s -> 1030s (-0.4%) - Total wall-clock: 2160s -> 2150s (-0.4%) Cumulative speedup of all three kernels (RoPE + RMSNorm + SwiGLU) vs baseline: training 1199s -> 1121s (-6.5%), total 2310s -> 2150s (-6.9%). * Skip RoPE outer cat when rope covers the full head dim `CausalSelfAttention.forward` wraps RoPE with `torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1)` to splice the rotated prefix back with the un-roped tail for partial-rope models. When `rope_n_elem == head_size` (Qwen3, Llama3, most modern configs), the tail slice is empty but cat still launches a copy kernel. Add a conditional: when rope covers the full head dim, assign `q = q_roped` directly. For partial-rope models (rotary_percentage < 1.0, e.g., litgpt default 0.25) the else-branch preserves the original cat behavior exactly. Eliminates 2 kernel launches per Block.forward call (one for q, one for k) on full-rope models. End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, stacked on top of `--sdpa.fused_rope=true --sdpa.fused_rmsnorm=true --sdpa.fused_swiglu=true`: - Training (6 steps): 1121s -> 1120s (within noise) - Validation: 1030s -> 1004s (-2.5%) - Total wall-clock: 2150s -> 2124s (-1.1%) Cumulative vs untouched baseline: 2310s -> 2124s (-8.0%). val_loss stays within variance (18.440 vs baseline 18.248, ±0.3). * Add NVTX ranges for gradient accumulation phases Adds `torch.cuda.nvtx.range_push/pop` around the major phases of `GradientAccumulator.run` so nsys profiles can attribute GPU work to meaningful regions: - `accumulator_run_{first_layer_idx}` (outer, in main.py) - `compute_checkpoints`, `sync_checkpoints` - Per-cell: `cell_inputs_{col}`, `forward_{col}`, `backward_{col}`, `write_head_grads_{col}` No performance impact — NVTX is a CPU-side marker that nsys reads. When nsys is not running, the calls are no-ops. These annotations match the names used in prior profiling work (PROFILING_RESULTS.md), so before/after comparisons are directly aligned. * Batch create_random_index randperm calls `create_random_index` used to do a Python for-loop over (batch, n_heads) calling `torch.randperm(length)[:num]` per iteration. Each CUDA randperm launches ~7 small kernels (histogram, radix sort passes, exclusive sum, duplicate handling, arange, fill), so the loop produced batch*n_heads*7 kernel launches. Replace with batched argsort(rand_vals): draw uniform random values of shape (batch, n_heads, length) in one call, argsort along the last dim to produce independent permutations per (b, h), slice to `num`. Same semantics, one sort-kernel sequence instead of batch*n_heads of them. Discovered via nsys re-profile after the three Triton kernels landed. forward_1 had ~358k radix-sort kernels per step 3 (49% of the sort machinery's total), traceable to this single code path through `create_ext_annotations` in the saved_tensors_hooks annotation pipeline. Keeps the CPU loop as a fallback for non-CUDA devices (randperm on CPU is a single call, no benefit from batching). Standalone microbenchmark (batch=2, n_heads=8 or 32, num=64, length=2048): - Small (n_kv_heads=8): 0.985 ms -> 0.083 ms (11.9x) - Large (n_heads=32): 3.832 ms -> 0.082 ms (46.7x) End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, stacked on top of all three Triton fused kernels + RoPE cat skip: - Training (6 steps): 1120.23s -> 1077.40s (-3.8%) - Validation: 1004.14s -> 1007.39s (~neutral, no annotations in inference path) - Total wall-clock: 2124.37s -> 2084.79s (-1.9%) - val_loss: 18.440 -> 18.237 (within baseline variance) Cumulative vs untouched baseline: 2309.88s -> 2084.79s (-9.7%). * Drop per-layer sync in layer-input checkpoint loop `_inference_forward_pass` had a `torch.cuda.synchronize()` at the end of each middle-loop (per-layer) iteration, intended to wait for the D2H of the just-written layer-input checkpoint before overwriting the shared quantizer buffer on the next layer. The D2H in `QuantizerState.copy_` already uses `non_blocking=True` on the default stream, and the next layer's quantize kernel runs on the same stream, so GPU-level ordering was already correct without the sync. The sync only stalled the CPU thread; GPU work wasn't waiting for it. Removing the sync lets CPU dispatch run ahead of the GPU pipeline again. The original TODO comment suspected this already. Measured at Qwen3-4B on A100-40GB, Helmet nq 64k, cpu_offload=True, stacked on top of all six prior async_offload commits: - Step 1: 276.6s -> 261.8s (-5%) - Step 3: 154.7s -> 148.7s (-4%) - 3 steps: 612.8s -> 592.1s (-3.4%) - Losses within prior variance; no corruption. A followup (`d2h_stream` + event-based ordering) would let the D2H actually overlap with the next layer's quantize on a separate stream; this commit is the zero-infrastructure part of that path. --------- Co-authored-by: VihangPatil <pvihang@amazon.com> Co-authored-by: Matthias Seeger <matthis@amazon.de>
1 parent b733be1 commit 3168e3f

11 files changed

Lines changed: 1056 additions & 16 deletions

File tree

keys_values/finetune/args.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,26 @@ class SDPAArgs:
585585
is available.
586586
dynamo_cache_size_limit: Value for `torch._dynamo.config.cache_size_limit`.
587587
Defaults to 32. The built-in default 8 is too small for our purposes.
588+
fused_rope: If `True`, replace the eager rotary position embedding
589+
(`apply_rope`) with a single fused Triton kernel. Falls back to
590+
eager automatically when Triton is unavailable or the input shape
591+
is incompatible. Correctness is verified against an fp64
592+
reference; the fused kernel accumulates in fp32 internally and is
593+
typically *more* accurate than eager in bf16/fp16. Measured at
594+
Qwen3-4B on A100-40GB: ~2% end-to-end speedup, val_loss matches
595+
or improves. See `keys_values/fused_rope.py`.
596+
fused_rmsnorm: If `True`, patch both `keys_values.model.RMSNorm` and
597+
`litgpt.model.RMSNorm` so their `forward` dispatches to a fused
598+
Triton kernel. Falls back to the original eager forward when
599+
Triton is unavailable, the tensor is on CPU, or the input shape
600+
is unsupported. Correctness verified against an fp64 reference.
601+
See `keys_values/fused_rmsnorm.py`.
602+
fused_swiglu: If `True`, patch `LLaMAMLP.forward` (both
603+
`keys_values.lora` and `litgpt.model` variants) so the
604+
`F.silu(x_fc_1) * x_fc_2` step runs as a single fused Triton
605+
kernel instead of two eager kernels. Falls back to eager when
606+
inputs are not on CUDA or dtypes mismatch. Correctness verified
607+
against an fp64 reference. See `keys_values/fused_swiglu.py`.
588608
flashinfer_attention: If `True` and FlashInfer is available, we use
589609
FlashInfer SDPA if summed attention weights are required. If
590610
`flex_attention == False`, this kernel is also used if attention
@@ -597,6 +617,9 @@ class SDPAArgs:
597617
reorder_sort_if_3d: bool = True
598618
use_flex_for_attn_weights: bool = True
599619
dynamo_cache_size_limit: int = 32
620+
fused_rope: bool = False
621+
fused_rmsnorm: bool = False
622+
fused_swiglu: bool = False
600623
flashinfer_attention: bool = True
601624

602625
def __post_init__(self):

keys_values/finetune/longcontext_full.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,12 @@ def get_mha_and_cache_kwargs(
982982
init_val=limit_gb,
983983
name="attention_forward_temp_size_gb",
984984
)
985+
from keys_values.pos_encoding import set_fused_rope_enabled
986+
set_fused_rope_enabled(sdpa.fused_rope)
987+
from keys_values.fused_rmsnorm import set_fused_rmsnorm_enabled
988+
set_fused_rmsnorm_enabled(sdpa.fused_rmsnorm)
989+
from keys_values.fused_swiglu import set_fused_swiglu_enabled
990+
set_fused_swiglu_enabled(sdpa.fused_swiglu)
985991
mha_kwargs: Dict[str, Any] = dict(
986992
tmp_array_limit_gb=tmp_array_limit_forward,
987993
pos_encoding=position_encoding_factory(config, do_yarn=yarn_rope),

0 commit comments

Comments
 (0)