Skip to content

Commit 66d9500

Browse files
authored
[levanter] Add MarinTokenizer.as_hf_tokenizer(), fix kitoken find-links (#4451)
with_tokenizer_padded_to_match_model() crashed when given a MarinTokenizer because it called add_tokens(), an HF-only API. Add as_hf_tokenizer() to the MarinTokenizer Protocol so any backend can produce an HF tokenizer on demand. Update padding, save_pretrained, and LoRA export to use it. Also add find-links to levanter's pyproject.toml so uv can resolve kitoken>=0.10.2 without the workspace root.
1 parent 82f335f commit 66d9500

5 files changed

Lines changed: 38 additions & 11 deletions

File tree

lib/levanter/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ dev = [
150150

151151

152152
[tool.uv]
153+
find-links = [
154+
"https://github.com/marin-community/kitoken/releases/expanded_assets/kitoken-0.10.2-a3012f4",
155+
]
153156
conflicts = [
154157
[
155158
{ extra = "gpu" },

lib/levanter/src/levanter/compat/hf_checkpoints.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -515,14 +515,15 @@ def with_tokenizer_padded_to_match_model(
515515
f"(adding {num_to_add} dummy tokens) to match model vocab size."
516516
)
517517

518-
# Add dummy tokens to the tokenizer
518+
# Add dummy tokens to the tokenizer. MarinTokenizer is read-only,
519+
# so we convert to an HF tokenizer which supports add_tokens.
519520
dummy_tokens = [f"<|padding_{i}|>" for i in range(num_to_add)]
520-
self.tokenizer.add_tokens(dummy_tokens)
521+
tokenizer = self.tokenizer
522+
if isinstance(tokenizer, MarinTokenizer):
523+
tokenizer = tokenizer.as_hf_tokenizer()
524+
tokenizer.add_tokens(dummy_tokens)
521525

522-
# Return a new converter with the modified tokenizer
523-
# Note: We modify self.tokenizer in place, but since the Vocab property is cached,
524-
# we need to return a new converter to get a fresh Vocab
525-
return dataclasses.replace(self, tokenizer=self.tokenizer) # type: ignore
526+
return dataclasses.replace(self, tokenizer=tokenizer) # type: ignore
526527

527528
def with_config_overrides(self, config_overrides: dict, merge: bool = True) -> "HFCheckpointConverter":
528529
if self.config_overrides is not None and merge:
@@ -1098,9 +1099,7 @@ def _list_relative_files(directory: str) -> set[str]:
10981099
logger.info("Saving tokenizer")
10991100
tokenizer = self.tokenizer
11001101
if isinstance(tokenizer, MarinTokenizer):
1101-
# MarinTokenizer doesn't have save_pretrained; load the
1102-
# underlying HF tokenizer so we can serialize it.
1103-
tokenizer = load_tokenizer(tokenizer.name_or_path)
1102+
tokenizer = tokenizer.as_hf_tokenizer()
11041103
tokenizer.save_pretrained(local_path)
11051104

11061105
if save_feature_extractor and self.feature_extractor is not None:

lib/levanter/src/levanter/data/passthrough_tokenizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,6 @@ def apply_chat_template_with_masks(
9595
**kwargs,
9696
) -> dict[str, list[list[int]]]:
9797
raise ValueError("PassthroughTokenizer does not support chat templates")
98+
99+
def as_hf_tokenizer(self):
100+
raise ValueError("PassthroughTokenizer cannot be converted to an HF tokenizer")

lib/levanter/src/levanter/lora.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777
silence_transformer_nag()
7878
from transformers import PreTrainedTokenizerBase # noqa: E402
7979

80+
from levanter.tokenizers import MarinTokenizer # noqa: E402
81+
8082

8183
logger = logging.getLogger(__name__)
8284

@@ -320,7 +322,7 @@ def save_peft_pretrained(
320322
config: LoraConfig,
321323
base_model_name_or_path,
322324
path: str,
323-
tokenizer: Optional[PreTrainedTokenizerBase] = None,
325+
tokenizer: Optional[PreTrainedTokenizerBase | MarinTokenizer] = None,
324326
*,
325327
prefix: Optional[str] = DEFAULT_DICT_PREFIX,
326328
upload_to: Optional[Union[bool, str, RepoRef]] = None,
@@ -349,6 +351,8 @@ def save_peft_pretrained(
349351
json.dump(hf_config, f)
350352

351353
if tokenizer is not None:
354+
if isinstance(tokenizer, MarinTokenizer):
355+
tokenizer = tokenizer.as_hf_tokenizer()
352356
tokenizer.save_pretrained(local_path)
353357

354358
if upload_to is True:
@@ -363,7 +367,7 @@ def save_peft_checkpoint_callback(
363367
base_path,
364368
config: LoraConfig,
365369
base_model_name_or_path,
366-
tokenizer: Optional[PreTrainedTokenizerBase] = None,
370+
tokenizer: Optional[PreTrainedTokenizerBase | MarinTokenizer] = None,
367371
upload_to_hf: Optional[Union[bool, str, RepoRef]] = False,
368372
**hf_upload_kwargs,
369373
):

lib/levanter/src/levanter/tokenizers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ def apply_chat_template_with_masks(
8989
**kwargs,
9090
) -> dict[str, list[list[int]]]: ...
9191

92+
def as_hf_tokenizer(self) -> Any:
93+
"""Return a HuggingFace PreTrainedTokenizerFast for this tokenizer.
94+
95+
Useful for operations that require the HF API (save_pretrained,
96+
add_tokens, generation config, etc.).
97+
"""
98+
...
99+
92100

93101
# Sentinel used to mark generation (assistant) boundaries in rendered templates.
94102
_GENERATION_SENTINEL_START = "__MARIN_GEN_START_7f3a9c__"
@@ -315,6 +323,11 @@ def apply_chat_template_with_masks(
315323
) -> dict[str, list[list[int]]]:
316324
return _apply_chat_template_with_masks(self, conversations, chat_template=chat_template, **kwargs)
317325

326+
def as_hf_tokenizer(self) -> Any:
327+
from transformers import AutoTokenizer
328+
329+
return AutoTokenizer.from_pretrained(self._name_or_path, trust_remote_code=True)
330+
318331

319332
@dataclasses.dataclass(frozen=True)
320333
class KitokenMarinTokenizer:
@@ -449,6 +462,11 @@ def apply_chat_template_with_masks(
449462
) -> dict[str, list[list[int]]]:
450463
return _apply_chat_template_with_masks(self, conversations, chat_template=chat_template, **kwargs)
451464

465+
def as_hf_tokenizer(self) -> Any:
466+
from transformers import AutoTokenizer
467+
468+
return AutoTokenizer.from_pretrained(self._name_or_path, trust_remote_code=True)
469+
452470

453471
class TokenizerBackend(StrEnum):
454472
HF = "hf"

0 commit comments

Comments
 (0)