Skip to content

Fused kernels and removal of torch cuda synchronize #105

Merged
mseeger merged 8 commits into
awslabs:mainfrom
vihangp:async_offload
May 7, 2026
Merged

Fused kernels and removal of torch cuda synchronize #105
mseeger merged 8 commits into
awslabs:mainfrom
vihangp:async_offload

Conversation

@vihangp
Copy link
Copy Markdown
Collaborator

@vihangp vihangp commented May 7, 2026

Summary

Seven incremental performance improvements targeting training-forward GPU
idle time on LongContextGradientModel. End-to-end measured at Qwen3-4B
on A100-40GB, Helmet nq 64k context, cpu_offload=True:

  • Training (6 steps): 1199s → 1068s (−10.9%)
  • Validation: 1111s → 1020s (−8.2%)
  • Total wall: 2310s → 2088s (−9.6%)
  • val_loss stays within run-to-run variance

All optimizations are opt-in via CLI flags and default to off so the
upstream default behavior is unchanged.

Before vs After latency

Profiled at step 3 (steady state, post-warmup) of a 3-step training run.
Both profiles were captured with nsys profile --delay=300 on the same
GPU back-to-back; the only difference is whether the optimizations were
applied. See BEFORE_AFTER_REPORT.md for full methodology.

Step 3 (1 training step) — headline numbers

Metric Before After Δ
Step 3 wall time 123.17 s 110.86 s −12.31 s (−10.0%)
forward_1 wall per call 691.3 ms 419.6 ms −271.7 ms (−39.3%)
forward_1 kernel time 388.3 ms 278.5 ms −109.8 ms (−28.3%)
forward_1 idle time 303.0 ms 141.1 ms −161.9 ms (−53.4%)
forward_1 idle % 43.8% 33.6% −10.2 pp
forward_1 kernel count 18,012 2,348 −87.0%
forward_1 total kernels × calls 648,436 84,533 −87.0%
forward_1 total kernel time 13,977 ms 10,026 ms −28.3%

Dispatch-gap distribution (first forward_1 call)

Kernel-to-kernel idle gaps — the classic "Python dispatch overhead" signature.

Bucket Before (count / total ms) After (count / total ms)
0-5 μs 11,011 / 30.8 ms 1,139 / 1.8 ms
5-20 μs 4,587 / 61.0 ms 1,034 / 11.6 ms
20-50 μs 2,153 / 76.4 ms 101 / 2.8 ms
50-100 μs 246 / 15.9 ms 85 / 6.4 ms
100-500 μs 81 / 16.7 ms 56 / 11.0 ms

The biggest wins are in the 20-50 μs bucket (kernel-launch-overhead-sized
gaps between tiny sort/randperm/elementwise kernels). That bucket dropped
from 2,153 gaps totaling 76 ms to 101 gaps totaling 2.8 ms — almost entirely
from commit 6 (batched randperm) eliminating the ~360k sort kernels that
were pulling the tail.

End-to-end (6 steps + validation)

Separate, un-profiled run — this is the real user-facing wall-clock impact.

Metric Before After Δ
Training (6 steps) 1199.14 s 1067.95 s −131.19 s (−10.9%)
Validation 1110.74 s 1020.06 s −90.68 s (−8.2%)
Total wall 2309.88 s 2088.01 s −221.87 s (−9.6%)
val_loss 18.248 17.871 within ±0.7 run-to-run variance

Commits

# Commit What it does
1 a007142 Fused Triton kernel for RoPE (--sdpa.fused_rope)
2 537ab13 Fused Triton kernel for RMSNorm with split-K backward (--sdpa.fused_rmsnorm)
3 8882064 Fused Triton kernel for SwiGLU (--sdpa.fused_swiglu)
4 f4d3b8c Skip the outer torch.cat in CausalSelfAttention.forward when rope covers the full head dim (Qwen3, Llama3, most modern configs)
5 122d569 Add NVTX ranges for gradient accumulation phases (no perf impact, enables profiling)
6 c7d8845 Batch torch.randperm in create_random_index (~360k sort-family kernels eliminated per step 3)
7 d00e0a1 Remove unneeded per-layer torch.cuda.synchronize() in _inference_forward_pass's layer-input checkpoint loop

