Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/tutorials/train-dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,44 @@ dpo_config = SimpleDPOConfig(
| `model_name_or_path` | HuggingFace model to initialize the policy from. Also used as the reference model unless `reference_model_path` is set separately. |
| `reference_model_path` | Path to the reference model. Defaults to `model_name_or_path`. |
| `validation_split_fraction` | Fraction of training data to hold out for validation (default 0.1). Set to `None` to use a separate validation set. |
| `hf_generation_eos_token_ids` | List of token IDs to write to `generation_config.json` for inference stop conditions. See below. |

### Setting Generation Stop Tokens

Chat models use a turn-boundary token (e.g. `<|eot_id|>`) to end assistant
responses, but the tokenizer's `eos_token` is typically the pre-training
document boundary (`<|end_of_text|>`). Inference tools like vLLM need both
tokens as stop conditions.

Set `hf_generation_eos_token_ids` to write a `generation_config.json` alongside
each saved checkpoint. The tokenizer's `eos_token_id` is auto-added if not
already in the list.

For Llama 3 models, use the predefined constant:

```python
from experiments.llama import LLAMA3_CHAT_STOP_TOKEN_IDS

dpo_config = SimpleDPOConfig(
...
hf_generation_eos_token_ids=LLAMA3_CHAT_STOP_TOKEN_IDS, # [128001, 128009]
)
```

For other model families, determine the correct stop token by applying the
chat template and checking the last token of the assistant turn:

```python
tokens = tokenizer.apply_chat_template(
[{"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}],
tokenize=True,
)
print(f"Chat stop token: {tokens[-1]}") # e.g. 128009 for <|eot_id|>
print(f"Tokenizer EOS: {tokenizer.eos_token_id}") # e.g. 128001 for <|end_of_text|>
```

If the two differ, pass both: `hf_generation_eos_token_ids=[eos_token_id, chat_stop_token]`.
If they match, you don't need to set this field.

## Creating the Training Pipeline

Expand Down
3 changes: 3 additions & 0 deletions experiments/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def default_train(
)
),
hf_save_steps=steps_per_export_hf,
hf_generation_eos_token_ids=train_config.hf_generation_eos_token_ids,
data_seed=train_config.data_seed,
eval_harness_steps=train_config.steps_per_task_eval or 10000,
eval_harness=harness_config,
Expand Down Expand Up @@ -557,6 +558,7 @@ def default_sft(
beta2=sft_config.beta2,
pad_tokenizer_to_match_model=sft_config.pad_tokenizer_to_match_model,
per_device_parallelism=sft_config.per_device_parallelism,
hf_generation_eos_token_ids=sft_config.hf_generation_eos_token_ids,
)

if sft_config.reinit_tokens:
Expand Down Expand Up @@ -672,6 +674,7 @@ def default_dpo(
validation_split_fraction=dpo_config.validation_split_fraction,
hf_save_steps=steps_per_export_hf,
hf_save_dtype=dpo_config.hf_save_dtype,
hf_generation_eos_token_ids=dpo_config.hf_generation_eos_token_ids,
data_seed=dpo_config.seed,
)

Expand Down
3 changes: 2 additions & 1 deletion experiments/dpo_ultrafeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from levanter.data.text import PreferenceChatLmDatasetFormat

from experiments.defaults import default_dpo, default_tokenize
from experiments.llama import llama_8b
from experiments.llama import LLAMA3_CHAT_STOP_TOKEN_IDS, llama_8b
from experiments.marin_models import marin_tokenizer
from experiments.models import llama_3_1_8b
from experiments.posttrain.preference_datasets import get_preference_dataset
Expand Down Expand Up @@ -61,6 +61,7 @@
steps_per_eval=200,
steps_per_checkpoint=1000,
steps_per_hf_export=1000,
hf_generation_eos_token_ids=LLAMA3_CHAT_STOP_TOKEN_IDS,
seed=0,
)

Expand Down
8 changes: 8 additions & 0 deletions experiments/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
llama3_tokenizer_vocab_size = 128_256
llama3_instruct_tokenizer = "meta-llama/Meta-Llama-3.1-8B-Instruct"

