Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 127 additions & 24 deletions experiments/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from marin.execution.remote import remote
from haliax.partitioning import ResourceAxis
from haliax.quantization import QuantizationConfig
from levanter.adaptation import NoAdaptationConfig
from levanter.checkpoint import CheckpointerConfig
from levanter.data.text import (
BlockShuffleConfig,
Expand All @@ -27,7 +28,7 @@
TextLmDatasetFormat,
)
from levanter.eval_harness import LmEvalHarnessConfig
from levanter.main.train_dpo import TrainDpoConfig
from levanter.main.train_dpo import SeparateReferenceConfig, TrainDpoConfig
from levanter.main.train_lm import TrainLmConfig
from levanter.models.llama import LlamaConfig
from levanter.models.lm_model import LmConfig
Expand Down Expand Up @@ -61,6 +62,7 @@
TokenizerStep,
add_validation_sets_to_mixture,
lm_data_config,
lm_mixture_data_config,
tokenize,
)
from marin.processing.tokenize.tokenize import HfTokenizeConfig, TokenizeConfigBase
Expand All @@ -76,6 +78,10 @@

HF_BUCKET_URI_PREFIX = "hf://buckets/"
HF_BUCKET_PATH_PREFIX = "buckets/"
_WANDB_NAME_MAX_LENGTH = 64
_WANDB_NAME_MIN_PREFIX_LENGTH = 24
_WANDB_NAME_AGGRESSIVE_TRUNCATION_CHARS = 16
_WANDB_SUFFIX_MARKERS = ("_seed", "-seed", "_step", "-step")


def _is_hf_bucket_path(path: str) -> bool:
Expand All @@ -96,21 +102,75 @@ def _normalize_hf_bucket_path(path: str) -> str:
"""Hierarchical block-shuffle default for newly constructed training runs."""


def _preferred_wandb_suffix_start(name: str) -> int | None:
"""Return the preferred suffix start for a truncated W&B run name.

We prefer underscore-delimited semantic tails like ``lr7.5e-7_seed2`` or
``foo_step400``. This avoids splitting on the ``-`` inside scientific
notation, which would corrupt names like ``lr7.5e-7``.
"""
for marker in _WANDB_SUFFIX_MARKERS:
marker_start = name.rfind(marker)
if marker_start == -1:
continue

prior_slash = name.rfind("/", 0, marker_start)
prior_underscore = name.rfind("_", 0, marker_start)
if prior_underscore > prior_slash:
return prior_underscore
return marker_start

last_underscore = name.rfind("_")
if last_underscore == -1:
return None

prior_slash = name.rfind("/", 0, last_underscore)
second_last_underscore = name.rfind("_", 0, last_underscore)
if second_last_underscore > prior_slash:
return second_last_underscore

return last_underscore


def _truncate_wandb_name(name: str) -> str:
"""Truncate a run name to fit WANDB's 64-character limit, preserving the trailing suffix."""
if len(name) <= 64:
"""Truncate a run name to fit W&B's 64-character limit without mangling semantic suffixes."""
if len(name) <= _WANDB_NAME_MAX_LENGTH:
return name

old_name = name
if "-" not in name:
name = name[:64]
suffix_start = _preferred_wandb_suffix_start(name)

if suffix_start is None:
name = name[:_WANDB_NAME_MAX_LENGTH]
preserved_suffix = ""
else:
prefix, suffix = name.rsplit("-", 1)
if len(suffix) >= 64:
suffix = suffix[:64]
name = suffix
suffix = name[suffix_start:]
preserved_suffix = suffix
if len(suffix) >= _WANDB_NAME_MAX_LENGTH:
name = name[:_WANDB_NAME_MAX_LENGTH]
preserved_suffix = ""
else:
name = prefix[: 63 - len(suffix)] + "-" + suffix
prefix_budget = _WANDB_NAME_MAX_LENGTH - len(suffix)
prefix = name[:prefix_budget]

# Prefer trimming at a token boundary so the retained prefix stays readable.
boundary = max(prefix.rfind("_"), prefix.rfind("/"))
if boundary >= _WANDB_NAME_MIN_PREFIX_LENGTH:
prefix = prefix[:boundary]

name = prefix + suffix

logger.warning(f"Truncated name from {old_name} to {name} to fit within WANDB limits.")

removed_chars = len(old_name) - len(name)
retained_prefix_len = len(name) - len(preserved_suffix)
if removed_chars >= _WANDB_NAME_AGGRESSIVE_TRUNCATION_CHARS or retained_prefix_len < _WANDB_NAME_MIN_PREFIX_LENGTH:
logger.warning(
"W&B run name %r required aggressive truncation to %r. Consider shortening the explicit name.",
old_name,
name,
)

return name


Expand Down Expand Up @@ -622,6 +682,12 @@ def default_dpo(
preference_data = PreferenceLmDataConfig.from_lm_data_config(pretraining_data)
preference_data = dataclasses.replace(preference_data, permutation_type="feistel")
dpo_tokenizer_name = unwrap_versioned_value(preference_data.tokenizer)
lm_validation_data = lm_mixture_data_config(
default_validation_sets(tokenizer=dpo_tokenizer_name),
{},
missing_weights_are_validation=True,
include_raw_paths=False,
)

name = _truncate_wandb_name(name)

Expand All @@ -630,12 +696,35 @@ def default_dpo(

train_length = _validate_train_length(dpo_config.train_seq_len, model_config)

requested_num_train_steps = dpo_config.num_train_steps
auto_num_epochs = None
if requested_num_train_steps is None:
requested_num_train_steps = 1
auto_num_epochs = dpo_config.num_epochs

requested_steps_per_eval = dpo_config.steps_per_eval
auto_validation_runs = None
if requested_steps_per_eval is None:
requested_steps_per_eval = 1
auto_validation_runs = 5

schedule = BatchSchedule(unwrap_versioned_value(dpo_config.train_batch_size))
total_examples = schedule.global_data_offset_by_step(dpo_config.num_train_steps)
total_examples = schedule.global_data_offset_by_step(requested_num_train_steps)

reference = dpo_config.reference
if isinstance(reference, SeparateReferenceConfig) and not reference.model_path:
reference_model_path = dpo_config.reference_model_path or dpo_config.model_name_or_path
if reference_model_path is None:
raise ValueError("reference_model_path must be set for DPO training when using a separate reference.")
reference = dataclasses.replace(
reference,
model_path=reference_model_path,
is_hf=dpo_config.reference_is_hf,
)

reference_model_path = dpo_config.reference_model_path or dpo_config.model_name_or_path
if reference_model_path is None:
raise ValueError("reference_model_path must be set for DPO training.")
hf_save_dtype = dpo_config.hf_save_dtype
if not isinstance(dpo_config.adapter, NoAdaptationConfig) and hf_save_dtype is not None:
raise ValueError("hf_save_dtype is not supported with adapter-based DPO exports.")

inner_config = TrainDpoConfig(
data=preference_data,
Expand All @@ -646,8 +735,8 @@ def default_dpo(
),
mp=jmp.get_policy("p=f32,c=bfloat16"),
train_batch_size=dpo_config.train_batch_size,
num_train_steps=dpo_config.num_train_steps,
steps_per_eval=dpo_config.steps_per_eval,
num_train_steps=requested_num_train_steps,
steps_per_eval=requested_steps_per_eval,
checkpointer=CheckpointerConfig(
save_interval=timedelta(minutes=10),
keep=_checkpoint_keep(steps_per_export),
Expand All @@ -659,6 +748,8 @@ def default_dpo(
"token_repeat": (ResourceAxis.REPLICA_DCN, ResourceAxis.REPLICA, ResourceAxis.DATA),
}
),
per_device_eval_parallelism=dpo_config.per_device_eval_parallelism,
profiler=dpo_config.profiler,
allow_partial_checkpoint=dpo_config.allow_partial_checkpoint,
allow_nondivisible_batch_size=True,
quantization=QuantizationConfig(int8=dpo_config.int8) if dpo_config.int8 else None,
Expand All @@ -668,6 +759,7 @@ def default_dpo(
initialize_from_hf=dpo_config.model_name_or_path if initialize_from_hf else False,
train_seq_len=train_length,
model=model_config,
adapter=dpo_config.adapter,
optimizer=AdamConfig(
learning_rate=dpo_config.learning_rate,
weight_decay=dpo_config.weight_decay,
Expand All @@ -677,12 +769,13 @@ def default_dpo(
min_lr_ratio=dpo_config.min_lr_ratio,
max_grad_norm=dpo_config.max_grad_norm,
),
reference_model_path=reference_model_path,
reference_is_hf=dpo_config.reference_is_hf,
reference=reference,
beta=dpo_config.beta,
validation_split_fraction=dpo_config.validation_split_fraction,
reference_eval_cache=dpo_config.reference_eval_cache,
lm_validation_data=lm_validation_data,
hf_save_steps=steps_per_export_hf,
hf_save_dtype=dpo_config.hf_save_dtype,
hf_save_dtype=hf_save_dtype,
hf_generation_eos_token_ids=dpo_config.hf_generation_eos_token_ids,
data_seed=dpo_config.seed,
)
Expand All @@ -691,18 +784,28 @@ def default_dpo(
train_config=inner_config,
resources=dpo_config.resources,
output_path=this_output_path(),
auto_num_epochs=auto_num_epochs,
auto_validation_runs=auto_validation_runs,
)

model_config = unwrap_versioned_value(model_config)

return ExecutorStep(
name=os.path.join("checkpoints", name),
description=(
f"Train a model (tokenizer={dpo_tokenizer_name}) for "
f"{dpo_config.num_train_steps} (steps) * "
f"{dpo_config.train_batch_size} (batch_size) * "
f"{train_length} (train_seq_len) "
f"= {total_examples * train_length} tokens."
(
f"Train a model (tokenizer={dpo_tokenizer_name}) for "
f"{requested_num_train_steps} (steps) * "
f"{dpo_config.train_batch_size} (batch_size) * "
f"{train_length} (train_seq_len) "
f"= {total_examples * train_length} tokens."
)
if auto_num_epochs is None
else (
f"Train a model (tokenizer={dpo_tokenizer_name}) for "
f"{dpo_config.num_epochs:g} epoch(s) with runtime-resolved step count "
f"and train_seq_len={train_length}."
)
),
fn=run_levanter_train_dpo,
config=config,
Expand Down
7 changes: 7 additions & 0 deletions experiments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ def download_model_step(model_config: ModelConfig) -> ExecutorStep:
)
)

marin_8b_instruct = download_model_step(
ModelConfig(
hf_repo_id="marin-community/marin-8b-instruct",
hf_revision="0378f9c",
)
)

llama_3_2_1b = download_model_step(
ModelConfig(
hf_repo_id="meta-llama/Llama-3.2-1B",
Expand Down
44 changes: 41 additions & 3 deletions experiments/simple_dpo_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from dataclasses import dataclass, field

from fray.cluster import ResourceConfig
from levanter.adaptation import AdaptationConfig, NoAdaptationConfig
from levanter.callbacks.profiler import ProfilerConfig
from levanter.dpo import ReferenceEvalCacheConfig
from levanter.main.train_dpo import DpoReferenceConfig, SeparateReferenceConfig
from levanter.schedule import IntSchedule

# DPO runs two models (policy + reference) but eval doesn't need gradients/optimizer,
# so we can fit more examples per device during eval than training.
# Keyed by TPU variant string from ResourceConfig.
DPO_EVAL_PARALLELISM: dict[str, int] = {
"v5p-8": 16,
"v5p-16": 16,
"v5p-32": 32,
"v5p-64": 32,
"v5p-128": 32,
"v5p-256": 64,
}


@dataclass(frozen=True)
class SimpleDPOConfig:
Expand All @@ -16,18 +32,25 @@ class SimpleDPOConfig:
resources: ResourceConfig

train_batch_size: int | IntSchedule = 128
num_train_steps: int = 10000
num_train_steps: int | None = None
num_epochs: float = 1.0
"""Approximate number of passes over the DPO train set when num_train_steps is unset."""
learning_rate: float = 1e-6
wandb_project: str | None = None

tokenizer: str | None = None
model_name_or_path: str | None = None
initialize_from_checkpoint_path: str | None = None

adapter: AdaptationConfig = field(default_factory=NoAdaptationConfig)
reference: DpoReferenceConfig = field(default_factory=SeparateReferenceConfig)
reference_model_path: str | None = None
reference_is_hf: bool = True
beta: float = 0.1
validation_split_fraction: float | None = 0.1
reference_eval_cache: ReferenceEvalCacheConfig = field(
default_factory=lambda: ReferenceEvalCacheConfig(mode="build_or_load")
)

train_seq_len: int | None = None
max_seq_len: int = 4096
Expand All @@ -39,7 +62,8 @@ class SimpleDPOConfig:
min_lr_ratio: float = 0.0
max_grad_norm: float | None = 1

steps_per_eval: int = 1000
steps_per_eval: int | None = None
"""None auto-schedules validation five times: before training, three interior points, and at the end."""
steps_per_checkpoint: int | None = None
"""How often to keep a permanent checkpoint. None (default) keeps only the final
checkpoint; rolling temporary checkpoints are still written for resumption."""
Expand All @@ -49,8 +73,22 @@ class SimpleDPOConfig:
"""EOS token IDs to write to generation_config.json. None means no generation config.
For chat models, include the turn-boundary token (e.g. [128001, 128009])."""

per_device_eval_parallelism: int = -1

seed: int = 0
initialize_from_hf: bool | None = None

profiler: ProfilerConfig = field(default_factory=ProfilerConfig)

allow_partial_checkpoint: bool = False
int8: bool = False

def __post_init__(self):
if self.num_train_steps is not None and self.num_train_steps <= 0:
raise ValueError(f"num_train_steps must be positive, got {self.num_train_steps}")
if self.num_epochs <= 0:
raise ValueError(f"num_epochs must be positive, got {self.num_epochs}")
if self.steps_per_eval is not None and self.steps_per_eval <= 0:
raise ValueError(f"steps_per_eval must be positive, got {self.steps_per_eval}")
if self.steps_per_checkpoint is not None and self.steps_per_checkpoint <= 0:
raise ValueError(f"steps_per_checkpoint must be positive, got {self.steps_per_checkpoint}")
10 changes: 8 additions & 2 deletions lib/levanter/config/dpo_tiny_gpt2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,14 @@ optimizer:
weight_decay: 0.0
warmup: 0.0

reference_model_path: hf-internal-testing/tiny-random-gpt2
reference_is_hf: true
adapter:
type: none

reference:
type: separate
model_path: hf-internal-testing/tiny-random-gpt2
is_hf: true

beta: 0.1
validation_split_fraction: 0.1

Expand Down
9 changes: 7 additions & 2 deletions lib/levanter/config/dpo_ultrafeedback_llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,13 @@ optimizer:
warmup: 0.1
max_grad_norm: 1.0

reference_model_path: gs://marin-us-central1/gcsfuse_mount/models/meta-llama--Llama-3-1-8B--main
reference_is_hf: true
adapter:
type: none

reference:
type: separate
model_path: gs://marin-us-central1/gcsfuse_mount/models/meta-llama--Llama-3-1-8B--main
is_hf: true
beta: 0.01
validation_split_fraction: null

Expand Down
Loading
Loading