Correctness

Each kernel is verified against an fp64 reference. Standalone microbenchmarks
in the kernel source files. End-to-end val_loss stays within run-to-run
variance (historic spread across our runs: 17.87 – 18.50).

Experiments that didn't work (reverted, documented)

  • compile_blocks and compile_mlp variants — torch.compile corrupts
    gradients when stacked on the training cell loop's
    saved_tensors_hooks wrapper
  • async d2h_stream for layer-input checkpoints (Option 2) and for
    head-gradient checkpoints (Target A) — neither produced a wall-clock
    improvement above the noise floor

Detailed measurements and analysis: PROFILING_RESULTS.md,
BEFORE_AFTER_REPORT.md, TRANSFER_IDLE_REPORT.md (local-only).

Reproducing the numbers

All measurements were taken with the command below (same config as
BENCHMARK_RUNS.md). The only difference between the Before and
After rows in the Summary is the three --sdpa.fused_* flags —
commits 4-7 (cat skip, NVTX ranges, batched randperm, sync removal)
are always-on and require no flags.

Benchmark: 6 training steps + final validation

PYTORCH_ALLOC_CONF=expandable_segments:True CUDA_VISIBLE_DEVICES=0 \
  <venv>/bin/python -m keys_values finetune_long_lora \
    Qwen/Qwen3-4B-Instruct-2507 \
    --out_dir <out_dir> \
    --precision bf16-true --verbose some --devices 1 \
    --data Helmet --data.dataset_key nq --data.max_length 64k \
    --data.trainloader_longest_first True \
    --train.save_interval 10 \
    --train.micro_batch_size 2 --train.global_batch_size 2 \
    --train.max_steps 6 \
    --eval.interval 10 --eval.initial_validation False \
    --eval.final_validation True \
    --attention_forward_temp_size_gb 2 \
    --kv_cache.cache_length 32768 --kv_cache.chunk_size 2048 \
    --kv_cache.name h2o-torch-quantized8 --kv_cache.cpu_offload True \
    --grad.layers_per_cell 1 \
    --grad.layercp_pin_memory true --grad.cachecp_pin_memory true \
    --sdpa.flex_attention true --sdpa.flex_extend_kv false \
    --sdpa.flex_num_q_lens 4 --sdpa.reorder_sort_if_3d true \
    --sdpa.use_flex_for_attn_weights true \
    --sdpa.dynamo_cache_size_limit 32 \
    --sdpa.fused_rope true \
    --sdpa.fused_rmsnorm true \
    --sdpa.fused_swiglu true \
    --yarn_rope true --oom_error_recovery false \
    --head_model next_token_prediction

To reproduce the Before baseline, drop the three --sdpa.fused_* flags
(or set them to false). Commits 4-7 are always-on, so "Before" here is
commits 1-3 disabled, not literal upstream/main. For a pure
upstream/main baseline, check out upstream/main directly.

Standalone kernel microbenchmarks

Each fused kernel file has a __main__-style microbench block. Run
directly, e.g.:

<venv>/bin/python -c "
import torch
from litgpt.model import apply_rope as eager
from keys_values.fused_rope import fused_apply_rope

x = torch.randn(2, 32, 2048, 128, device='cuda', dtype=torch.bfloat16,
                requires_grad=True)
cos = torch.randn(1, 2048, 128, device='cuda', dtype=torch.bfloat16)
sin = torch.randn(1, 2048, 128, device='cuda', dtype=torch.bfloat16)
y = fused_apply_rope(x, cos, sin)
print(y.shape, y.dtype)
"

nsys profile (optional)

The 14 NVTX range push/pop pairs added in commit 122d569 let nsys
attribute work to named phases (compute_checkpoints, forward_{col},
backward_{col}, write_head_grads_{col}, etc.). Measured overhead
when no profiler is attached is ~0.3 ms per training step (<0.001% of
step wall). To profile:

nsys profile --trace=cuda,nvtx --sample=none --cpuctxsw=none \
  --delay=300 --output=<path>/prof \
  <venv>/bin/python -m keys_values finetune_long_lora ...