# Llama 3 chat stop token IDs for generation_config.json.
# The chat template ends every turn (user, assistant, system) with <|eot_id|> (128009),
# but the tokenizer's eos_token is <|end_of_text|> (128001), which is the pre-training
# document boundary. Both must be listed as stop tokens so vLLM stops on either.
# Determined by running: tokenizer.apply_chat_template([...], tokenize=True)
# and observing the last token of the assistant turn is 128009.
LLAMA3_CHAT_STOP_TOKEN_IDS = [128001, 128009]

# Llama3 instruct trainable chat template
# Slight modification of https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct/blob/main/tokenizer_config.json
# to add {% generation %} so we can create the assistant_mask
Expand Down
3 changes: 3 additions & 0 deletions experiments/simple_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class SimpleDPOConfig:
steps_per_checkpoint: int = 1000
steps_per_hf_export: int = 500
hf_save_dtype: str | None = None
hf_generation_eos_token_ids: list[int] | None = None
"""EOS token IDs to write to generation_config.json. None means no generation config.
For chat models, include the turn-boundary token (e.g. [128001, 128009])."""

seed: int = 0
initialize_from_hf: bool | None = None
Expand Down
4 changes: 4 additions & 0 deletions experiments/simple_sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ class SimpleSFTConfig:
steps_per_hf_export: int = 500
"""How often to save HuggingFace checkpoints."""

hf_generation_eos_token_ids: list[int] | None = None
"""EOS token IDs to write to generation_config.json. None means no generation config.
For chat models, include the turn-boundary token (e.g. [128001, 128009])."""

# Mixture-specific parameters
mixture_block_size: int = 2048
"""Block size for dataset mixing (only used with mixture training)."""
Expand Down
2 changes: 2 additions & 0 deletions experiments/simple_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class SimpleTrainConfig:
"""how often to run task evaluations"""
steps_per_hf_export: int | None = None
"""None means match steps_per_export, -1 disables"""
hf_generation_eos_token_ids: list[int] | None = None
"""EOS token IDs to write to generation_config.json. None means no generation config."""
per_device_parallelism: int = -1
"""How many examples to process in parallel on each device. -1 (default) means
train_batch_size/num_devices (no gradient accumulation). Set to a positive value
Expand Down
89 changes: 89 additions & 0 deletions experiments/test_dpo_generation_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""
Smoke test: verify that hf_generation_eos_token_ids writes generation_config.json
in DPO checkpoint saves. Uses marin-8b-instruct on v5p-8 with 2 training steps.

After the run completes, check the HF checkpoint output for generation_config.json:
gcloud storage cat gs://marin-us-east5/checkpoints/dpo/test_generation_config_smoke-<hash>/hf/step-2/generation_config.json

Expected content: {"eos_token_id": [128001, 128009], "bos_token_id": 128000}
"""

from levanter.data.text import PreferenceChatLmDatasetFormat

from experiments.defaults import default_dpo, default_tokenize
from experiments.llama import LLAMA3_CHAT_STOP_TOKEN_IDS, llama_8b
from experiments.marin_models import marin_tokenizer
from experiments.posttrain.preference_datasets import get_preference_dataset
from experiments.simple_dpo_config import SimpleDPOConfig
from fray.cluster import ResourceConfig
from marin.execution.executor import executor_main
from marin.processing.tokenize import lm_data_config

DATASET_NAME = "HuggingFaceH4/ultrafeedback_binarized"

preference_dataset = get_preference_dataset(DATASET_NAME, splits=["train_prefs", "test_prefs"])

tokenized_train = default_tokenize(
name="ultrafeedback_binarized_train_prefs_marin_tokenizer",
dataset=preference_dataset / "train_prefs/*.jsonl.gz",
tokenizer=marin_tokenizer,
format=PreferenceChatLmDatasetFormat(),
)

tokenized_val = default_tokenize(
name="ultrafeedback_binarized_test_prefs_marin_tokenizer",
dataset=preference_dataset / "test_prefs/*.jsonl.gz",
tokenizer=marin_tokenizer,
format=PreferenceChatLmDatasetFormat(),
is_validation=True,
)

tokenized_preferences = lm_data_config(
training_set=tokenized_train,
validation_sets={"ultrafeedback_test_prefs": tokenized_val},
)

