-
Notifications
You must be signed in to change notification settings - Fork 111
Expand file tree
/
Copy pathtrain.py
More file actions
574 lines (494 loc) · 21.5 KB
/
train.py
File metadata and controls
574 lines (494 loc) · 21.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import dataclasses
import functools
import logging
import time
from dataclasses import dataclass, field
import jax
import jax.numpy as jnp
import jmp
import optax
from fray.cluster import ResourceConfig
from haliax import Axis
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jax.tree_util import register_dataclass
from jaxtyping import PRNGKeyArray
import levanter.callbacks as callbacks
import levanter.tracker
from levanter.callbacks.state_adapter import StateCallbackRunner
from levanter.callbacks.watch import WatchConfig, compute_watch_stats
from levanter.data import AsyncDataset, DataLoader
from levanter.data.mixture import MixtureDataset, rescale_mixture_schedule_for_batch_schedule
from levanter.data.text import GrugLmExample, LmDataConfig
from levanter.data.text.examples import grug_lm_example_from_named
from levanter.eval import TaggedEvaluator, cb_tagged_evaluate
from levanter.models.lm_model import LmExample
from levanter.optim import AdamConfig, OptimizerConfig
from levanter.schedule import BatchSchedule
from levanter.trainer import TrainerConfig
from levanter.utils.flop_utils import lm_flops_per_token
from levanter.utils.jax_utils import parameter_count
from levanter.utils.logging import LoadingTimeTrackerIterator
import equinox as eqx
from experiments.grug.checkpointing import restore_grug_state_from_checkpoint
from experiments.grug.dispatch import dispatch_grug_training_run
from experiments.grug.moe.model import GrugModelConfig, Transformer
# This file intentionally mirrors `experiments/grug/base/train.py` with
# variant-specific model/loss/FLOP wiring, per the grug copy-first workflow in
# `.agents/skills/change-grug/`.
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class GrugTrainerConfig:
"""Runtime knobs for grug training."""
trainer: TrainerConfig = field(default_factory=lambda: TrainerConfig(use_explicit_mesh_axes=True))
train_batch_pspec: P = field(default_factory=lambda: P(("data", "expert")))
data_seed: int | None = None
log_every: int = 1
ema_beta: float | None = None # EMA coefficient for eval/checkpoint model; None disables EMA.
z_loss_weight: float = 0.0 # Weight on logsumexp (z-loss) stabilization term.
@dataclass(frozen=True)
class GrugEvalConfig:
"""Perplexity eval settings for grug training."""
eval_batch_size: int = 512
eval_batch_pspec: P = field(default_factory=lambda: P(("data", "expert")))
steps_per_eval: int | None = 1000
max_eval_batches: int | None = None
prefix: str = "eval"
eval_current: bool = True
eval_ema: bool = True
compute_bpb: bool = True
@dataclass(frozen=True)
class GrugRunConfig:
"""Top-level config for grug training."""
model: GrugModelConfig
data: LmDataConfig
resources: ResourceConfig
optimizer: OptimizerConfig = field(default_factory=AdamConfig)
trainer: GrugTrainerConfig = field(default_factory=GrugTrainerConfig)
eval: GrugEvalConfig | None = field(default_factory=GrugEvalConfig)
def build_train_dataset(
data_config: LmDataConfig,
*,
max_seq_len: int,
batch_schedule: BatchSchedule,
key: PRNGKeyArray,
) -> MixtureDataset[GrugLmExample]:
pos = Axis("position", max_seq_len)
mix_key, shuffle_key = jax.random.split(key)
weights = data_config.train_weights
if isinstance(weights, list):
weights = rescale_mixture_schedule_for_batch_schedule(weights, batch_schedule)
initial_batch_size = batch_schedule.batch_size_at_step(0)
datasets = data_config.train_sets(pos, key=shuffle_key, initial_batch_size=initial_batch_size)
return MixtureDataset(
datasets=datasets,
weights=weights,
stop_strategy=data_config.stop_strategy,
key=mix_key,
block_size=data_config.mixture_block_size,
)
def build_train_loader(
dataset: AsyncDataset[GrugLmExample],
*,
batch_schedule: BatchSchedule,
mesh: Mesh,
batch_pspec: P = P(("data", "expert")),
) -> DataLoader[GrugLmExample]:
# DataLoader uses this batch axis mapping to shard batches across the distributed mesh.
axis_resource = batch_pspec[0]
return DataLoader(
dataset,
batch_schedule.schedule,
mesh=mesh,
axis_resources={"__BATCH__": axis_resource},
batch_axis_name="__BATCH__",
allow_nondivisible_batch_size=False,
)
def build_tagged_evaluator(
*,
data_config: LmDataConfig,
max_seq_len: int,
mesh: Mesh,
eval_cfg: GrugEvalConfig,
) -> TaggedEvaluator[LmExample | GrugLmExample, Transformer] | None:
pos = Axis("position", max_seq_len)
tagged_eval_sets = data_config.tagged_eval_sets(pos)
if len(tagged_eval_sets) == 0:
logger.warning("No evaluation datasets provided.")
return None
max_examples_per_dataset = None
if eval_cfg.max_eval_batches is not None:
max_examples_per_dataset = eval_cfg.max_eval_batches * eval_cfg.eval_batch_size
tokenizer = data_config.the_tokenizer if eval_cfg.compute_bpb else None
batch_axis_resource = eval_cfg.eval_batch_pspec[0]
eval_axis_mapping = {"batch": batch_axis_resource}
eval_batch = Axis("batch", eval_cfg.eval_batch_size)
eval_array_sharding = NamedSharding(mesh, P(batch_axis_resource, None))
def eval_loss_fn(model: Transformer, batch: LmExample | GrugLmExample) -> tuple[jax.Array, jax.Array, jax.Array]:
if isinstance(batch, LmExample):
batch = grug_lm_example_from_named(batch)
per_pos_loss = model.next_token_loss(
batch.tokens,
batch.loss_weight,
mask=batch.attn_mask,
reduction="none",
logsumexp_weight=None,
)
per_pos_loss = jax.sharding.reshard(per_pos_loss, eval_array_sharding)
per_pos_weight = jax.sharding.reshard(batch.loss_weight, eval_array_sharding)
per_pos_token_id = jnp.roll(batch.tokens, -1, axis=-1)
return per_pos_loss, per_pos_weight, per_pos_token_id
return TaggedEvaluator(
EvalBatch=eval_batch,
tagged_eval_sets=tagged_eval_sets,
loss_fn=eval_loss_fn,
tokenizer=tokenizer,
device_mesh=mesh,
axis_mapping=eval_axis_mapping,
max_examples_per_dataset=max_examples_per_dataset,
)
def _compute_flops(
*,
model_config: GrugModelConfig,
) -> tuple[float, dict[str, float]]:
flops_per_token = lm_flops_per_token(
hidden_dim=model_config.hidden_dim,
intermediate_dim=model_config.intermediate_dim,
shared_intermediate_dim=model_config.shared_expert_intermediate_dim,
num_layers=model_config.num_layers,
num_kv_heads=model_config.num_kv_heads,
num_heads=model_config.num_heads,
seq_len=model_config.max_seq_len,
vocab_size=model_config.vocab_size,
glu=True,
num_experts=model_config.num_experts,
num_shared_experts=1 if model_config.shared_expert_intermediate_dim > 0 else 0,
num_experts_per_tok=model_config.num_experts_per_token,
)
flops_per_example = 3 * flops_per_token * model_config.max_seq_len
flops_summary: dict[str, float] = {
"throughput/flops_per_token_analytic": flops_per_token,
"throughput/flops_per_example_analytic": flops_per_example,
}
return flops_per_example, flops_summary
def _make_mixture_stage_callback(train_dataset: MixtureDataset, batch_schedule: BatchSchedule):
last_mixture_stage = -1
def log_mixture_stage(step_info):
nonlocal last_mixture_stage
seq_index = batch_schedule.global_data_offset_by_step(step_info.step)
block_id = seq_index // train_dataset.block_size
stage = train_dataset._get_stage_for_block(block_id)
if stage == last_mixture_stage:
return
weights = train_dataset.weight_stages[stage][1]
mixture_log = {f"mixture/weight/{name}": weight for name, weight in weights.items()}
mixture_log["mixture/stage"] = stage
levanter.tracker.log(mixture_log, step=step_info.step)
last_mixture_stage = stage
return log_mixture_stage
@register_dataclass
@dataclass(frozen=True)
class GrugTrainState:
step: jax.Array
params: Transformer
opt_state: optax.OptState
ema_params: Transformer | None
def _apply_qb_betas(model: Transformer, qb_betas: jax.Array) -> Transformer:
"""Set router biases from QB betas (computed on previous step, applied on host)."""
new_blocks = list(model.blocks)
moe_idx = 0
for i, block in enumerate(model.blocks):
if block.mlp is None:
continue
new_bias = -qb_betas[moe_idx]
new_bias = new_bias - jnp.mean(new_bias)
new_mlp = eqx.tree_at(lambda m: m.router_bias, block.mlp, new_bias)
new_blocks[i] = eqx.tree_at(lambda b: b.mlp, block, new_mlp)
moe_idx += 1
return eqx.tree_at(lambda t: t.blocks, model, tuple(new_blocks))
def initial_state(
model_config: GrugModelConfig,
*,
optimizer: optax.GradientTransformation,
mp: jmp.Policy,
key: PRNGKeyArray,
ema_beta: float | None,
) -> GrugTrainState:
params = mp.cast_to_param(Transformer.init(model_config, key=key))
return GrugTrainState(
step=jnp.array(0, dtype=jnp.int32),
params=params,
opt_state=optimizer.init(params),
ema_params=params if ema_beta is not None else None,
)
def _make_train_step(
optimizer: optax.GradientTransformation,
mp: jmp.Policy,
*,
z_loss_weight: float,
ema_beta: float | None,
watch_config: WatchConfig | None = None,
):
one = jnp.array(1, dtype=jnp.int32)
z_loss = z_loss_weight if z_loss_weight > 0 else None
if watch_config is not None:
if isinstance(watch_config.watch_targets, str):
watch_targets = tuple(t.strip() for t in watch_config.watch_targets.split(","))
else:
watch_targets = tuple(watch_config.watch_targets)
else:
watch_targets = ()
@functools.partial(jax.jit, donate_argnums=(0,), static_argnames=("compute_watch",))
def train_step(state: GrugTrainState, batch, *, compute_watch: bool = False):
def loss_fn(params):
compute_params = mp.cast_to_compute(params)
return compute_params.next_token_loss(
batch.tokens,
batch.loss_weight,
mask=batch.attn_mask,
reduction="mean",
logsumexp_weight=z_loss,
return_router_metrics=True,
)
(loss, summarized_metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
metrics = {"train/loss": loss, **summarized_metrics}
updates, opt_state = optimizer.update(grads, state.opt_state, state.params)
params = optax.apply_updates(state.params, updates)
if ema_beta is None:
ema_params = None
else:
if state.ema_params is None:
raise ValueError("ema_params must be initialized when ema_beta is set.")
ema_params = jax.tree_util.tree_map(
lambda old, new: ema_beta * old + (1.0 - ema_beta) * new,
state.ema_params,
params,
)
watch_stats = None
if watch_config is not None and compute_watch:
watch_stats = compute_watch_stats(
watch_targets=watch_targets,
include_norms=watch_config.include_norms,
include_per_parameter_norms=watch_config.include_per_parameter_norms,
include_histogram=watch_config.include_histograms,
split_scan_layers=watch_config.split_scan_layers,
params=state.params,
grads=grads,
updates=updates,
opt_state=state.opt_state,
model_tree_type=type(state.params),
)
next_state = dataclasses.replace(
state,
step=state.step + one,
params=params,
opt_state=opt_state,
ema_params=ema_params,
)
return next_state, metrics, watch_stats
return train_step
def _run_grug_local(config: GrugRunConfig) -> None:
"""Entry point for the grug template training loop."""
trainer = config.trainer.trainer
trainer.initialize()
levanter.tracker.log_configuration(config)
run_id = trainer.id
if run_id is None:
raise ValueError("trainer.id was not initialized")
optimizer = config.optimizer.build(trainer.num_train_steps)
watch_config = trainer.watch
train_step = _make_train_step(
optimizer,
trainer.mp,
z_loss_weight=config.trainer.z_loss_weight,
ema_beta=config.trainer.ema_beta,
watch_config=watch_config if watch_config.is_enabled else None,
)
data_key, model_key = jax.random.split(jax.random.PRNGKey(trainer.seed), 2)
if config.trainer.data_seed is not None:
data_key = jax.random.PRNGKey(config.trainer.data_seed)
# Build data/model state under the trainer mesh so all arrays are sharded consistently.
with trainer.use_device_mesh():
mesh = trainer.device_mesh
batch_schedule = trainer.batch_schedule
train_dataset = build_train_dataset(
config.data,
max_seq_len=config.model.max_seq_len,
batch_schedule=batch_schedule,
key=data_key,
)
train_loader = build_train_loader(
train_dataset,
batch_schedule=batch_schedule,
mesh=mesh,
batch_pspec=config.trainer.train_batch_pspec,
)
@jax.jit
def _init_state(model_rng):
return initial_state(
config.model,
optimizer=optimizer,
mp=trainer.mp,
key=model_rng,
ema_beta=config.trainer.ema_beta,
)
state = _init_state(model_key)
checkpointer = trainer.checkpointer.create(run_id)
checkpoint_path = trainer.load_checkpoint_path
if checkpoint_path is None and checkpointer is not None:
checkpoint_path = trainer.checkpointer.expanded_path(run_id)
additional_checkpoint_paths = []
temp_path = trainer.checkpointer.expanded_temporary_path(run_id)
if temp_path is not None:
additional_checkpoint_paths.append(temp_path)
state = restore_grug_state_from_checkpoint(
state,
checkpoint_path=checkpoint_path,
load_checkpoint_setting=trainer.load_checkpoint,
mesh=mesh,
allow_partial=trainer.allow_partial_checkpoint,
additional_checkpoint_paths=additional_checkpoint_paths,
)
levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)})
flops_per_example, flops_summary = _compute_flops(model_config=config.model)
levanter.tracker.log_summary(flops_summary)
eval_cfg = config.eval
evaluator = None
if eval_cfg is not None:
evaluator = build_tagged_evaluator(
data_config=config.data,
max_seq_len=config.model.max_seq_len,
mesh=mesh,
eval_cfg=eval_cfg,
)
profiler_cfg = trainer.profiler
profiler_num_steps = profiler_cfg.resolve_num_profile_steps(num_train_steps=trainer.num_train_steps)
profiler_enabled = profiler_cfg.is_enabled and profiler_num_steps > 0
log_every = max(1, config.trainer.log_every)
iterator = LoadingTimeTrackerIterator(train_loader.iter_from_step(int(state.step)))
state_callbacks = StateCallbackRunner[GrugTrainState](
step_getter=lambda s: s.step,
model_getter=lambda s: s.params,
eval_model_getter=lambda s: s.ema_params if s.ema_params is not None else s.params,
opt_state_getter=lambda s: s.opt_state,
)
state_callbacks.add_hook(
callbacks.log_performance_stats(config.model.max_seq_len, batch_schedule, flops_per_example),
every=log_every,
)
state_callbacks.add_hook(callbacks.pbar_logger(total=trainer.num_train_steps), every=log_every)
state_callbacks.add_hook(callbacks.log_step_info(trainer.num_train_steps), every=log_every)
if profiler_enabled:
state_callbacks.add_hook(
callbacks.profile(
str(trainer.log_dir / run_id / "profiler"),
profiler_cfg.start_step,
profiler_num_steps,
profiler_cfg.perfetto_link,
),
every=1,
)
state_callbacks.add_hook(_make_mixture_stage_callback(train_dataset, batch_schedule), every=1)
if evaluator is not None and eval_cfg is not None:
interval = eval_cfg.steps_per_eval
eval_ema = eval_cfg.eval_ema and config.trainer.ema_beta is not None
if interval is not None and interval > 0 and (eval_cfg.eval_current or eval_ema):
state_callbacks.add_hook(
cb_tagged_evaluate(
evaluator,
prefix=eval_cfg.prefix,
eval_current=eval_cfg.eval_current,
eval_ema=eval_ema,
),
every=interval,
)
last_loss: float | jax.Array = 0.0
last_step_duration = 0.0
pending_qb_betas: jax.Array | None = None
# Main optimization loop.
try:
while int(state.step) < trainer.num_train_steps:
# QB: apply router bias updates from previous step (on host).
if pending_qb_betas is not None:
state = dataclasses.replace(
state,
params=_apply_qb_betas(state.params, pending_qb_betas),
ema_params=(
_apply_qb_betas(state.ema_params, pending_qb_betas) if state.ema_params is not None else None
),
)
pending_qb_betas = None
with jax.profiler.TraceAnnotation("load_batch"):
batch = next(iterator)
step_start = time.perf_counter()
current_step = int(state.step)
# grad_watch runs only on its configured interval.
compute_watch = (
watch_config.is_enabled and watch_config.interval > 0 and current_step % watch_config.interval == 0
)
state, metrics, watch_stats = train_step(state, batch, compute_watch=compute_watch)
step = int(state.step) - 1
jax.block_until_ready(metrics["train/loss"])
pending_qb_betas = metrics["qb_beta_per_layer"]
if jnp.isnan(metrics["train/loss"]):
logger.error(f"NaN loss at step {int(state.step)}. Stopping training.")
break
duration = time.perf_counter() - step_start
hook_start = time.perf_counter()
with jax.profiler.TraceAnnotation("callbacks"):
state_callbacks.run(state, loss=metrics["train/loss"], step_duration=duration)
last_loss = metrics["train/loss"]
last_step_duration = duration
levanter.tracker.log({"throughput/hook_time": time.perf_counter() - hook_start}, step=step)
levanter.tracker.log({"throughput/loading_time": iterator.this_load_time}, step=step)
router_metrics = {
key: value
for key, value in metrics.items()
if (key.startswith("train/router/") or key.startswith("moe_bias/"))
and key not in ("train/router/routing_counts_per_layer", "qb_beta_per_layer")
}
if router_metrics:
levanter.tracker.log(router_metrics, step=step)
if "train/cross_entropy_loss" in metrics:
levanter.tracker.log(
{"train/cross_entropy_loss": metrics["train/cross_entropy_loss"]},
step=step,
)
if watch_stats is not None:
levanter.tracker.log(watch_stats, step=step)
if checkpointer is not None:
checkpointer.on_step(tree=state, step=int(state.step))
except BaseException:
logger.exception(
"Fatal error in grug training loop; skipping final callbacks/checkpoint to preserve root cause"
)
raise
else:
# Mirror classic trainer behavior: force callbacks on the last completed step.
state_callbacks.run(state, loss=last_loss, step_duration=last_step_duration, force=True)
if checkpointer is not None:
checkpointer.on_step(tree=state, step=int(state.step), force=True)
checkpointer.wait_until_finished()
levanter.tracker.current_tracker().finish()
def run_grug(config: GrugRunConfig) -> None:
"""Dispatch grug training through Fray jobs."""
trainer = config.trainer.trainer
if trainer.id is None:
raise ValueError("trainer.id must be set before dispatching grug training.")
dispatch_grug_training_run(
run_id=trainer.id,
config=config,
local_entrypoint=_run_grug_local,
resources=config.resources,
)
__all__ = [
"GrugEvalConfig",
"GrugRunConfig",
"GrugTrainState",
"GrugTrainerConfig",
"initial_state",
"run_grug",
]