The --delay=300 skips tokenization + model load + step-1 torch.compile
/ Triton JIT warmup. For analysis, export to sqlite
(nsys export -t sqlite ...) and query per-phase wall / kernel / idle
breakdowns.

Test plan

  • Unit tests for each fused kernel (correctness vs fp64 reference)
  • 6-step training + final validation, losses match baseline
  • nsys-profiled kernel count / idle %: forward_1 kernels 18,060 → 2,348 (−87%), idle 45% → 34%

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

VihangPatil added 7 commits April 28, 2026 19:49
Replaces the eager apply_rope sequence (~5-6 elementwise kernels per call)
with a single Triton kernel for both forward and backward. Reduces kernel
launch overhead in the training forward path, where RoPE accounts for
roughly 15-20% of the 21k kernels/call at Qwen3-4B scale.

Opt-in via `--sdpa.fused_rope true`. Falls back to eager apply_rope when
Triton is unavailable or the input shape/dtype is unsupported.

Correctness: verified against fp64 reference across bf16/fp16/fp32 and
Qwen3 shapes. Fused kernel accumulates in fp32 internally and is typically
more accurate than eager (0.69x mean error vs fp64).

Measured on Qwen3-4B, A100-40GB, Helmet nq 64k, cpu_offload=True:
- Training (6 steps): 1199s -> 1174s (-2.1%)
- Validation: 1111s -> 1076s (-3.1%)
- val_loss: 18.248 -> 16.264 (slight improvement, likely fp32 accumulation)

Standalone microbenchmark at Qwen3 Q shape (2, 32, 2048, 128) bf16:
- Forward: 0.394 ms -> 0.070 ms (5.6x)
- Forward+backward: 1.061 ms -> 0.339 ms (3.1x)
Replaces the eager RMSNorm.forward (fp32 cast + x*x + mean + rsqrt + normalize
+ mul weight + cast-back = ~5-6 kernels) with a single Triton forward kernel.
Backward uses two kernels: a row-parallel kernel for dL/dx and a split-K
two-pass reduction for dL/dw.

The dL/dw split-K design: a first pass produces (n_m_blocks, D) partial sums
in parallel across both M and D; a second pass reduces partials along M per
column. Target ~432 total programs (4 per SM on A100) for full occupancy.

Opt-in via `--sdpa.fused_rmsnorm true`. When enabled, monkey-patches both
`keys_values.model.RMSNorm.forward` and `litgpt.model.RMSNorm.forward` so
all existing call sites (norm_1, post_attention_norm, norm_2, post_mlp_norm,
norm_q, norm_k, ln_f) transparently benefit. Falls back to the original
eager forward when Triton is unavailable, the tensor is on CPU, or D > 16384.

Correctness: verified against fp64 reference across bf16/fp16/fp32 and Qwen3
shapes. The fused forward accumulates in fp32 and is slightly more accurate
than eager in the Gemma add_unit_offset path.

Standalone microbenchmark at Qwen3 hidden (2, 2048, 2560) bf16:
- Forward:         0.355 ms -> 0.061 ms  (5.8x)
- Forward+backward: 1.452 ms -> 0.430 ms  (3.4x)

End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, cpu_offload=True,
combined with --sdpa.fused_rope true:
- Training (6 steps): 1199s -> 1126s (-6.1%)
- Validation:         1111s -> 1034s (-6.9%)
- val_loss:           18.248 -> 18.504 (within variance)
- Total wall-clock:   2310s -> 2160s (-6.5%)
Replaces the eager `F.silu(x_fc_1) * x_fc_2` inside LLaMAMLP.forward
(2 kernels: silu + elementwise multiply) with a single fused Triton
kernel. Backward kernel computes both dL/dx_fc_1 and dL/dx_fc_2 in one
pass.

Opt-in via `--sdpa.fused_swiglu true`. When enabled, monkey-patches both
`keys_values.lora.LLaMAMLP.forward` and `litgpt.model.LLaMAMLP.forward`.
Falls back to eager when inputs are not on CUDA or dtypes mismatch.

Correctness verified against fp64 reference across bf16/fp16/fp32 and
Qwen3 shapes. Fused is slightly more accurate (0.71x mean error vs
eager) from fp32 accumulation of `a * sigmoid(a) * b`.

