Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions experiments/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _truncate_wandb_name(name: str) -> str:
return name


def _resolve_hf_export_steps(steps_per_hf_export: int | None, steps_per_export: int) -> int | None:
def _resolve_hf_export_steps(steps_per_hf_export: int | None, steps_per_export: int | None) -> int | None:
"""Resolve the HF export step interval: None means same as checkpoint, -1 means disabled."""
if steps_per_hf_export is None:
return steps_per_export
Expand All @@ -125,6 +125,17 @@ def _resolve_hf_export_steps(steps_per_hf_export: int | None, steps_per_export:
return steps_per_hf_export


def _checkpoint_keep(steps_per_export: int | None) -> list[dict]:
"""Build the `keep` list for `CheckpointerConfig`.

None means keep no permanent intermediate checkpoints (only the final checkpoint
is saved at end-of-training, plus a rolling temporary checkpoint for resumption).
"""
if steps_per_export is None:
return []
return [dict(every=steps_per_export)]


def _validate_train_length(train_seq_len: int | None, model_config: LmConfig) -> int:
"""Resolve and validate the training sequence length against the model's max."""
actual = unwrap_versioned_value(model_config)
Expand Down Expand Up @@ -409,7 +420,7 @@ def default_train(
steps_per_eval=train_config.steps_per_eval if train_config.steps_per_eval is not None else 1000,
checkpointer=CheckpointerConfig(
save_interval=timedelta(minutes=10),
keep=[dict(every=steps_per_export)],
keep=_checkpoint_keep(steps_per_export),
),
model_averaging=model_averaging,
mesh=MeshConfig(
Expand Down Expand Up @@ -641,7 +652,7 @@ def default_dpo(
steps_per_eval=dpo_config.steps_per_eval,
checkpointer=CheckpointerConfig(
save_interval=timedelta(minutes=10),
keep=[dict(every=steps_per_export)],
keep=_checkpoint_keep(steps_per_export),
),
model_averaging=None,
mesh=MeshConfig(
Expand Down
4 changes: 3 additions & 1 deletion experiments/simple_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class SimpleDPOConfig:
max_grad_norm: float | None = 1

steps_per_eval: int = 1000
steps_per_checkpoint: int = 1000
steps_per_checkpoint: int | None = None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Disable default HF exports in DPO no-checkpoint mode

This default now disables periodic Levanter checkpoints, but leaves steps_per_hf_export at 500, and default_dpo still resolves and passes that interval through (experiments/defaults.py), with HF save hooks enabled whenever hf_save_steps is set (lib/levanter/src/levanter/main/train_dpo.py). As a result, DPO runs with defaults continue to emit frequent permanent HF checkpoints, which undercuts the new conservative-retention behavior and can cause avoidable artifact bloat.

Useful? React with 👍 / 👎.

"""How often to keep a permanent checkpoint. None (default) keeps only the final
checkpoint; rolling temporary checkpoints are still written for resumption."""
steps_per_hf_export: int = 500
hf_save_dtype: str | None = None
hf_generation_eos_token_ids: list[int] | None = None
Expand Down
5 changes: 3 additions & 2 deletions experiments/simple_sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ class SimpleSFTConfig:
steps_per_eval: int = 1000
"""How often to run validation losses."""

steps_per_checkpoint: int = 1000
"""How often to save checkpoints."""
steps_per_checkpoint: int | None = None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Disable default HF exports in SFT no-checkpoint mode

Setting steps_per_checkpoint to None here makes SFT look like it will keep only a final checkpoint, but steps_per_hf_export still defaults to 500, and default_sft forwards that value into default_train (experiments/defaults.py), which always gets an hf_save_path (lib/marin/src/marin/training/training.py) and registers periodic HF saves when hf_save_steps is non-null (lib/levanter/src/levanter/main/train_lm.py). In long SFT runs this still creates many permanent HF snapshots, so retention is not actually conservative and storage usage can grow unexpectedly.

Useful? React with 👍 / 👎.

"""How often to keep a permanent checkpoint. None (default) keeps only the final
checkpoint; rolling temporary checkpoints are still written for resumption."""

steps_per_hf_export: int = 500
"""How often to save HuggingFace checkpoints."""
Expand Down
4 changes: 3 additions & 1 deletion experiments/simple_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ class SimpleTrainConfig:

steps_per_eval: int | None = None
"""how often to run validation losses"""
steps_per_export: int = 10000
steps_per_export: int | None = None
"""How often to keep a permanent checkpoint. None (default) keeps only the final
checkpoint; rolling temporary checkpoints are still written for resumption."""
steps_per_task_eval: int | None = None
"""how often to run task evaluations"""
steps_per_hf_export: int | None = None
Expand Down
Loading