-
Notifications
You must be signed in to change notification settings - Fork 14
Open
Description
Hi, I tried to apply apollo to qwen3 finetuning and it seems the efficiency, throughput and GPU memory are not getting better than adam. I wonder if you have any insights? I debuged it and doesnt seem to be any misuse:
def build_apollo_param_groups(
model: nn.Module,
*,
# selection
target_modules: Sequence[str] = ("attn", "mlp"),
# APOLLO hyperparameters (explicit, no args object)
rank: int = 256, # --rank
update_proj_gap: int = 200, # --update_proj_gap
apollo_scale: float = 1.0, # --apollo_scale
proj_type: str = "std", # --proj_type
proj: str = "random", # --proj
scale_type: str = "channel", # --scale_type
) -> Tuple[List[dict], List[nn.Parameter], List[nn.Parameter]]:
"""
Build APOLLO param groups:
- Low-rank group: weights of nn.Linear whose module name contains any token in `target_modules`
- Regular group: all remaining parameters
Returns:
(param_groups, lowrank_params, regular_params)
"""
lowrank_params: List[nn.Parameter] = []
for module_name, module in model.named_modules():
if not isinstance(module, nn.Linear):
continue
if not any(tk in module_name for tk in target_modules):
continue
# only the weight goes to low-rank group (mirrors your training script)
lowrank_params.append(module.weight)
id_low = {id(p) for p in lowrank_params}
regular_params: List[nn.Parameter] = [p for p in model.parameters() if id(p) not in id_low]
param_groups = [
{"params": regular_params},
{
"params": lowrank_params,
"rank": rank,
"update_proj_gap": update_proj_gap,
"scale": apollo_scale,
"proj_type": proj_type,
"proj": proj,
"scale_type": scale_type,
},
]
# print the number of parameters in the low-rank group and regular group
print(f"Number of parameters in the low-rank group: {sum(p.numel() for p in lowrank_params if p.requires_grad)}")
print(f"Number of parameters in the regular group: {sum(p.numel() for p in regular_params if p.requires_grad)}")
return param_groups, lowrank_params, regular_params
...
elif optimizer_type == "galore":
print("Building GaLore param groups...")
param_groups, _, _ = build_galore_param_groups(
model,
rank=128,
update_proj_gap=50,
galore_scale=1.0,
proj_type="std",
target_modules=("attn", "mlp"),
)
optim = GaLoreAdamW(param_groups, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
Metadata
Metadata
Assignees
Labels
No labels