dpo_config = SimpleDPOConfig(
resources=ResourceConfig.with_tpu("v5p-8", ram="400g"),
train_batch_size=8,
num_train_steps=2,
learning_rate=5e-7,
lr_schedule="cosine",
warmup=0,
wandb_project="dpo",
tokenizer=marin_tokenizer,
model_name_or_path="marin-community/marin-8b-instruct",
reference_model_path="marin-community/marin-8b-instruct",
reference_is_hf=True,
train_seq_len=4096,
max_seq_len=4096,
beta=0.1,
validation_split_fraction=None,
steps_per_eval=2,
steps_per_checkpoint=2,
steps_per_hf_export=2,
hf_generation_eos_token_ids=LLAMA3_CHAT_STOP_TOKEN_IDS,
seed=0,
)

training_step = default_dpo(
name="dpo/test_generation_config_smoke",
tokenized=tokenized_preferences,
model_config=llama_8b,
dpo_config=dpo_config,
tags=["dpo", "smoke-test", "generation-config"],
)


if __name__ == "__main__":
executor_main(
steps=[
preference_dataset,
tokenized_train,
tokenized_val,
training_step,
]
)
29 changes: 29 additions & 0 deletions lib/levanter/docs/guides/DPO-Training.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,35 @@ validation_split_fraction: 0.1 # Auto-split from training data; null to disable
training data for validation. Set to `null` to use separately configured
validation sets.

### Generation Stop Tokens

Chat models typically use a turn-boundary token (e.g. `<|eot_id|>`, ID 128009)
to end assistant responses, while the tokenizer's `eos_token` remains the
pre-training document boundary (e.g. `<|end_of_text|>`, ID 128001). Inference
tools like vLLM use `eos_token_id` from `config.json` to decide when to stop,
so they will miss the chat stop token unless told otherwise.

Set `hf_generation_eos_token_ids` to write a `generation_config.json` alongside
each HF checkpoint with the correct stop tokens:

```yaml
hf_generation_eos_token_ids: [128001, 128009] # <|end_of_text|> + <|eot_id|>
```

The tokenizer's `eos_token_id` is auto-added if not in the list. This field
defaults to `null` (no `generation_config.json` written), preserving backward
compatibility with pretraining checkpoints.

To determine the right stop token for your model's chat template:

```python
tokens = tokenizer.apply_chat_template(
[{"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}],
tokenize=True,
)
# The last token is the chat stop token (e.g. 128009 for Llama 3)
```

## Running

