Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/launch_instance.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
* AZ: `us-west-2b`

```bash
ssh -i "matthis_deeplearning_uswest2.pem" ubuntu@ec2-35-85-224-176.us-west-2.compute.amazonaws.com
ssh -i "~/matthis_deeplearning_uswest2.pem" ubuntu@ec2-35-85-224-176.us-west-2.compute.amazonaws.com
```

### Instance `P4Research2`
Expand All @@ -24,7 +24,7 @@ ssh -i "matthis_deeplearning_uswest2.pem" ubuntu@ec2-35-85-224-176.us-west-2.com
* AZ: `us-west-2c`

```bash
ssh -i "matthis_deeplearning_uswest2.pem" ubuntu@ec2-34-209-209-37.us-west-2.compute.amazonaws.com
ssh -i "~/matthis_deeplearning_uswest2.pem" ubuntu@ec2-34-209-209-37.us-west-2.compute.amazonaws.com
```

### Instance `P4Research3`
Expand All @@ -38,7 +38,7 @@ Note: `P4Research2` and `P4Research2` share the same EFS volume.
* AZ: `us-west-2c`

```bash
ssh -i "matthis_deeplearning_uswest2.pem" ubuntu@ec2-16-147-216-186.us-west-2.compute.amazonaws.com
ssh -i "~/matthis_deeplearning_uswest2.pem" ubuntu@ec2-16-147-216-186.us-west-2.compute.amazonaws.com
```