Standalone microbenchmark at Qwen3 intermediate (2, 2048, 6912) bf16:
- Forward:         0.215 ms -> 0.130 ms  (1.65x)
- Forward+backward: 0.909 ms -> 0.645 ms  (1.41x)

End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, stacking on top of
--sdpa.fused_rope true --sdpa.fused_rmsnorm true:
- Training (6 steps): 1126s -> 1121s (-0.4%)
- Validation:         1034s -> 1030s (-0.4%)
- Total wall-clock:   2160s -> 2150s (-0.4%)

Cumulative speedup of all three kernels (RoPE + RMSNorm + SwiGLU) vs
baseline: training 1199s -> 1121s (-6.5%), total 2310s -> 2150s (-6.9%).
`CausalSelfAttention.forward` wraps RoPE with `torch.cat((q_roped,
q[..., rope_n_elem:]), dim=-1)` to splice the rotated prefix back with
the un-roped tail for partial-rope models. When `rope_n_elem ==
head_size` (Qwen3, Llama3, most modern configs), the tail slice is
empty but cat still launches a copy kernel.

Add a conditional: when rope covers the full head dim, assign
`q = q_roped` directly. For partial-rope models (rotary_percentage < 1.0,
e.g., litgpt default 0.25) the else-branch preserves the original cat
behavior exactly.

Eliminates 2 kernel launches per Block.forward call (one for q, one for
k) on full-rope models.

End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, stacked on top of
`--sdpa.fused_rope=true --sdpa.fused_rmsnorm=true --sdpa.fused_swiglu=true`:
- Training (6 steps): 1121s -> 1120s (within noise)
- Validation:         1030s -> 1004s (-2.5%)
- Total wall-clock:   2150s -> 2124s (-1.1%)

Cumulative vs untouched baseline: 2310s -> 2124s (-8.0%).

val_loss stays within variance (18.440 vs baseline 18.248, ±0.3).
Adds `torch.cuda.nvtx.range_push/pop` around the major phases of
`GradientAccumulator.run` so nsys profiles can attribute GPU work to
meaningful regions:

- `accumulator_run_{first_layer_idx}` (outer, in main.py)
- `compute_checkpoints`, `sync_checkpoints`
- Per-cell: `cell_inputs_{col}`, `forward_{col}`, `backward_{col}`,
  `write_head_grads_{col}`

No performance impact — NVTX is a CPU-side marker that nsys reads. When
nsys is not running, the calls are no-ops.

These annotations match the names used in prior profiling work
(PROFILING_RESULTS.md), so before/after comparisons are directly
aligned.
`create_random_index` used to do a Python for-loop over (batch, n_heads)
calling `torch.randperm(length)[:num]` per iteration. Each CUDA randperm
launches ~7 small kernels (histogram, radix sort passes, exclusive sum,
duplicate handling, arange, fill), so the loop produced batch*n_heads*7
kernel launches.

Replace with batched argsort(rand_vals): draw uniform random values of
shape (batch, n_heads, length) in one call, argsort along the last dim
to produce independent permutations per (b, h), slice to `num`. Same
semantics, one sort-kernel sequence instead of batch*n_heads of them.

