Skip to content

Commit bb19795

Browse files
ahmeda14960claude
andcommitted
Add generation_config.json support for chat model HF checkpoints
Chat models need vLLM to stop on <|eot_id|> (128009), but the tokenizer's eos_token is <|end_of_text|> (128001) for pretraining. Add explicit hf_generation_eos_token_ids config field that writes a generation_config.json alongside saved checkpoints with the validated stop token IDs. - New helper module levanter/utils/hf_export.py with build_generation_config() - save_pretrained() and save_hf_checkpoint_callback() accept generation_config - Config field threaded through SimpleDPOConfig, SimpleSFTConfig, SimpleTrainConfig, TrainDpoConfig, TrainLmConfig, and defaults.py - LLAMA3_CHAT_STOP_TOKEN_IDS constant in experiments/llama.py - 14 unit tests for validation and normalization Fixes #4153 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent fee874c commit bb19795

10 files changed

Lines changed: 218 additions & 1 deletion

File tree

experiments/defaults.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,7 @@ def default_train(
467467
)
468468
),
469469
hf_save_steps=steps_per_export_hf,
470+
hf_generation_eos_token_ids=train_config.hf_generation_eos_token_ids,
470471
data_seed=train_config.data_seed,
471472
eval_harness_steps=train_config.steps_per_task_eval or 10000,
472473
eval_harness=harness_config,
@@ -557,6 +558,7 @@ def default_sft(
557558
beta2=sft_config.beta2,
558559
pad_tokenizer_to_match_model=sft_config.pad_tokenizer_to_match_model,
559560
per_device_parallelism=sft_config.per_device_parallelism,
561+
hf_generation_eos_token_ids=sft_config.hf_generation_eos_token_ids,
560562
)
561563

562564
if sft_config.reinit_tokens:
@@ -672,6 +674,7 @@ def default_dpo(
672674
validation_split_fraction=dpo_config.validation_split_fraction,
673675
hf_save_steps=steps_per_export_hf,
674676
hf_save_dtype=dpo_config.hf_save_dtype,
677+
hf_generation_eos_token_ids=dpo_config.hf_generation_eos_token_ids,
675678
data_seed=dpo_config.seed,
676679
)
677680

experiments/llama.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
llama3_tokenizer_vocab_size = 128_256
1919
llama3_instruct_tokenizer = "meta-llama/Meta-Llama-3.1-8B-Instruct"
2020

21+
# Llama 3 chat stop token IDs for generation_config.json.
22+
# The chat template ends every turn (user, assistant, system) with <|eot_id|> (128009),
23+
# but the tokenizer's eos_token is <|end_of_text|> (128001), which is the pre-training
24+
# document boundary. Both must be listed as stop tokens so vLLM stops on either.
25+
# Determined by running: tokenizer.apply_chat_template([...], tokenize=True)
26+
# and observing the last token of the assistant turn is 128009.
27+
LLAMA3_CHAT_STOP_TOKEN_IDS = [128001, 128009]
28+
2129
# Llama3 instruct trainable chat template
2230
# Slight modification of https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct/blob/main/tokenizer_config.json
2331
# to add {% generation %} so we can create the assistant_mask

experiments/simple_dpo_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class SimpleDPOConfig:
4343
steps_per_checkpoint: int = 1000
4444
steps_per_hf_export: int = 500
4545
hf_save_dtype: str | None = None
46+
hf_generation_eos_token_ids: list[int] | None = None
47+
"""EOS token IDs to write to generation_config.json. None means no generation config.
48+
For chat models, include the turn-boundary token (e.g. [128001, 128009])."""
4649

4750
seed: int = 0
4851
initialize_from_hf: bool | None = None

experiments/simple_sft_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ class SimpleSFTConfig:
123123
steps_per_hf_export: int = 500
124124
"""How often to save HuggingFace checkpoints."""
125125

126+
hf_generation_eos_token_ids: list[int] | None = None
127+
"""EOS token IDs to write to generation_config.json. None means no generation config.
128+
For chat models, include the turn-boundary token (e.g. [128001, 128009])."""
129+
126130
# Mixture-specific parameters
127131
mixture_block_size: int = 2048
128132
"""Block size for dataset mixing (only used with mixture training)."""

