Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions keys_values/finetune/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,26 @@ class SDPAArgs:
is available.
dynamo_cache_size_limit: Value for `torch._dynamo.config.cache_size_limit`.
Defaults to 32. The built-in default 8 is too small for our purposes.
fused_rope: If `True`, replace the eager rotary position embedding
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these compiled on the fly? I don't know how Triton works?

And if they speed up things, we should maybe make the default True, no?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first run it is compiled and then you see the speed ups for the remaining steps.

(`apply_rope`) with a single fused Triton kernel. Falls back to
eager automatically when Triton is unavailable or the input shape
is incompatible. Correctness is verified against an fp64
reference; the fused kernel accumulates in fp32 internally and is
typically *more* accurate than eager in bf16/fp16. Measured at
Qwen3-4B on A100-40GB: ~2% end-to-end speedup, val_loss matches
or improves. See `keys_values/fused_rope.py`.
fused_rmsnorm: If `True`, patch both `keys_values.model.RMSNorm` and
`litgpt.model.RMSNorm` so their `forward` dispatches to a fused
Triton kernel. Falls back to the original eager forward when
Triton is unavailable, the tensor is on CPU, or the input shape
is unsupported. Correctness verified against an fp64 reference.
See `keys_values/fused_rmsnorm.py`.
fused_swiglu: If `True`, patch `LLaMAMLP.forward` (both
`keys_values.lora` and `litgpt.model` variants) so the
`F.silu(x_fc_1) * x_fc_2` step runs as a single fused Triton
kernel instead of two eager kernels. Falls back to eager when
inputs are not on CUDA or dtypes mismatch. Correctness verified
against an fp64 reference. See `keys_values/fused_swiglu.py`.
flashinfer_attention: If `True` and FlashInfer is available, we use
FlashInfer SDPA if summed attention weights are required. If
`flex_attention == False`, this kernel is also used if attention
Expand All @@ -597,6 +617,9 @@ class SDPAArgs:
reorder_sort_if_3d: bool = True
use_flex_for_attn_weights: bool = True
dynamo_cache_size_limit: int = 32
fused_rope: bool = False
fused_rmsnorm: bool = False
fused_swiglu: bool = False
flashinfer_attention: bool = True

def __post_init__(self):
Expand Down
6 changes: 6 additions & 0 deletions keys_values/finetune/longcontext_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,12 @@ def get_mha_and_cache_kwargs(
init_val=limit_gb,
name="attention_forward_temp_size_gb",
)
from keys_values.pos_encoding import set_fused_rope_enabled
set_fused_rope_enabled(sdpa.fused_rope)
from keys_values.fused_rmsnorm import set_fused_rmsnorm_enabled
set_fused_rmsnorm_enabled(sdpa.fused_rmsnorm)
from keys_values.fused_swiglu import set_fused_swiglu_enabled
set_fused_swiglu_enabled(sdpa.fused_swiglu)
mha_kwargs: Dict[str, Any] = dict(
tmp_array_limit_gb=tmp_array_limit_forward,
pos_encoding=position_encoding_factory(config, do_yarn=yarn_rope),
Expand Down
Loading
Loading