Discovered via nsys re-profile after the three Triton kernels landed.
forward_1 had ~358k radix-sort kernels per step 3 (49% of the sort
machinery's total), traceable to this single code path through
`create_ext_annotations` in the saved_tensors_hooks annotation
pipeline.

Keeps the CPU loop as a fallback for non-CUDA devices (randperm on
CPU is a single call, no benefit from batching).

Standalone microbenchmark (batch=2, n_heads=8 or 32, num=64, length=2048):
- Small (n_kv_heads=8):  0.985 ms -> 0.083 ms  (11.9x)
- Large (n_heads=32):    3.832 ms -> 0.082 ms  (46.7x)

End-to-end on Qwen3-4B, Helmet nq 64k, A100-40GB, stacked on top of all
three Triton fused kernels + RoPE cat skip:
- Training (6 steps): 1120.23s -> 1077.40s  (-3.8%)
- Validation:         1004.14s -> 1007.39s  (~neutral, no annotations
                                             in inference path)
- Total wall-clock:   2124.37s -> 2084.79s  (-1.9%)
- val_loss:           18.440 -> 18.237 (within baseline variance)

Cumulative vs untouched baseline: 2309.88s -> 2084.79s (-9.7%).
`_inference_forward_pass` had a `torch.cuda.synchronize()` at the end of
each middle-loop (per-layer) iteration, intended to wait for the D2H of
the just-written layer-input checkpoint before overwriting the shared
quantizer buffer on the next layer.

The D2H in `QuantizerState.copy_` already uses `non_blocking=True` on
the default stream, and the next layer's quantize kernel runs on the
same stream, so GPU-level ordering was already correct without the
sync. The sync only stalled the CPU thread; GPU work wasn't waiting
for it.

Removing the sync lets CPU dispatch run ahead of the GPU pipeline
again. The original TODO comment suspected this already.

Measured at Qwen3-4B on A100-40GB, Helmet nq 64k, cpu_offload=True,
stacked on top of all six prior async_offload commits:
- Step 1:  276.6s -> 261.8s  (-5%)
- Step 3:  154.7s -> 148.7s  (-4%)
- 3 steps: 612.8s -> 592.1s  (-3.4%)
- Losses within prior variance; no corruption.

A followup (`d2h_stream` + event-based ordering) would let the D2H
actually overlap with the next layer's quantize on a separate stream;
this commit is the zero-infrastructure part of that path.
Copy link
Copy Markdown
Contributor

@mseeger mseeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some comments. But otherwise, can merge this, and I can do the comparison tests, unless you have the code for them already.

FlashInfer SDPA if summed attention weights are required. If
`flex_attention == False`, this kernel is also used if attention
weights are not needed.
fused_rope: If `True`, replace the eager rotary position embedding
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these compiled on the fly? I don't know how Triton works?

And if they speed up things, we should maybe make the default True, no?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first run it is compiled and then you see the speed ups for the remaining steps.

if self._verbose_more:
print("Forward pass to store KV cache checkpoints")
infer_replay_caches = self._create_inference_replay_caches(model_part)
torch.cuda.nvtx.range_push("compute_checkpoints")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this again. Also the other nvtx commands below.

for b in range(shape[0]):
for h in range(shape[1]):
result[b, h, :] = torch.randperm(length, **index_kwargs)[:num]
# Batched random permutation: draw uniform random values of shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, and nice solution (create random numbers and just sort them).

This should be a function, but I can do such clean-up later.

perms = perms[..., :num]
result = perms.to(dtype=dtype)
else:
# CPU fallback: keep the original loop — randperm on CPU is a single
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is wrong with the code above? I'd do the same on CPU, it is always a better solution

snapshots = self._record_gpu_memory_snapshots
else:
snapshots = None
torch.cuda.nvtx.range_push(f"accumulator_run_{first_layer_idx}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove all nvtx annotations

@@ -0,0 +1,447 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, we need tests to compare these fused operators with the PyTorch code.

I can do this, don't worry. Let us merge it. Unless you have this comparison code already, then let us use it.

Comment thread keys_values/fused_rope.py
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fused Triton kernel for rotary position embedding (RoPE).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here, need comparison tests. I can do this, unless you already have it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I did for the comparison was just run the entire setup with and without it. And compared the resulting timing results.

# TODO: Is this really a problem? We don't modify
# `embeddings`, but just unlink it.
torch.cuda.synchronize()
# No explicit sync needed between layers: the D2H in
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! I was just not sure

Comment thread keys_values/model.py
)
q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1)
k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1)
# Skip the outer cat when rope covers the full head dim (e.g.,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
if _USE_FUSED_ROPE:
from keys_values.fused_rope import fused_apply_rope, can_use_fused_rope
if can_use_fused_rope(x, cos, sin):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this defined? Is this included somewhere?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the fused triton kernel for rope.

Copy link
Copy Markdown
Contributor

@mseeger mseeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving as is. I'll follow up with another PR adding tests

@mseeger mseeger merged commit 3168e3f into awslabs:main May 7, 2026
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants