Skip to content

Commit 180cb79

Browse files
ahmeda14960claude
andauthored
[levanter] Add generation_config.json support for chat model checkpoints (#4160)
## Summary - Add `hf_generation_eos_token_ids` config field to `SimpleDPOConfig`, `SimpleSFTConfig`, `SimpleTrainConfig`, `TrainDpoConfig`, and `TrainLmConfig` - When set (e.g. `[128001, 128009]`), write a validated `generation_config.json` alongside HF checkpoints so vLLM stops on the right tokens for chat models - `config.json` is unchanged — pretraining checkpoints are unaffected - New shared helper `levanter/utils/hf_export.py` with `build_generation_config()` for validation/normalization - `LLAMA3_CHAT_STOP_TOKEN_IDS` constant in `experiments/llama.py` Replaces #4154 (closed). Does **not** modify the tokenizer's `eos_token` or override `eos_token_id` in `config.json`. Fixes #4153 Fixes #4159 ## Test plan - [x] 14 unit tests in `test_hf_export.py` — validation, dedup, sort, auto-add EOS, error cases - [x] `./infra/pre-commit.py --all-files --fix` passes - [x] Pre-commit hooks pass on commit - [x] Verify `generation_config.json` is written when `hf_generation_eos_token_ids=[128001, 128009]` is set on a DPO run - [ ] Verify no `generation_config.json` when field is `None` (default) 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3cdfeb4 commit 180cb79

13 files changed

Lines changed: 365 additions & 4 deletions

File tree

docs/tutorials/train-dpo.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,44 @@ dpo_config = SimpleDPOConfig(
7979
| `model_name_or_path` | HuggingFace model to initialize the policy from. Also used as the reference model unless `reference_model_path` is set separately. |
8080
| `reference_model_path` | Path to the reference model. Defaults to `model_name_or_path`. |
8181
| `validation_split_fraction` | Fraction of training data to hold out for validation (default 0.1). Set to `None` to use a separate validation set. |
82+
| `hf_generation_eos_token_ids` | List of token IDs to write to `generation_config.json` for inference stop conditions. See below. |
83+
84+
### Setting Generation Stop Tokens
85+
86+
Chat models use a turn-boundary token (e.g. `<|eot_id|>`) to end assistant
87+
responses, but the tokenizer's `eos_token` is typically the pre-training
88+
document boundary (`<|end_of_text|>`). Inference tools like vLLM need both
89+
tokens as stop conditions.
90+
91+
Set `hf_generation_eos_token_ids` to write a `generation_config.json` alongside
92+
each saved checkpoint. The tokenizer's `eos_token_id` is auto-added if not
93+
already in the list.
94+
95+
For Llama 3 models, use the predefined constant:
96+
97+
```python
98+
from experiments.llama import LLAMA3_CHAT_STOP_TOKEN_IDS
99+
100+
dpo_config = SimpleDPOConfig(
101+
...
102+
hf_generation_eos_token_ids=LLAMA3_CHAT_STOP_TOKEN_IDS, # [128001, 128009]
103+
)
104+
```
105+
106+
For other model families, determine the correct stop token by applying the
107+
chat template and checking the last token of the assistant turn:
108+
109+
```python
110+
tokens = tokenizer.apply_chat_template(
111+
[{"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}],
112+
tokenize=True,
113+
)
114+
print(f"Chat stop token: {tokens[-1]}") # e.g. 128009 for <|eot_id|>
115+
print(f"Tokenizer EOS: {tokenizer.eos_token_id}") # e.g. 128001 for <|end_of_text|>
116+
```
117+
118+
If the two differ, pass both: `hf_generation_eos_token_ids=[eos_token_id, chat_stop_token]`.
119+
If they match, you don't need to set this field.
82120

83121
## Creating the Training Pipeline
84122

experiments/defaults.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,7 @@ def default_train(
467467
)
468468
),
469469
hf_save_steps=steps_per_export_hf,
470+
hf_generation_eos_token_ids=train_config.hf_generation_eos_token_ids,
470471
data_seed=train_config.data_seed,
471472
eval_harness_steps=train_config.steps_per_task_eval or 10000,
472473
eval_harness=harness_config,
@@ -557,6 +558,7 @@ def default_sft(
557558
beta2=sft_config.beta2,
558559
pad_tokenizer_to_match_model=sft_config.pad_tokenizer_to_match_model,
559560
per_device_parallelism=sft_config.per_device_parallelism,
561+
hf_generation_eos_token_ids=sft_config.hf_generation_eos_token_ids,
560562
)
561563

562564
if sft_config.reinit_tokens:
@@ -672,6 +674,7 @@ def default_dpo(
672674
validation_split_fraction=dpo_config.validation_split_fraction,
673675
hf_save_steps=steps_per_export_hf,
674676
hf_save_dtype=dpo_config.hf_save_dtype,
677+
hf_generation_eos_token_ids=dpo_config.hf_generation_eos_token_ids,
675678
data_seed=dpo_config.seed,
676679
)
677680

experiments/dpo_ultrafeedback.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from levanter.data.text import PreferenceChatLmDatasetFormat
99

1010
from experiments.defaults import default_dpo, default_tokenize
11-
from experiments.llama import llama_8b
11+
from experiments.llama import LLAMA3_CHAT_STOP_TOKEN_IDS, llama_8b
1212
from experiments.marin_models import marin_tokenizer
1313
from experiments.models import llama_3_1_8b
1414
from experiments.posttrain.preference_datasets import get_preference_dataset
@@ -61,6 +61,7 @@
6161
steps_per_eval=200,
6262
steps_per_checkpoint=1000,
6363
steps_per_hf_export=1000,
64+
hf_generation_eos_token_ids=LLAMA3_CHAT_STOP_TOKEN_IDS,
6465
seed=0,
6566
)
6667

experiments/llama.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
llama3_tokenizer_vocab_size = 128_256
1919
llama3_instruct_tokenizer = "meta-llama/Meta-Llama-3.1-8B-Instruct"
2020

21+
# Llama 3 chat stop token IDs for generation_config.json.
22+
# The chat template ends every turn (user, assistant, system) with <|eot_id|> (128009),
23+
# but the tokenizer's eos_token is <|end_of_text|> (128001), which is the pre-training
24+
# document boundary. Both must be listed as stop tokens so vLLM stops on either.
25+
# Determined by running: tokenizer.apply_chat_template([...], tokenize=True)
26+
# and observing the last token of the assistant turn is 128009.
27+
LLAMA3_CHAT_STOP_TOKEN_IDS = [128001, 128009]
28+
2129
# Llama3 instruct trainable chat template
2230
# Slight modification of https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct/blob/main/tokenizer_config.json
2331
# to add {% generation %} so we can create the assistant_mask

experiments/simple_dpo_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class SimpleDPOConfig:
4343
steps_per_checkpoint: int = 1000
4444
steps_per_hf_export: int = 500
4545
hf_save_dtype: str | None = None
46+
hf_generation_eos_token_ids: list[int] | None = None
47+
"""EOS token IDs to write to generation_config.json. None means no generation config.
48+
For chat models, include the turn-boundary token (e.g. [128001, 128009])."""
4649

4750
seed: int = 0
4851
initialize_from_hf: bool | None = None

experiments/simple_sft_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ class SimpleSFTConfig:
123123
steps_per_hf_export: int = 500
124124
"""How often to save HuggingFace checkpoints."""
125125

126+
hf_generation_eos_token_ids: list[int] | None = None
127+
"""EOS token IDs to write to generation_config.json. None means no generation config.
128+
For chat models, include the turn-boundary token (e.g. [128001, 128009])."""
129+
126130
# Mixture-specific parameters
127131
mixture_block_size: int = 2048
128132
"""Block size for dataset mixing (only used with mixture training)."""

experiments/simple_train_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class SimpleTrainConfig:
5050
"""how often to run task evaluations"""
5151
steps_per_hf_export: int | None = None
5252
"""None means match steps_per_export, -1 disables"""
53+
hf_generation_eos_token_ids: list[int] | None = None
54+
"""EOS token IDs to write to generation_config.json. None means no generation config."""
5355
per_device_parallelism: int = -1
5456
"""How many examples to process in parallel on each device. -1 (default) means
5557
train_batch_size/num_devices (no gradient accumulation). Set to a positive value
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Smoke test: verify that hf_generation_eos_token_ids writes generation_config.json
6+
in DPO checkpoint saves. Uses marin-8b-instruct on v5p-8 with 2 training steps.
7+
8+
After the run completes, check the HF checkpoint output for generation_config.json:
9+
gcloud storage cat gs://marin-us-east5/checkpoints/dpo/test_generation_config_smoke-<hash>/hf/step-2/generation_config.json
10+
11+
Expected content: {"eos_token_id": [128001, 128009], "bos_token_id": 128000}
12+
"""
13+
14+
from levanter.data.text import PreferenceChatLmDatasetFormat
15+
16+
from experiments.defaults import default_dpo, default_tokenize
17+
from experiments.llama import LLAMA3_CHAT_STOP_TOKEN_IDS, llama_8b
18+
from experiments.marin_models import marin_tokenizer
19+
from experiments.posttrain.preference_datasets import get_preference_dataset
20+
from experiments.simple_dpo_config import SimpleDPOConfig
21+
from fray.cluster import ResourceConfig
22+
from marin.execution.executor import executor_main
23+
from marin.processing.tokenize import lm_data_config
24+
25+
DATASET_NAME = "HuggingFaceH4/ultrafeedback_binarized"
26+
27+
preference_dataset = get_preference_dataset(DATASET_NAME, splits=["train_prefs", "test_prefs"])
28+
29+
tokenized_train = default_tokenize(
30+
name="ultrafeedback_binarized_train_prefs_marin_tokenizer",
31+
dataset=preference_dataset / "train_prefs/*.jsonl.gz",
32+
tokenizer=marin_tokenizer,
33+
format=PreferenceChatLmDatasetFormat(),
34+
)
35+
36+
tokenized_val = default_tokenize(
37+
name="ultrafeedback_binarized_test_prefs_marin_tokenizer",
38+
dataset=preference_dataset / "test_prefs/*.jsonl.gz",
39+
tokenizer=marin_tokenizer,
40+
format=PreferenceChatLmDatasetFormat(),
41+
is_validation=True,
42+
)
43+
44+
tokenized_preferences = lm_data_config(
45+
training_set=tokenized_train,
46+
validation_sets={"ultrafeedback_test_prefs": tokenized_val},
47+
)
48+
49+
dpo_config = SimpleDPOConfig(
50+
resources=ResourceConfig.with_tpu("v5p-8", ram="400g"),
51+
train_batch_size=8,
52+
num_train_steps=2,
53+
learning_rate=5e-7,
54+
lr_schedule="cosine",
55+
warmup=0,
56+
wandb_project="dpo",
57+
tokenizer=marin_tokenizer,
58+
model_name_or_path="marin-community/marin-8b-instruct",
59+
reference_model_path="marin-community/marin-8b-instruct",
60+
reference_is_hf=True,
61+
train_seq_len=4096,
62+
max_seq_len=4096,
63+
beta=0.1,
64+
validation_split_fraction=None,
65+
steps_per_eval=2,
66+
steps_per_checkpoint=2,
67+
steps_per_hf_export=2,
68+
hf_generation_eos_token_ids=LLAMA3_CHAT_STOP_TOKEN_IDS,
69+
seed=0,
70+
)
71+
72+
training_step = default_dpo(
73+
name="dpo/test_generation_config_smoke",
74+
tokenized=tokenized_preferences,
75+
model_config=llama_8b,
76+
dpo_config=dpo_config,
77+
tags=["dpo", "smoke-test", "generation-config"],
78+
)
79+
80+
81+
if __name__ == "__main__":
82+
executor_main(
83+
steps=[
84+
preference_dataset,
85+
tokenized_train,
86+
tokenized_val,
87+
training_step,
88+
]
89+
)

lib/levanter/docs/guides/DPO-Training.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,35 @@ validation_split_fraction: 0.1 # Auto-split from training data; null to disable
8383
training data for validation. Set to `null` to use separately configured
8484
validation sets.
8585

86+
### Generation Stop Tokens
87+
88+
Chat models typically use a turn-boundary token (e.g. `<|eot_id|>`, ID 128009)
89+
to end assistant responses, while the tokenizer's `eos_token` remains the
90+
pre-training document boundary (e.g. `<|end_of_text|>`, ID 128001). Inference
91+
tools like vLLM use `eos_token_id` from `config.json` to decide when to stop,
92+
so they will miss the chat stop token unless told otherwise.
93+
94+
Set `hf_generation_eos_token_ids` to write a `generation_config.json` alongside
95+
each HF checkpoint with the correct stop tokens:
96+
97+
```yaml
98+
hf_generation_eos_token_ids: [128001, 128009] # <|end_of_text|> + <|eot_id|>
99+
```
100+
101+
The tokenizer's `eos_token_id` is auto-added if not in the list. This field
102+
defaults to `null` (no `generation_config.json` written), preserving backward
103+
compatibility with pretraining checkpoints.
104+
105+
To determine the right stop token for your model's chat template:
106+
107+
```python
108+
tokens = tokenizer.apply_chat_template(
109+
[{"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}],
110+
tokenize=True,
111+
)
112+
# The last token is the chat stop token (e.g. 128009 for Llama 3)
113+
```
114+
86115
## Running
87116

