Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/levanter/src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from jax.sharding import PartitionSpec

import levanter.tracker
from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer
from levanter.compat.hf_checkpoints import HFCheckpointConverter
from levanter.data.packing import (
PromptCompletion,
greedy_pack_prompt_completions,
Expand All @@ -56,7 +56,7 @@
from levanter.models.gpt2 import Gpt2Config
from levanter.models.loss import fused_cross_entropy_loss_and_logsumexp_penalty
from levanter.utils.background_iterable import BackgroundIterator
from levanter.tokenizers import MarinTokenizer
from levanter.tokenizers import MarinTokenizer, load_tokenizer
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve mutable tokenizer for pad token fallback

Importing load_tokenizer from levanter.tokenizers changes EvalHarnessMainConfig.the_tokenizer to return HfMarinTokenizer, which is a frozen dataclass without a pad_token_id setter. In this same module, both loglikelihood and generate_until do self.tokenizer.pad_token_id = self.tokenizer.eos_token_id when padding is missing, so models whose tokenizer has no pad token (common for Llama-family checkpoints) will now raise at runtime instead of evaluating. The previous loader from compat.hf_checkpoints returned a mutable HF tokenizer, so this regression is introduced by the import swap.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep tokenizer callable in generation paths

Switching load_tokenizer to levanter.tokenizers.load_tokenizer now returns HfMarinTokenizer, but generate_until() still relies on HF-style call semantics via tok_encode() (self.tokenizer(...) in eval_harness.py). HfMarinTokenizer does not implement __call__, so any lm-eval task that hits generate_until will now fail at runtime with a TypeError instead of generating outputs. This regression is introduced by the import swap because the previous loader returned a callable HF tokenizer.

Useful? React with 👍 / 👎.

from levanter.utils.py_utils import set_global_rng_seeds

try:
Expand Down
13 changes: 13 additions & 0 deletions lib/levanter/tests/test_eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,16 @@ def test_task_config():
q = config.to_task_dict()

assert len(q) == 3


def test_eval_harness_config_loads_marin_tokenizer():
"""Verify EvalHarnessMainConfig.the_tokenizer returns a MarinTokenizer."""
from levanter.eval_harness import EvalHarnessMainConfig, LmEvalHarnessConfig
from levanter.tokenizers import MarinTokenizer

config = EvalHarnessMainConfig(
eval_harness=LmEvalHarnessConfig(task_spec=["hellaswag"]),
tokenizer="stanford-crfm/marin-tokenizer",
checkpoint_path="/nonexistent",
)
assert isinstance(config.the_tokenizer, MarinTokenizer)
Loading