Expand Down
10 changes: 10 additions & 0 deletions keys_values/finetune/longcontext_eval_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,14 @@
print_with_rank_and_timestamp,
adjust_cache_kwargs,
)
from keys_values.fused import (
set_fused_swiglu_enabled,
set_fused_rmsnorm_enabled,
)
from keys_values.head_model_factory import HeadModelFactory
from keys_values.long_context import LongContextInferenceModel
from keys_values.lora import Config as ConfigLoRA
from keys_values.pos_encoding import set_fused_rope_enabled
from keys_values.utils import (
flush_io_streams,
VerbosityLevels,
Expand Down Expand Up @@ -372,6 +377,11 @@ def main(
data_class_path = _data_class_path
data_init_args = _data_init_args

# Enable/disable fused operators
set_fused_rope_enabled(sdpa.fused_rope)
set_fused_rmsnorm_enabled(sdpa.fused_rmsnorm)
set_fused_swiglu_enabled(sdpa.fused_swiglu)

# Create model
is_lora = model_type == "lora"
if torch.cuda.is_available():
Expand Down
22 changes: 12 additions & 10 deletions keys_values/finetune/longcontext_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@
adjust_cache_kwargs,
copy_config_files,
)
from keys_values.fused import (
set_fused_swiglu_enabled,
set_fused_rmsnorm_enabled,
)
from keys_values.generate.base import generate
from keys_values.gpu_memory import RecordGPUMemory
from keys_values.head_model import HeadModel, CrossEntropyOnLogits
Expand Down Expand Up @@ -122,7 +126,10 @@
from keys_values.optimize.grad_accumulate import CPUOffloadAccumulateGradients
from keys_values.optimize.model_factory import BlockComponentName
from keys_values.parser_config import save_hyperparameters
from keys_values.pos_encoding import position_encoding_factory
from keys_values.pos_encoding import (
position_encoding_factory,
set_fused_rope_enabled,
)
from keys_values.tools.size_log import (
SizeWeightsGradientsLog,
SizeLogMapper,
Expand Down Expand Up @@ -689,6 +696,10 @@ def main(
else:
cpu_offload_device = None
optim_device = fabric.device
# Enable/disable fused operators
set_fused_rope_enabled(sdpa.fused_rope)
set_fused_rmsnorm_enabled(sdpa.fused_rmsnorm)
set_fused_swiglu_enabled(sdpa.fused_swiglu)

if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)
Expand Down Expand Up @@ -985,15 +996,6 @@ def get_mha_and_cache_kwargs(
init_val=limit_gb,
name="attention_forward_temp_size_gb",
)
from keys_values.pos_encoding import set_fused_rope_enabled

set_fused_rope_enabled(sdpa.fused_rope)
from keys_values.fused_rmsnorm import set_fused_rmsnorm_enabled

set_fused_rmsnorm_enabled(sdpa.fused_rmsnorm)
from keys_values.fused_swiglu import set_fused_swiglu_enabled

set_fused_swiglu_enabled(sdpa.fused_swiglu)
mha_kwargs: Dict[str, Any] = dict(
tmp_array_limit_gb=tmp_array_limit_forward,
pos_encoding=position_encoding_factory(config, do_yarn=yarn_rope),
Expand Down
38 changes: 38 additions & 0 deletions keys_values/fused/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keys_values.fused.fused_rmsnorm import (
can_use_fused_rmsnorm,
fused_rmsnorm,
set_fused_rmsnorm_enabled,
)
from keys_values.fused.fused_rope import (
can_use_fused_rope,
fused_apply_rope,
)
from keys_values.fused.fused_swiglu import (
can_use_fused_swiglu,
fused_swiglu,
set_fused_swiglu_enabled,
)

__all__ = [
"can_use_fused_rmsnorm",
"can_use_fused_rope",
"can_use_fused_swiglu",
"fused_apply_rope",
"fused_rmsnorm",
"fused_swiglu",
"set_fused_rmsnorm_enabled",
"set_fused_swiglu_enabled",
]
95 changes: 48 additions & 47 deletions keys_values/fused_rmsnorm.py → keys_values/fused/fused_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fused Triton kernel for RMSNorm.

Replaces the eager RMSNorm forward (x.float() + x*x + mean + rsqrt + x*rsqrt
+ mul weight + cast = ~5-6 kernels) with a single Triton kernel. Similarly
for backward. Implemented as torch.autograd.Function so the training cell
loop's saved_tensors_hooks see a normal autograd function.

Forward:
y = (x / sqrt(mean(x**2) + eps)) * weight (add_unit_offset=False)
y = (x / sqrt(mean(x**2) + eps)) * (1 + weight) (add_unit_offset=True)

Backward (per row, with w' = weight or 1+weight):
r = rsqrt(mean_sq + eps)
dL/dw = sum_over_batch( dL/dy * x * r ) (in fp32, reduced across all rows)
dL/dx = r * w' * dL/dy - (r**3 / D) * (sum_j(dL/dy_j * w'_j * x_j)) * x
where D is the norm dim size (last dim).
"""

import torch

_triton_available = False
Expand Down Expand Up @@ -227,41 +209,33 @@ def _fused_rmsnorm_bwd_dw_reduce_kernel(
tl.store(GradW_ptr + d_offsets, acc, mask=d_mask)


def can_use_fused_rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
dim: int,
) -> bool:
"""Check if `fused_rmsnorm` can handle this input."""
if not _triton_available:
return False
if not x.is_cuda:
return False
if x.dtype not in (torch.float32, torch.float16, torch.bfloat16):
return False
if x.dim() < 2:
return False
# Only support reducing along the last dim (the common case)
if dim != -1 and dim != x.dim() - 1:
return False
D = x.shape[-1]
if weight.numel() != D:
return False
# Triton block size must fit the hidden dim; cap at 16384 to stay within
# shared memory budgets
if D > 16384:
return False
return True


def _next_power_of_two(n: int) -> int:
p = 1
while p < n:
p *= 2
return p


class _FusedRMSNorm(torch.autograd.Function):
class FusedRMSNorm(torch.autograd.Function):
"""Fused Triton kernel for RMSNorm.

Replaces the eager RMSNorm forward (x.float() + x*x + mean + rsqrt + x*rsqrt
+ mul weight + cast = ~5-6 kernels) with a single Triton kernel.
Implemented as torch.autograd.Function so the training cell
loop's saved_tensors_hooks see a normal autograd function.

Forward:
y = (x / sqrt(mean(x**2) + eps)) * weight (add_unit_offset=False)
y = (x / sqrt(mean(x**2) + eps)) * (1 + weight) (add_unit_offset=True)

Backward (per row, with w' = weight or 1+weight):
r = rsqrt(mean_sq + eps)
dL/dw = sum_over_batch( dL/dy * x * r ) (in fp32, reduced across all rows)
dL/dx = r * w' * dL/dy - (r**3 / D) * (sum_j(dL/dy_j * w'_j * x_j)) * x
where D is the norm dim size (last dim).

"""

@staticmethod
def forward(ctx, x, weight, eps, add_unit_offset):
if not can_use_fused_rmsnorm(x, weight, -1):
Expand Down Expand Up @@ -397,6 +371,33 @@ def backward(ctx, grad_out):
return grad_x.view(ctx.original_shape), grad_w, None, None


def can_use_fused_rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
dim: int,
) -> bool:
"""Check if `fused_rmsnorm` can handle this input."""
if not _triton_available:
return False
if not x.is_cuda:
return False
if x.dtype not in (torch.float32, torch.float16, torch.bfloat16):
return False
if x.dim() < 2:
return False
# Only support reducing along the last dim (the common case)
if dim != -1 and dim != x.dim() - 1:
return False
D = x.shape[-1]
if weight.numel() != D:
return False
# Triton block size must fit the hidden dim; cap at 16384 to stay within
# shared memory budgets
if D > 16384:
return False
return True


def fused_rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
Expand All @@ -414,7 +415,7 @@ def fused_rmsnorm(
Returns:
Normalized tensor, same shape and dtype as x.
"""
return _FusedRMSNorm.apply(x, weight, eps, add_unit_offset)
return FusedRMSNorm.apply(x, weight, eps, add_unit_offset)


# Module-level flag controlling whether RMSNorm classes are patched to use the
Expand Down
93 changes: 47 additions & 46 deletions keys_values/fused_rope.py → keys_values/fused/fused_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fused Triton kernel for rotary position embedding (RoPE).

Replaces the eager apply_rope() sequence (slice + negate + cat + mul + mul +
add + to-dtype = ~5-6 kernels) with a single Triton kernel. Forward and
backward are both implemented as torch.autograd.Function, so the training
cell loop's saved_tensors_hooks see normal autograd functions and continue
to work unchanged.

Forward:
y = x * cos + rot(x) * sin
where rot(x) = cat(-x[..., half:], x[..., :half], dim=-1).

Backward (see pos_encoding.py derivation):
dL/dx_j for j < half = dL/dy_j * cos_j + dL/dy_{j+half} * sin_{j+half}
dL/dx_j for j >= half = dL/dy_j * cos_j - dL/dy_{j-half} * sin_{j-half}
"""

from typing import Tuple

import torch
Expand Down Expand Up @@ -186,33 +169,6 @@ def _fused_rope_bwd_kernel(
tl.store(GradX_ptr + gx_offsets, gx, mask=t_mask[:, None])


def can_use_fused_rope(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> bool:
"""Check if `fused_apply_rope` can handle this input."""
if not _triton_available:
return False
if not x.is_cuda:
return False
if x.dtype not in (torch.float32, torch.float16, torch.bfloat16):
return False
if x.dim() < 2:
return False
D = x.shape[-1]
if D % 2 != 0:
return False
if cos.shape != sin.shape:
return False
if cos.shape[-1] != D:
return False
T = x.shape[-2]
if cos.numel() != T * D:
return False
return True


def _reshape_inputs(
x: torch.Tensor,
cos: torch.Tensor,
Expand All @@ -230,7 +186,25 @@ def _reshape_inputs(
return x_view, cos_view, sin_view, BH, T, D, original_shape


class _FusedRope(torch.autograd.Function):
class FusedRoPE(torch.autograd.Function):
"""Fused Triton kernel for rotary position embedding (RoPE).

Replaces the eager apply_rope() sequence (slice + negate + cat + mul + mul +
add + to-dtype = ~5-6 kernels) with a single Triton kernel. Forward and
backward are both implemented as torch.autograd.Function, so the training
cell loop's saved_tensors_hooks see normal autograd functions and continue
to work unchanged.

Forward:
y = x * cos + rot(x) * sin
where rot(x) = cat(-x[..., half:], x[..., :half], dim=-1).

Backward (see pos_encoding.py derivation):
dL/dx_j for j < half = dL/dy_j * cos_j + dL/dy_{j+half} * sin_{j+half}
dL/dx_j for j >= half = dL/dy_j * cos_j - dL/dy_{j-half} * sin_{j-half}

"""

@staticmethod
def forward(ctx, x, cos, sin):
if not can_use_fused_rope(x, cos, sin):
Expand Down Expand Up @@ -309,6 +283,33 @@ def backward(ctx, grad_out):
return grad_x.view(ctx.original_shape), None, None


def can_use_fused_rope(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> bool:
"""Check if `fused_apply_rope` can handle this input."""
if not _triton_available:
return False
if not x.is_cuda:
return False
if x.dtype not in (torch.float32, torch.float16, torch.bfloat16):
return False
if x.dim() < 2:
return False
D = x.shape[-1]
if D % 2 != 0:
return False
if cos.shape != sin.shape:
return False
if cos.shape[-1] != D:
return False
T = x.shape[-2]
if cos.numel() != T * D:
return False
return True


def fused_apply_rope(
x: torch.Tensor,
cos: torch.Tensor,
Expand All @@ -324,4 +325,4 @@ def fused_apply_rope(
Returns:
RoPE-transformed tensor, same shape and dtype as x.
"""
return _FusedRope.apply(x, cos, sin)
return FusedRoPE.apply(x, cos, sin)
Loading
Loading