-
Notifications
You must be signed in to change notification settings - Fork 111
Expand file tree
/
Copy pathexp4021_ablate_shared_expert.py
More file actions
161 lines (131 loc) · 5.15 KB
/
exp4021_ablate_shared_expert.py
File metadata and controls
161 lines (131 loc) · 5.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0
"""Experiment 4021: ablate shared expert in the good-enough 10T MoE gate.
Run a single ~1e19-FLOP comparison: baseline (with shared expert) vs ablation
(shared_expert_intermediate_dim=0). Both arms use the trial model architecture
and optimizer, differing only in the shared expert. Each arm's step count is
chosen so that 3 * flops_per_token * batch_size * seq_len * steps ≈ 1e19.
"""
import dataclasses
from fray.cluster import ResourceConfig
from levanter.optim import AdamConfig
from levanter.tracker.wandb import WandbConfig
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
from experiments.grug.moe.launch import (
GRUG_MOE_TRIAL_MODEL,
GrugMoeLaunchConfig,
GrugTrainerConfig,
NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
_resolve_run_id,
run_grug_moe,
)
from experiments.grug.moe.model import GrugModelConfig
from experiments.grug.moe.train import GrugEvalConfig
FLOP_BUDGET = 1e19
BATCH_SIZE = 512
SEQ_LEN = GRUG_MOE_TRIAL_MODEL.max_seq_len
def _steps_for_budget(flops_per_token: float) -> int:
"""Compute training steps so total FLOPs ≈ FLOP_BUDGET."""
tokens = FLOP_BUDGET / (3 * flops_per_token)
return max(1, round(tokens / (BATCH_SIZE * SEQ_LEN)))
def _flops_per_token(model: GrugModelConfig) -> float:
from levanter.utils.flop_utils import lm_flops_per_token
return lm_flops_per_token(
hidden_dim=model.hidden_dim,
intermediate_dim=model.intermediate_dim,
shared_intermediate_dim=model.shared_expert_intermediate_dim,
num_layers=model.num_layers,
num_kv_heads=model.num_kv_heads,
num_heads=model.num_heads,
seq_len=model.max_seq_len,
vocab_size=model.vocab_size,
glu=True,
num_experts=model.num_experts,
num_shared_experts=1 if model.shared_expert_intermediate_dim > 0 else 0,
num_experts_per_tok=model.num_experts_per_token,
)
# ---------------------------------------------------------------------------
# Model configs
# ---------------------------------------------------------------------------
SHARED_MODEL = GRUG_MOE_TRIAL_MODEL # shared_expert_intermediate_dim=1792
NO_SHARED_MODEL = dataclasses.replace(GRUG_MOE_TRIAL_MODEL, shared_expert_intermediate_dim=0)
SHARED_STEPS = _steps_for_budget(_flops_per_token(SHARED_MODEL))
NO_SHARED_STEPS = _steps_for_budget(_flops_per_token(NO_SHARED_MODEL))
# ---------------------------------------------------------------------------
# Common training knobs (mirrors trial defaults)
# ---------------------------------------------------------------------------
_OPTIMIZER = AdamConfig(
learning_rate=3e-3,
weight_decay=0.1,
lr_schedule="cosine",
decay=0.2,
min_lr_ratio=0.1,
warmup=1000,
)
_GRUG_TRAINER = GrugTrainerConfig(
z_loss_weight=1e-4,
ema_beta=None,
log_every=1,
)
_EVAL = GrugEvalConfig(
eval_batch_size=512,
steps_per_eval=1000,
max_eval_batches=8,
eval_current=True,
eval_ema=False,
)
def _wandb(group: str, run_id: str) -> WandbConfig:
return WandbConfig(
project="marin",
tags=["grug", "moe", "exp4021", "shared-expert-ablation"],
group=group,
name=None,
)
# ---------------------------------------------------------------------------
# Executor steps
# ---------------------------------------------------------------------------
_SHARED_RUN_ID = _resolve_run_id("exp4021-shared")
_NO_SHARED_RUN_ID = _resolve_run_id("exp4021-no-shared")
shared_expert_baseline = ExecutorStep(
name="grug/exp4021-shared",
fn=run_grug_moe,
config=GrugMoeLaunchConfig(
model=versioned(SHARED_MODEL),
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
output_path=this_output_path(),
run_id=_SHARED_RUN_ID,
resources=versioned(ResourceConfig.with_tpu("v5p-8")),
steps=versioned(SHARED_STEPS),
batch_size=versioned(BATCH_SIZE),
seed=versioned(0),
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
tracker=_wandb("exp4021-ablate-shared-expert", _SHARED_RUN_ID),
optimizer=versioned(_OPTIMIZER),
grug_trainer=versioned(_GRUG_TRAINER),
eval=versioned(_EVAL),
),
)
no_shared_expert_ablation = ExecutorStep(
name="grug/exp4021-no-shared",
fn=run_grug_moe,
config=GrugMoeLaunchConfig(
model=versioned(NO_SHARED_MODEL),
data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION,
output_path=this_output_path(),
run_id=_NO_SHARED_RUN_ID,
resources=versioned(ResourceConfig.with_tpu("v5p-8")),
steps=versioned(NO_SHARED_STEPS),
batch_size=versioned(BATCH_SIZE),
seed=versioned(0),
mp=versioned("params=float32,compute=bfloat16,output=bfloat16"),
tracker=_wandb("exp4021-ablate-shared-expert", _NO_SHARED_RUN_ID),
optimizer=versioned(_OPTIMIZER),
grug_trainer=versioned(_GRUG_TRAINER),
eval=versioned(_EVAL),
),
)
if __name__ == "__main__":
executor_main(
steps=[shared_expert_baseline, no_shared_expert_ablation],
description="Exp 4021: ablate shared expert at ~1e19 FLOPs. Fixes #4021.",
)