experiments/simple_train_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class SimpleTrainConfig:
5050
"""how often to run task evaluations"""
5151
steps_per_hf_export: int | None = None
5252
"""None means match steps_per_export, -1 disables"""
53+
hf_generation_eos_token_ids: list[int] | None = None
54+
"""EOS token IDs to write to generation_config.json. None means no generation config."""
5355
per_device_parallelism: int = -1
5456
"""How many examples to process in parallel on each device. -1 (default) means
5557
train_batch_size/num_devices (no gradient accumulation). Set to a positive value

lib/levanter/src/levanter/compat/hf_checkpoints.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from levanter.utils.cloud_utils import temp_dir_before_upload
5757
from levanter.utils.hf_utils import HfTokenizer
5858
from levanter.utils.jax_utils import best_effort_sharding, sync_global_devices, use_cpu_device
59+
from levanter.utils.hf_export import GenerationConfigDict
5960
from levanter.utils.json_utils import ConfigJSONEncoder
6061
from levanter.utils.logging import silence_transformer_nag
6162
from levanter.utils.py_utils import dataclass_with_default_init
@@ -881,6 +882,7 @@ def save_pretrained(
881882
max_shard_size: int = DEFAULT_MAX_SHARD_SIZE,
882883
save_feature_extractor: bool = False,
883884
dtype: Optional[jnp.dtype] = None,
885+
generation_config: Optional[GenerationConfigDict] = None,
884886
**hf_upload_kwargs,
885887
):
886888
"""
@@ -1055,6 +1057,13 @@ def _list_relative_files(directory: str) -> set[str]:
10551057
with open(os.path.join(local_path, "config.json"), "w") as f:
10561058
json.dump(dict_config, f, cls=ConfigJSONEncoder)
10571059

1060+
if generation_config is not None:
1061+
logger.info(
1062+
"Writing generation_config.json with eos_token_id=%s", generation_config.get("eos_token_id")
1063+
)
1064+
with open(os.path.join(local_path, "generation_config.json"), "w") as f:
1065+
json.dump(generation_config, f)
1066+
10581067
if index is not None:
10591068
with open(os.path.join(local_path, SAFE_TENSORS_INDEX_NAME), "w") as f:
10601069
json.dump(index, f)
@@ -1149,6 +1158,7 @@ def save_hf_checkpoint_callback(
11491158
converter: HFCheckpointConverter,
11501159
upload_to_hf: Union[bool, str, RepoRef] = False,
11511160
save_dtype: Optional[jnp.dtype] = None,
1161+
generation_config: Optional[GenerationConfigDict] = None,
11521162
**hf_upload_kwargs,
11531163
):
11541164
"""
@@ -1176,6 +1186,7 @@ def cb(step: StepInfo):
11761186
os.path.join(base_path, f"step-{step.step}"),
11771187
upload_to_hf=upload_to_hf,
11781188
dtype=save_dtype,
1189+
generation_config=generation_config,
11791190
**my_upload_kwargs,
11801191
)
11811192

lib/levanter/src/levanter/main/train_dpo.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from levanter.metrics import Metric, ReductionType
3333
from levanter.optim import AdamConfig, OptimizerConfig
3434
from levanter.trainer import Trainer, TrainerConfig
35+
from levanter.utils.hf_export import build_generation_config
3536
from levanter.utils.jax_utils import parameter_count, use_cpu_device
3637
from levanter.utils.tree_utils import inference_mode
3738

@@ -231,6 +232,7 @@ class TrainDpoConfig:
231232
hf_upload: Optional[str] = None
232233
hf_save_steps: int = 10000
233234
hf_save_dtype: Optional[str] = None
235+
hf_generation_eos_token_ids: Optional[list[int]] = None
234236

235237
data_seed: Optional[int] = None
236238
initialize_from_checkpoint_path: Optional[str] = None
@@ -242,6 +244,8 @@ def main(config: TrainDpoConfig):
242244

243245
tokenizer = config.data.the_tokenizer
244246

247+
_generation_config = build_generation_config(tokenizer, config.hf_generation_eos_token_ids)
248+
245249
if config.initialize_from_hf:
246250
if config.trainer.initialize_from is not None:
247251
raise ValueError("Cannot specify both initialize_from_hf and initialize_from")
@@ -472,6 +476,7 @@ def save_policy_hf_checkpoint(step):
472476
os.path.join(full_save_path, f"step-{step.step}"),
473477
upload_to_hf=upload_to_hf,
474478
dtype=save_dtype,
479+
generation_config=_generation_config,
475480
**hf_upload_kwargs,
476481
)
477482

lib/levanter/src/levanter/main/train_lm.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel
3030
from levanter.optim import AdamConfig, OptimizerConfig
3131
from levanter.trainer import Trainer, TrainerConfig
32+
from levanter.utils.hf_export import build_generation_config
3233
from levanter.utils.jax_utils import parameter_count
3334

3435
logger = logging.getLogger(__name__)
@@ -60,6 +61,7 @@ class TrainLmConfig:
6061
hf_upload: Optional[str] = None
6162
hf_save_steps: int = 10000
6263
hf_save_dtype: Optional[str] = None
64+
hf_generation_eos_token_ids: Optional[list[int]] = None
6365

6466
data_seed: Optional[int] = None # if provided, will override the data seed from the trainer
6567
initialize_from_checkpoint_path: Optional[str] = None
@@ -264,9 +266,15 @@ def log_mixture_weights(step_info):
264266
except TypeError:
265267
logger.warning(f"Invalid hf_save_dtype: {config.hf_save_dtype}. Defaulting to None.")
266268

269+
_generation_config = build_generation_config(tokenizer, config.hf_generation_eos_token_ids)
270+
267271
trainer.add_hook(
268272
save_hf_checkpoint_callback(
269-
full_save_path, converter, upload_to_hf=config.hf_upload or False, save_dtype=save_dtype
273+
full_save_path,
274+
converter,
275+
upload_to_hf=config.hf_upload or False,
276+
save_dtype=save_dtype,
277+
generation_config=_generation_config,
270278
),
271279
every=config.hf_save_steps,
272280
)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright The Levanter Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Helpers for exporting HuggingFace-compatible checkpoints."""
5+
6+
import logging
7+
8+
from transformers import PreTrainedTokenizerBase
9+
10+
logger = logging.getLogger(__name__)
11+
12+
GenerationConfigDict = dict[str, int | list[int]]
13+
14+
15+
def build_generation_config(
16+
tokenizer: PreTrainedTokenizerBase,
17+
eos_token_ids: list[int] | None,
18+
) -> GenerationConfigDict | None:
19+
"""Build a validated generation_config dict from explicit EOS token IDs.
20+
21+
The returned dict is suitable for writing as ``generation_config.json``
22+
alongside an HF checkpoint. It tells inference tools like vLLM which
23+
tokens should stop generation (e.g. both ``<|end_of_text|>`` and
24+
``<|eot_id|>`` for chat models).
25+
26+
Normalization guarantees:
27+
- Output ``eos_token_id`` is always sorted and deduplicated.
28+
- The tokenizer's own ``eos_token_id`` is auto-added if not already present.
29+
30+
Args:
31+
tokenizer: The tokenizer that will be saved with the checkpoint.
32+
eos_token_ids: Explicit list of EOS token IDs, or ``None`` to skip.
33+
34+
Returns:
35+
A config dict ready for JSON serialization, or ``None`` if
36+
*eos_token_ids* is ``None``.
37+
38+
Raises:
39+
ValueError: If the list is empty, contains non-ints, or contains
40+
IDs outside the tokenizer's vocabulary range.
41+
"""
42+
if eos_token_ids is None:
43+
return None
44+
45+
if not eos_token_ids:
46+
raise ValueError("hf_generation_eos_token_ids must be non-empty when set")
47+
48+
vocab_size = len(tokenizer)
49+
for tid in eos_token_ids:
50+
if not isinstance(tid, int):
51+
raise ValueError(f"hf_generation_eos_token_ids contains non-int: {tid!r}")
52+
if not (0 <= tid < vocab_size):
53+
raise ValueError(f"Token ID {tid} out of range [0, {vocab_size})")
54+
55+
ids = set(eos_token_ids)
56+
57+
tok_eos = tokenizer.eos_token_id
58+
if tok_eos is None:
59+
logger.warning("Tokenizer has no eos_token_id; generation config will use only the provided IDs")
60+
elif tok_eos not in ids:
61+
logger.info("Auto-adding tokenizer eos_token_id=%d to generation config", tok_eos)
62+
ids.add(tok_eos)
63+
64+
gen_config: GenerationConfigDict = {"eos_token_id": sorted(ids)}
65+
if tokenizer.bos_token_id is not None:
66+
gen_config["bos_token_id"] = tokenizer.bos_token_id
67+
return gen_config
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright The Levanter Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Tests for levanter.utils.hf_export — generation config validation and normalization."""
5+
6+
import pytest
7+
8+
from levanter.utils.hf_export import build_generation_config
9+
10+
11+
class _FakeTokenizer:
12+
"""Minimal tokenizer stub for testing build_generation_config."""
13+
14+
def __init__(self, vocab_size: int = 200, eos_token_id: int | None = 2, bos_token_id: int | None = 1):
15+
self._vocab_size = vocab_size
16+
self.eos_token_id = eos_token_id
17+
self.bos_token_id = bos_token_id
18+
19+
def __len__(self):
20+
return self._vocab_size
21+
22+
def convert_ids_to_tokens(self, tid: int) -> str | None:
23+
if 0 <= tid < self._vocab_size:
24+
return f"<tok_{tid}>"
25+
return None
26+
27+
28+
class TestBuildGenerationConfig:
29+
def test_none_returns_none(self):
30+
tok = _FakeTokenizer()
31+
assert build_generation_config(tok, None) is None
32+
33+
def test_valid_ids(self):
34+
tok = _FakeTokenizer(vocab_size=200, eos_token_id=2, bos_token_id=1)
35+
result = build_generation_config(tok, [50])
36+
assert result is not None
37+
assert result["eos_token_id"] == [2, 50]
38+
assert result["bos_token_id"] == 1
39+
40+
def test_deduplication(self):
41+
tok = _FakeTokenizer(vocab_size=200, eos_token_id=2)
42+
result = build_generation_config(tok, [50, 50, 2])
43+
assert result is not None
44+
assert result["eos_token_id"] == [2, 50]
45+
46+
def test_sorted_output(self):
47+
tok = _FakeTokenizer(vocab_size=200, eos_token_id=2)
48+
result = build_generation_config(tok, [100, 50, 75])
49+
assert result is not None
50+
assert result["eos_token_id"] == [2, 50, 75, 100]
51+
52+
def test_deterministic(self):
53+
tok = _FakeTokenizer(vocab_size=200, eos_token_id=2)
54+
r1 = build_generation_config(tok, [128, 64, 128])
55+
r2 = build_generation_config(tok, [64, 128, 64])
56+
assert r1 == r2
57+
58+
def test_auto_adds_tokenizer_eos(self):
59+
tok = _FakeTokenizer(vocab_size=200, eos_token_id=2)
60+
result = build_generation_config(tok, [50])
61+
assert result is not None
62+
assert 2 in result["eos_token_id"]
63+
64+
def test_eos_already_included(self):
65+
tok = _FakeTokenizer(vocab_size=200, eos_token_id=2)
66+
result = build_generation_config(tok, [2, 50])
67+
assert result is not None
68+
assert result["eos_token_id"] == [2, 50]
69+
70+
def test_tokenizer_eos_none(self):
71+
tok = _FakeTokenizer(vocab_size=200, eos_token_id=None)
72+
result = build_generation_config(tok, [50])
73+
assert result is not None
74+
assert result["eos_token_id"] == [50]
75+
76+
def test_bos_included_when_present(self):
77+
tok = _FakeTokenizer(vocab_size=200, eos_token_id=2, bos_token_id=1)
78+
result = build_generation_config(tok, [50])
79+
assert result is not None
80+
assert result["bos_token_id"] == 1
81+
82+
def test_bos_omitted_when_none(self):
83+
tok = _FakeTokenizer(vocab_size=200, eos_token_id=2, bos_token_id=None)
84+
result = build_generation_config(tok, [50])
85+
assert result is not None
86+
assert "bos_token_id" not in result
87+
88+
def test_empty_list_raises(self):
89+
tok = _FakeTokenizer()
90+
with pytest.raises(ValueError, match="non-empty"):
91+
build_generation_config(tok, [])
92+
93+
def test_non_int_raises(self):
94+
tok = _FakeTokenizer()
95+
with pytest.raises(ValueError, match="non-int"):
96+
build_generation_config(tok, [1, "two"]) # type: ignore[list-item]
97+
98+
def test_out_of_range_raises(self):
99+
tok = _FakeTokenizer(vocab_size=100)
100+
with pytest.raises(ValueError, match="out of range"):
101+
build_generation_config(tok, [999])
102+
103+
def test_negative_id_raises(self):
104+
tok = _FakeTokenizer(vocab_size=100)
105+
with pytest.raises(ValueError, match="out of range"):
106+
build_generation_config(tok, [-1])

0 commit comments

Comments
 (0)