```bash
Expand Down
67 changes: 67 additions & 0 deletions lib/levanter/src/levanter/compat/hf_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,63 @@ def _is_hf_model_id(path: str) -> bool:
PYTORCH_WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
SAFE_TENSORS_INDEX_NAME = "model.safetensors.index.json"

GenerationConfigDict = dict[str, int | list[int]]


def build_generation_config(
tokenizer: "PreTrainedTokenizerBase",
eos_token_ids: list[int] | None,
) -> GenerationConfigDict | None:
"""Build a validated generation_config dict from explicit EOS token IDs.

The returned dict is suitable for writing as ``generation_config.json``
alongside an HF checkpoint. It tells inference tools like vLLM which
tokens should stop generation (e.g. both ``<|end_of_text|>`` and
``<|eot_id|>`` for chat models).

Normalization guarantees:
- Output ``eos_token_id`` is always sorted and deduplicated.
- The tokenizer's own ``eos_token_id`` is auto-added if not already present.

Args:
tokenizer: The tokenizer that will be saved with the checkpoint.
eos_token_ids: Explicit list of EOS token IDs, or ``None`` to skip.

Returns:
A config dict ready for JSON serialization, or ``None`` if
*eos_token_ids* is ``None``.

Raises:
ValueError: If the list is empty, contains non-ints, or contains
IDs outside the tokenizer's vocabulary range.
"""
if eos_token_ids is None:
return None

if not eos_token_ids:
raise ValueError("hf_generation_eos_token_ids must be non-empty when set")

vocab_size = len(tokenizer)
for tid in eos_token_ids:
if not isinstance(tid, int):
raise ValueError(f"hf_generation_eos_token_ids contains non-int: {tid!r}")
if not (0 <= tid < vocab_size):
raise ValueError(f"Token ID {tid} out of range [0, {vocab_size})")

ids = set(eos_token_ids)

tok_eos = tokenizer.eos_token_id
if tok_eos is None:
logger.warning("Tokenizer has no eos_token_id; generation config will use only the provided IDs")
elif tok_eos not in ids:
logger.info("Auto-adding tokenizer eos_token_id=%d to generation config", tok_eos)
ids.add(tok_eos)

gen_config: GenerationConfigDict = {"eos_token_id": sorted(ids)}
if tokenizer.bos_token_id is not None:
gen_config["bos_token_id"] = tokenizer.bos_token_id
return gen_config


@dataclass(frozen=True)
class RepoRef:
Expand Down Expand Up @@ -881,6 +938,7 @@ def save_pretrained(
max_shard_size: int = DEFAULT_MAX_SHARD_SIZE,
save_feature_extractor: bool = False,
dtype: Optional[jnp.dtype] = None,
generation_config: Optional[GenerationConfigDict] = None,
**hf_upload_kwargs,
):
"""
Expand Down Expand Up @@ -1055,6 +1113,13 @@ def _list_relative_files(directory: str) -> set[str]:
with open(os.path.join(local_path, "config.json"), "w") as f:
json.dump(dict_config, f, cls=ConfigJSONEncoder)

if generation_config is not None:
logger.info(
"Writing generation_config.json with eos_token_id=%s", generation_config.get("eos_token_id")
)
with open(os.path.join(local_path, "generation_config.json"), "w") as f:
json.dump(generation_config, f)

if index is not None:
with open(os.path.join(local_path, SAFE_TENSORS_INDEX_NAME), "w") as f:
json.dump(index, f)
Expand Down Expand Up @@ -1149,6 +1214,7 @@ def save_hf_checkpoint_callback(
converter: HFCheckpointConverter,
upload_to_hf: Union[bool, str, RepoRef] = False,
save_dtype: Optional[jnp.dtype] = None,
generation_config: Optional[GenerationConfigDict] = None,
**hf_upload_kwargs,
):
"""
Expand Down Expand Up @@ -1176,6 +1242,7 @@ def cb(step: StepInfo):
os.path.join(base_path, f"step-{step.step}"),
upload_to_hf=upload_to_hf,
dtype=save_dtype,
generation_config=generation_config,
**my_upload_kwargs,
)

Expand Down
6 changes: 5 additions & 1 deletion lib/levanter/src/levanter/main/train_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import levanter.callbacks
from levanter import callbacks
from levanter.checkpoint import load_checkpoint
from levanter.compat.hf_checkpoints import HFCompatConfig
from levanter.compat.hf_checkpoints import HFCompatConfig, build_generation_config
from levanter.data.dataset import AsyncDataset
from levanter.data.mixture import MixtureDataset
from levanter.data.text import (
Expand Down Expand Up @@ -231,6 +231,7 @@ class TrainDpoConfig:
hf_upload: Optional[str] = None
hf_save_steps: int = 10000
hf_save_dtype: Optional[str] = None
hf_generation_eos_token_ids: Optional[list[int]] = None

data_seed: Optional[int] = None
initialize_from_checkpoint_path: Optional[str] = None
Expand All @@ -242,6 +243,8 @@ def main(config: TrainDpoConfig):

tokenizer = config.data.the_tokenizer

_generation_config = build_generation_config(tokenizer, config.hf_generation_eos_token_ids)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Gate DPO generation-config validation on HF export

build_generation_config(...) is executed at startup even when HF checkpoint export is disabled (hf_save_path is None or hf_save_steps is None), so a malformed hf_generation_eos_token_ids value can fail the entire DPO run even though no generation_config.json will be written. This is an avoidable regression in behavior (and inconsistent with train_lm, which computes this only inside the HF-save block), so the call should be deferred until export is actually enabled.

Useful? React with 👍 / 👎.


if config.initialize_from_hf:
if config.trainer.initialize_from is not None:
raise ValueError("Cannot specify both initialize_from_hf and initialize_from")
Expand Down Expand Up @@ -472,6 +475,7 @@ def save_policy_hf_checkpoint(step):
os.path.join(full_save_path, f"step-{step.step}"),
upload_to_hf=upload_to_hf,
dtype=save_dtype,
generation_config=_generation_config,
**hf_upload_kwargs,
)

Expand Down
Loading
Loading