Skip to content

Commit

Permalink
[Not for land] GaLore example
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
awgu committed Jul 29, 2024
1 parent b322446 commit b1225c2
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 11 deletions.
24 changes: 24 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,30 @@ def __init__(self):
action="store_true",
help="Whether the fused implementation(CUDA only) is used.",
)
self.parser.add_argument(
"--optimizer.galore_rank", type=int, default=128, help="GaLore rank"
)
self.parser.add_argument(
"--optimizer.galore_update_proj_gap",
type=int,
default=200,
help="GaLore update projection gap",
)
self.parser.add_argument(
"--optimizer.galore_scale", type=float, default=1.0, help="GaLore scale"
)
self.parser.add_argument(
"--optimizer.galore_proj_type",
type=str,
default="std",
help="GaLore projection type",
)
self.parser.add_argument(
"--optimizer.galore_in_backward",
default=False,
action="store_true",
help="Whether to apply GaLore in backward"
)

# training configs
self.parser.add_argument(
Expand Down
22 changes: 11 additions & 11 deletions torchtitan/lr_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,6 @@ def linear_warmup_linear_decay(


def get_lr_schedulers(optimizers, job_config: JobConfig):
def _get_lr_scheduler(optimizer):
"""Build a linear warmup and linear decay scheduler"""
warmup_steps = int(job_config.training.warmup_steps)
decay_steps = float(max(1, job_config.training.steps - warmup_steps))
lr_lambda = functools.partial(
linear_warmup_linear_decay, warmup_steps, decay_steps
)
warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
return warmup_scheduler

class SchedulersContainer:
"""Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages"""

Expand All @@ -54,5 +44,15 @@ def step(self):
schedulers.step()

return SchedulersContainer(
[_get_lr_scheduler(optimizer) for optimizer in optimizers]
[_get_lr_scheduler(job_config, optimizer) for optimizer in optimizers]
)

def _get_lr_scheduler(job_config: JobConfig, optimizer):
"""Build a linear warmup and linear decay scheduler"""
warmup_steps = int(job_config.training.warmup_steps)
decay_steps = float(max(1, job_config.training.steps - warmup_steps))
lr_lambda = functools.partial(
linear_warmup_linear_decay, warmup_steps, decay_steps
)
warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
return warmup_scheduler
210 changes: 210 additions & 0 deletions torchtitan/optims/galore_adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""
Credit: https://github.com/jiaweizzhao/GaLore/tree/master
(copied over and condensed for convenience)
"""

import math
from typing import Callable, Iterable, Tuple

import torch
import torch.nn as nn

from torchtitan.logging_utils import logger


class GaLoreAdamW(torch.optim.Optimizer):
def __init__(
self,
params: Iterable[nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)"
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)"
)
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
"correct_bias": correct_bias,
}
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure: Callable = None):
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad

state = self.state[p]
if "step" not in state:
state["step"] = 0
if "dim" not in group:
group["dim"] = 2

if "rank" in group:
if "projector" not in state:
state["projector"] = GaLoreProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
scale=group["scale"],
proj_type=group["proj_type"],
)
grad = state["projector"].project(grad, state["step"])

if "exp_avg" not in state:
state["exp_avg"] = torch.zeros_like(
grad, memory_format=torch.preserve_format
)
state["exp_avg_sq"] = torch.zeros_like(
grad, memory_format=torch.preserve_format
)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]

state["step"] += 1
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]:
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = (
step_size * math.sqrt(bias_correction2) / bias_correction1
)

norm_grad = exp_avg / denom

if "rank" in group:
norm_grad = state["projector"].project_back(norm_grad)

p.add_(norm_grad, alpha=-step_size)

if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))

return loss


class GaLoreProjector:
def __init__(
self,
rank: int,
update_proj_gap: int = 200,
scale: float = 1.0,
proj_type: str = "std",
):
self.rank = rank
self.update_proj_gap = update_proj_gap
self.scale = scale
self.ortho_matrix = None
self.proj_type = proj_type

def project(self, full_rank_grad: torch.Tensor, iter_idx: int):
if self.proj_type == "std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right"
)
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
else:
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left"
)
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "reverse_std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left"
)
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
else:
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right"
)
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == "right":
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right"
)
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == "left":
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left"
)
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "full":
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="full"
)
low_rank_grad = (
torch.matmul(self.ortho_matrix[0].t(), full_rank_grad)
@ self.ortho_matrix[1].t()
)

return low_rank_grad

def project_back(self, low_rank_grad: torch.Tensor):
if self.proj_type == "std":
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
else:
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == "reverse_std":
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]:
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
else:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == "right":
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == "left":
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == "full":
full_rank_grad = (
torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
)

return full_rank_grad * self.scale

def get_orthogonal_matrix(self, param: torch.Tensor, rank: int, type: str):
U, s, Vh = torch.linalg.svd(param.detach(), full_matrices=False)
if type == "right":
B = Vh[:rank, :].to(device=param.device, dtype=param.dtype)
return B
elif type == "left":
A = U[:, :rank].to(device=param.device, dtype=param.dtype)
return A
elif type == "full":
A = U[:, :rank].to(device=param.device, dtype=param.dtype)
B = Vh[:rank, :].to(device=param.device, dtype=param.dtype)
return [A, B]
else:
raise ValueError(f"type should be left, right, or full but got {type}")
59 changes: 59 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,65 @@ def _build_optimizer(model):
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
elif name == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
elif name == "GaLoreAdamW":
from torchtitan.optims.galore_adamw import GaLoreAdamW

optimizer_kwargs.pop("fused")
optimizer_kwargs.pop("foreach")
galore_kwargs = {}
galore_kwargs["rank"] = job_config.optimizer.galore_rank
galore_kwargs["update_proj_gap"] = (
job_config.optimizer.galore_update_proj_gap
)
galore_kwargs["scale"] = job_config.optimizer.galore_scale
galore_kwargs["proj_type"] = job_config.optimizer.galore_proj_type
nongalore_params = []
galore_params = []
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
if (
isinstance(module, torch.nn.Linear)
and "weight" in param_name
and (
"attention" in module_name or "feed_forward" in module_name
)
):
galore_params.append(param)
else:
nongalore_params.append(param)
if not job_config.optimizer.galore_in_backward:
param_groups = [
{"params": nongalore_params},
{"params": galore_params, **galore_kwargs},
]
optimizer = GaLoreAdamW(param_groups, **optimizer_kwargs)
return optimizer
else:
from torchtitan.lr_scheduling import _get_lr_scheduler

param_to_optim: Dict[nn.Parameter, torch.optim.Optimizer] = {}
for param in nongalore_params:
param_to_optim[param] = GaLoreAdamW([param], **optimizer_kwargs)
for param in galore_params:
param_group = [{"params": [param], **galore_kwargs}]
param_to_optim[param] = GaLoreAdamW(param_group, **optimizer_kwargs)

param_to_scheduler: Dict[nn.Parameter, torch.optim.LRScheduler] = {}
for param, optim in param_to_optim.items():
param_to_scheduler[param] = _get_lr_scheduler(job_config, optim)

def optimizer_hook(param: torch.nn.Parameter) -> None:
if param.grad is None:
return
optim = param_to_optim[param]
optim.step()
optim.zero_grad()
param_to_scheduler[param].step()

for param in param_to_optim:
param.register_post_accumulate_grad_hook(optimizer_hook)

return GaLoreAdamW([torch.empty(0)], **optimizer_kwargs)
else:
raise NotImplementedError(f"Optimizer {name} not added.")

Expand Down

0 comments on commit b1225c2

Please sign in to comment.