Skip to content

Power sgd #196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "ZeroBand is a production ready codebase for decentralized trainin
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch==2.5.1",
"torch==2.6.0",
"numpy",
"setuptools",
"transformers>=4.44.2",
Expand Down
5 changes: 3 additions & 2 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def non_error_barrier():
dist.barrier()
except Exception as e:
from zeroband.utils.logging import get_logger

get_logger().info(f"Error in data checkpointing barrier: {e}, continuing training")


Expand All @@ -174,8 +175,8 @@ def __init__(
self,
config: CkptConfig,
model: nn.Module,
optimizer: Optimizer,
scheduler: LambdaLR,
optimizer: list[Optimizer],
scheduler: list[LambdaLR],
dataloader: StatefulDataLoader,
training_progress: TrainingProgress,
data_rank: int | None,
Expand Down
35 changes: 31 additions & 4 deletions src/zeroband/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ class DataConfig(BaseConfig):
reverse_data_files: bool = False
split_by_data_rank: bool = True


class AdamConfig(BaseConfig):
type: Literal["adam"] = "adam" # the literal is used to distinguish between the different optimizers configuration in the union type
type: Literal["adam"] = (
"adam" # the literal is used to distinguish between the different optimizers configuration in the union type
)
lr: float = 4e-4
weight_decay: float = 0.1
betas1: float = 0.9
Expand All @@ -41,11 +44,34 @@ class SoapConfig(BaseConfig):
precondition_frequency: int = 100


OptimizersConfig: TypeAlias = AdamConfig | SoapConfig
class MuonConfig(BaseConfig):
type: Literal["muon"] = "muon"
ns_steps: int = 5
lr: float = 0.02
momentum: float = 0.95
nesterov: bool = True
compression_ratio: float | None = None
compression_step_start: int = 0
lie_compression: bool = False

@model_validator(mode="after")
def calidate_compression(self):
if self.compression_ratio is not None:
assert 0 < self.compression_ratio <= 1, "compression_ratio must be between 0 and 1"
return self


OptimizersConfig: TypeAlias = AdamConfig | SoapConfig | MuonConfig


class PowerSGDConfig(BaseConfig):
rank: int = 1
warmup_steps: int = 1000


class OptimConfig(BaseConfig):
optim: OptimizersConfig = AdamConfig()
power_sgd: PowerSGDConfig | None = None

lr: float = 4e-4
weight_decay: float = 0.1
Expand All @@ -70,6 +96,7 @@ class DilocoConfig(BaseConfig):

retry_all_reduce: int = 3


class MemoryProfilerConfig(BaseConfig):
freq: int = 10
snapshot_dir: str
Expand Down Expand Up @@ -231,7 +258,8 @@ def get_env_config(config: Config | None, item: str | None, default: Any | None

return cfg

def get_env_config_bool(config: Config | None, item: str | None, default: bool | None = None) -> bool:

def get_env_config_bool(config: Config | None, item: str | None, default: bool | None = None) -> bool:
"""
Call get_env_config and convert strings to bools where makes sense.

Expand All @@ -248,4 +276,3 @@ def get_env_config_bool(config: Config | None, item: str | None, default: bool
if isinstance(val, str):
return val.lower() == "true" or val.lower() == "1"
return bool(val)

3 changes: 2 additions & 1 deletion src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE
from torch.nn.attention import SDPBackend, sdpa_kernel

_flex_attention_compiled = torch.compile(flex_attention, dynamic=False)
_flex_attention_compiled = torch.compile(flex_attention, dynamic=True)
# _flex_attention_compiled = flex_attention


# copied from https://github.com/pytorch/torchtune/blob/f2bd4bc25b24587aef40f486087412b9da8f1d94/torchtune/modules/attention_utils.py#L27
Expand Down
273 changes: 273 additions & 0 deletions src/zeroband/muon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
# copied from https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
import torch
from torch import Tensor
import torch.distributed as dist


@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert (
G.ndim >= 2
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT

# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = (
b * A + c * A @ A
) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X

if G.size(-2) > G.size(-1):
X = X.mT
return X


@torch.compile
def low_rank_approximation_zeropower_via_newtonschulz5(G: Tensor, rank: int, steps: int = 5) -> tuple[Tensor, Tensor]:
"""
Compute a low-rank approximation of matrix G using Newton-Schulz iteration.
Returns the expanded approximated matrix directly.

Args:
G: Input tensor of shape (..., m, n)
rank: Target rank for the approximation
steps: Number of Newton-Schulz iterations

Returns:
G_approx: Low rank approximation of G with the same shape as G
"""
assert G.ndim >= 2
assert rank > 0 and rank <= min(G.size(-2), G.size(-1))

# Constants for quintic iteration
a, b, c = (3.4445, -4.7750, 2.0315)

# Convert to bfloat16
G = G.bfloat16()

# Initialize random projection matrix Q in bfloat16
n = G.size(-1)
Q = torch.randn((*G.shape[:-2], n, rank), device=G.device).bfloat16()
Q = Q / (Q.norm(dim=(-2, -1), keepdim=True) + 1e-7)

# Power iteration to find approximate range
Y = G @ Q

# Normalize Y
Y = Y / (Y.norm(dim=(-2, -1), keepdim=True) + 1e-7)

# Newton-Schulz iterations for orthogonalization
for _ in range(steps):
A = Y @ Y.mT
B = b * A + c * A @ A
Y = a * Y + B @ Y

# Compute factors and immediately expand
U = Y
V = G.mT @ U
G_approx = U @ V.mT

return G_approx


@torch.compile
def low_rank_approximation_via_newtonschulz_lie(G: Tensor, rank: int, steps: int = 5) -> Tensor:
"""
Compute a low-rank approximation of matrix G using Newton-Schulz iteration with Lie group structure.
Returns the approximated matrix in the form Q = (I + UVᵀ)diag(d).

Args:
G: Input tensor of shape (..., m, n)
rank: Target rank for the approximation
steps: Number of Newton-Schulz iterations

Returns:
G_approx: Low rank approximation of G with the same shape as G
"""
assert G.ndim >= 2
assert rank > 0 and rank <= min(G.size(-2), G.size(-1))

# Constants for quintic iteration
a, b, c = (3.4445, -4.7750, 2.0315)

# Convert to bfloat16
G = G.bfloat16()
m, n = G.size(-2), G.size(-1)

# Initialize random projection matrix Q
Q = torch.randn((*G.shape[:-2], n, rank), device=G.device).bfloat16()
Q = Q / (Q.norm(dim=(-2, -1), keepdim=True) + 1e-7)

# Power iteration to find approximate range
Y = G @ Q
Y = Y / (Y.norm(dim=(-2, -1), keepdim=True) + 1e-7)

# Newton-Schulz iterations for orthogonalization
for _ in range(steps):
A = Y @ Y.mT
B = b * A + c * A @ A
Y = a * Y + B @ Y

# Compute U factor
U = Y

# Compute V factor through projection
V = G.mT @ U

# Normalize U and V to have unit norm
U_norms = torch.sum(U * U, dim=-1, keepdim=True).sqrt()
V_norms = torch.sum(V * V, dim=-1, keepdim=True).sqrt()

U = U / (U_norms + 1e-7)
V = V / (V_norms + 1e-7)

# Create identity matrix of appropriate size
Id = torch.eye(m, device=G.device, dtype=G.dtype)
Id = Id.expand(*G.shape[:-2], m, m)

# Compute diagonal scaling
d = torch.diagonal(G @ G.mT, dim1=-2, dim2=-1).sqrt()
d = d / (d.norm(dim=-1, keepdim=True) + 1e-7)

# Construct final approximation Q = (I + UVᵀ)diag(d)
G_approx = U @ V.mT

# Scale the approximation to match G's magnitude
G_norms = torch.sum(G * G, dim=(-2, -1), keepdim=True).sqrt()
G_approx_norms = torch.sum(G_approx * G_approx, dim=(-2, -1), keepdim=True).sqrt()
scale = G_norms / (G_approx_norms + 1e-7)
G_approx = G_approx * scale

return G_approx


class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.

Some warnings:
- This optimizer assumes that all parameters passed in are 2D.
- It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
- We believe it is unlikely to work well for training with small batch size.
- We believe it may not work well for finetuning pretrained models, but we haven"t tested this.
- We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).

Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use.
"""

def __init__(
self,
params,
lr=0.02,
momentum=0.95,
nesterov=True,
ns_steps=5,
rank=0,
world_size=1,
compression_ratio: float | None = None,
compression_step_start: int = 0,
lie_compression: bool = False,
):
self.rank = rank
self.world_size = world_size
self.compression_ratio = compression_ratio
self.compression_step_start = compression_step_start
self.lie_compression = lie_compression
self._step_count = 0 # Add step counter

defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params: list[Tensor] = [*params]
assert all(isinstance(p, Tensor) for p in params)
sizes = {p.numel() for p in params}

def create_update_buffer(size: int):
b = torch.empty(self.world_size, size, dtype=torch.bfloat16, device="cuda")
return dict(update_buffer=b, update_buffer_views=[b[i] for i in range(self.world_size)])

param_groups = [
dict(params=[p for p in params if p.numel() == size], **create_update_buffer(size)) for size in sizes
]
super().__init__(param_groups, defaults)

@torch.no_grad()
def step(self):
self._step_count += 1 # Increment step counter

for group in self.param_groups:
lr = group["lr"]
momentum = group["momentum"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
update_buffer = group["update_buffer"]
update_buffer_views: list[Tensor] = group["update_buffer_views"]
params: list[Tensor] = group["params"]
handle = None
params_world = None

def update_prev():
if params_world is None:
return
assert handle is not None
handle.wait()
for p_world, g_world in zip(params_world, update_buffer_views):
p_world.add_(
g_world.view_as(p_world),
alpha=-lr * max(1, p_world.size(-2) / p_world.size(-1)) ** 0.5,
)

for base_i in range(len(params))[:: self.world_size]:
if base_i + self.rank < len(params):
p = params[base_i + self.rank]
g = p.grad
assert g is not None
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - momentum)
g = g.lerp_(buf, momentum) if nesterov else buf

# Only apply compression if we've reached the start step and compression ratio is set
if self.compression_ratio is not None and self._step_count >= self.compression_step_start:
mat_rank = int(g.shape[0] * self.compression_ratio)
if self.lie_compression:
g = low_rank_approximation_via_newtonschulz_lie(g, mat_rank, steps=ns_steps).flatten()
else:
g = low_rank_approximation_zeropower_via_newtonschulz5(
g, mat_rank, steps=ns_steps
).flatten()
else:
g = zeropower_via_newtonschulz5(g, steps=ns_steps).flatten()
else:
g = update_buffer_views[self.rank]
update_prev()
handle = dist.all_gather_into_tensor(update_buffer, g, async_op=True)
params_world = params[base_i : base_i + self.world_size]
update_prev()
Loading