88117
```bash

lib/levanter/src/levanter/compat/hf_checkpoints.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,63 @@ def _is_hf_model_id(path: str) -> bool:
117117
PYTORCH_WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
118118
SAFE_TENSORS_INDEX_NAME = "model.safetensors.index.json"
119119

120+
GenerationConfigDict = dict[str, int | list[int]]
121+
122+
123+
def build_generation_config(
124+
tokenizer: "PreTrainedTokenizerBase",
125+
eos_token_ids: list[int] | None,
126+
) -> GenerationConfigDict | None:
127+
"""Build a validated generation_config dict from explicit EOS token IDs.
128+
129+
The returned dict is suitable for writing as ``generation_config.json``
130+
alongside an HF checkpoint. It tells inference tools like vLLM which
131+
tokens should stop generation (e.g. both ``<|end_of_text|>`` and
132+
``<|eot_id|>`` for chat models).
133+
134+
Normalization guarantees:
135+
- Output ``eos_token_id`` is always sorted and deduplicated.
136+
- The tokenizer's own ``eos_token_id`` is auto-added if not already present.
137+
138+
Args:
139+
tokenizer: The tokenizer that will be saved with the checkpoint.
140+
eos_token_ids: Explicit list of EOS token IDs, or ``None`` to skip.
141+
142+
Returns:
143+
A config dict ready for JSON serialization, or ``None`` if
144+
*eos_token_ids* is ``None``.
145+
146+
Raises:
147+
ValueError: If the list is empty, contains non-ints, or contains
148+
IDs outside the tokenizer's vocabulary range.
149+
"""
150+
if eos_token_ids is None:
151+
return None
152+
153+
if not eos_token_ids:
154+
raise ValueError("hf_generation_eos_token_ids must be non-empty when set")
155+
156+
vocab_size = len(tokenizer)
157+
for tid in eos_token_ids:
158+
if not isinstance(tid, int):
159+
raise ValueError(f"hf_generation_eos_token_ids contains non-int: {tid!r}")
160+
if not (0 <= tid < vocab_size):
161+
raise ValueError(f"Token ID {tid} out of range [0, {vocab_size})")
162+
163+
ids = set(eos_token_ids)
164+
165+
tok_eos = tokenizer.eos_token_id
166+
if tok_eos is None:
167+
logger.warning("Tokenizer has no eos_token_id; generation config will use only the provided IDs")
168+
elif tok_eos not in ids:
169+
logger.info("Auto-adding tokenizer eos_token_id=%d to generation config", tok_eos)
170+
ids.add(tok_eos)
171+
172+
gen_config: GenerationConfigDict = {"eos_token_id": sorted(ids)}
173+
if tokenizer.bos_token_id is not None:
174+
gen_config["bos_token_id"] = tokenizer.bos_token_id
175+
return gen_config
176+
120177

121178
@dataclass(frozen=True)
122179
class RepoRef:
@@ -881,6 +938,7 @@ def save_pretrained(
881938
max_shard_size: int = DEFAULT_MAX_SHARD_SIZE,
882939
save_feature_extractor: bool = False,
883940
dtype: Optional[jnp.dtype] = None,
941+
generation_config: Optional[GenerationConfigDict] = None,
884942
**hf_upload_kwargs,
885943
):
886944
"""
@@ -1055,6 +1113,13 @@ def _list_relative_files(directory: str) -> set[str]:
10551113
with open(os.path.join(local_path, "config.json"), "w") as f:
10561114
json.dump(dict_config, f, cls=ConfigJSONEncoder)
10571115

1116+
if generation_config is not None:
1117+
logger.info(
1118+
"Writing generation_config.json with eos_token_id=%s", generation_config.get("eos_token_id")
1119+
)
1120+
with open(os.path.join(local_path, "generation_config.json"), "w") as f:
1121+
json.dump(generation_config, f)
1122+
10581123
if index is not None:
10591124
with open(os.path.join(local_path, SAFE_TENSORS_INDEX_NAME), "w") as f:
10601125
json.dump(index, f)
@@ -1149,6 +1214,7 @@ def save_hf_checkpoint_callback(
11491214
converter: HFCheckpointConverter,
11501215
upload_to_hf: Union[bool, str, RepoRef] = False,
11511216
save_dtype: Optional[jnp.dtype] = None,
1217+
generation_config: Optional[GenerationConfigDict] = None,
11521218
**hf_upload_kwargs,
11531219
):
11541220
"""
@@ -1176,6 +1242,7 @@ def cb(step: StepInfo):
11761242
os.path.join(base_path, f"step-{step.step}"),
11771243
upload_to_hf=upload_to_hf,
11781244
dtype=save_dtype,
1245+
generation_config=generation_config,
11791246
**my_upload_kwargs,
11801247
)
11811248

0 commit comments

Comments
 (0)