Skip to content

does apollo work with qwen? #12

@pengzhangzhi

Description

@pengzhangzhi

Hi, I tried to apply apollo to qwen3 finetuning and it seems the efficiency, throughput and GPU memory are not getting better than adam. I wonder if you have any insights? I debuged it and doesnt seem to be any misuse:


def build_apollo_param_groups(
    model: nn.Module,
    *,
    # selection
    target_modules: Sequence[str] = ("attn", "mlp"),
    # APOLLO hyperparameters (explicit, no args object)
    rank: int = 256,                 # --rank
    update_proj_gap: int = 200,      # --update_proj_gap
    apollo_scale: float = 1.0,       # --apollo_scale
    proj_type: str = "std",          # --proj_type
    proj: str = "random",            # --proj
    scale_type: str = "channel",     # --scale_type
) -> Tuple[List[dict], List[nn.Parameter], List[nn.Parameter]]:
    """
    Build APOLLO param groups:
      - Low-rank group: weights of nn.Linear whose module name contains any token in `target_modules`
      - Regular group: all remaining parameters

    Returns:
      (param_groups, lowrank_params, regular_params)
    """
    lowrank_params: List[nn.Parameter] = []

    for module_name, module in model.named_modules():
        if not isinstance(module, nn.Linear):
            continue
        if not any(tk in module_name for tk in target_modules):
            continue
        # only the weight goes to low-rank group (mirrors your training script)
        lowrank_params.append(module.weight)

    id_low = {id(p) for p in lowrank_params}
    regular_params: List[nn.Parameter] = [p for p in model.parameters() if id(p) not in id_low]

    param_groups = [
        {"params": regular_params},
        {
            "params": lowrank_params,
            "rank": rank,
            "update_proj_gap": update_proj_gap,
            "scale": apollo_scale,
            "proj_type": proj_type,
            "proj": proj,
            "scale_type": scale_type,
        },
    ]
    # print the number of parameters in the low-rank group and regular group
    print(f"Number of parameters in the low-rank group: {sum(p.numel() for p in lowrank_params if p.requires_grad)}")
    print(f"Number of parameters in the regular group: {sum(p.numel() for p in regular_params if p.requires_grad)}")
    return param_groups, lowrank_params, regular_params



...
elif optimizer_type == "galore":
        print("Building GaLore param groups...")
        param_groups, _, _ = build_galore_param_groups(
            model,
            rank=128,
            update_proj_gap=50,
            galore_scale=1.0,
            proj_type="std",
            target_modules=("attn", "mlp"),
        )
        optim = GaLoreAdamW(param_groups, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
    

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions