Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions .agents/projects/model-perplexity-gap-finder.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Model Perplexity Gap Finder

## Problem

Levanter's current analysis path compares models only after they have been
tokenized with a single shared tokenizer. The existing compare-viz entrypoint in
[`lib/levanter/src/levanter/main/viz_logprobs.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/main/viz_logprobs.py#L34)
loads one tokenizer from `config.data.the_tokenizer` and uses one `LmConfig` for
both models
([`viz_logprobs.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/main/viz_logprobs.py#L54),
[`viz_logprobs.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/main/viz_logprobs.py#L123)).
That is fine for "same tokenizer, two checkpoints", but it cannot answer
"where is Marin worse than Llama 3.1?" once the models use different tokenizers.

Levanter already has the right aggregation idea for corpus slices: tagged eval
datasets with hierarchical rollups and per-tag `bpb`
([`lib/levanter/src/levanter/eval.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/eval.py#L199),
[`eval.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/eval.py#L387)).
Marin already defaults validation to Paloma plus uncheatable eval, but only in a
tokenizer-specific cached form
([`experiments/defaults.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/experiments/defaults.py#L297)).

For this feature we want a different path:

- take raw text corpora in the usual InputName-driven Marin style
- tokenize on the fly for each model separately
- compare models on a tokenizer-independent unit
- report both dataset-level gaps and byte-pattern / word-part gaps

No backward compatibility work is needed. Existing cached tokenization, `eval_lm`,
and `viz_logprobs` behavior should stay unchanged.

## Goals

- Compare two Levanter-loadable LMs, where each model may have its own tokenizer,
its own `LmConfig`, and either an HF or native Levanter checkpoint.
- Score raw text documents directly and normalize results as bits per byte.
- Attribute loss deltas onto byte spans so reports can surface tokenization-free
"word part" effects such as whitespace runs, punctuation clusters, or short
literal byte spans.
- Reuse Marin's existing raw-dataset conventions and default to raw Paloma plus
raw uncheatable eval.
- Produce a persisted report that is readable without W&B.

Non-goals:

- replacing `LmDataConfig` or the cache-based training/eval path
- supporting non-text dataset formats in v1
- unsupervised topic discovery or clustering
- exact token-to-token alignment across two tokenizers

## Proposed Solution

### Core approach

Introduce a new raw-text analysis path in Levanter that scores both models on the
same raw UTF-8 documents, but tokenizes each document independently per model.
Each model's per-token next-token loss is projected back onto the original
document bytes through tokenizer offset mappings. Once both models live on the
same byte axis, every report becomes an aggregation over byte-attributed losses.

This keeps the core invariant simple:

1. raw document bytes are the shared evaluation unit
2. model A and model B may tokenize differently
3. both models' losses are attributed onto those same bytes

### Config shape

Levanter gets a dedicated entrypoint and config rather than extending
`VizLmConfig`.

```python
@dataclass
class GapFinderModelConfig:
checkpoint_path: str
model: LmConfig | None = None
checkpoint_is_hf: bool = False
tokenizer: str | None = None
tokenizer_backend: TokenizerBackend = TokenizerBackend.HF


@dataclass
class GapFinderConfig:
model_a: GapFinderModelConfig
model_b: GapFinderModelConfig
datasets: dict[str, DatasetComponent]
trainer: TrainerConfig = field(default_factory=TrainerConfig)
output_path: str = "gap-finder"
max_eval_length: int = 4096
max_docs_per_dataset: int | None = 256
```

Marin gets a thin wrapper config that accepts raw datasets, converts them into
`DatasetComponent` values with `UrlDatasetSourceConfig` /
`HfDatasetSourceConfig`, then submits the Levanter job on Iris.

### Raw scoring loop

The raw path should not go through `LmDataConfig.validation_sets()` because that
method is cache- and tokenizer-oriented
([`lib/levanter/src/levanter/data/text/datasets.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/data/text/datasets.py#L817),
[`datasets.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/data/text/datasets.py#L826)).
Instead, the new entrypoint should iterate raw shards via
`DatasetComponent.source.get_shard_source("validation")`, read `text` from
`TextLmDatasetFormat`, tokenize batches on the host, and feed padded arrays into
the model.

The forward pass should reuse the standard next-token loss path rather than
custom logits math:

```python
@hax.named_jit(axis_resources=compute_axis_mapping)
def compute_token_losses(model: LmHeadModel, batch: LmExample):
model = inference_mode(model, True)
model = mp.cast_to_compute(model)
per_pos_loss = model.compute_next_token_loss(
batch,
reduction=None,
reduction_axis=(),
).array
target_ids = jnp.roll(batch.tokens.array, -1, axis=-1)
return per_pos_loss, batch.loss_weight.array, target_ids
```

### Byte attribution

For each raw document:

1. tokenize with offsets using the model's HF tokenizer
2. add BOS/EOS manually when the tokenizer would not insert them itself
3. score padded chunks up to `max_eval_length`
4. shift losses onto target-token spans, mirroring Levanter eval's target-id
handling
5. spread each token's loss uniformly across its covered byte span

Uniform byte spreading is the simplest stable attribution rule. It preserves
document-level `bpb`, avoids token-to-token alignment, and lets us aggregate by
arbitrary byte-derived patterns later.

### Report structure

The report should include:

- dataset / subcorpus gap table (`model_a_bpb`, `model_b_bpb`, `gap_bpb`)
- hierarchical rollups for names like `paloma/...`
- top documents by positive and negative delta
- pattern-bucket gap table, with buckets such as:
- `whitespace/single_space`
- `whitespace/multi_space`
- `whitespace/newline`
- `whitespace/tab_or_cr`
- `text/url`
- `text/number`
- `text/punctuation`
- `text/non_ascii`
- `text/word`
- top literal byte spans / short substrings with the largest deltas

Persist both JSON and HTML so downstream scripts can consume the data while
humans can inspect a single rendered report.

## Implementation Outline

1. Add a Levanter raw-text gap finder entrypoint, config types, model-loading
helpers, and HTML/JSON report writer.
2. Add host-side raw text iteration, tokenizer-with-offset handling, and
byte-attributed loss aggregation for text datasets.
3. Add a Marin wrapper plus helpers for raw evaluation components and default raw
Paloma/uncheatable dataset wiring.
4. Add focused tests for byte attribution, bucket aggregation, and a tiny
end-to-end Levanter run.
5. Add an experiment script that compares `marin-community/marin-8b-base` and
`meta-llama/Meta-Llama-3.1-8B` on Iris v5p-8 in `us-central1`.

## Notes

- V1 should explicitly support `TextLmDatasetFormat` only. Chat/template-aware
data can be added later once there is a clear raw-byte contract.
- Existing tagged eval code in
[`lib/levanter/src/levanter/eval.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/eval.py#L538)
is still the right model for hierarchical corpus aggregation; the new path just
computes those aggregates from raw byte-attributed records instead of from a
shared-tokenizer dataset.
- The existing `byte_length_of_token()` helper
([`lib/levanter/src/levanter/utils/hf_utils.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/utils/hf_utils.py#L23))
remains useful for sanity checks, but offset-based byte attribution is the
source of truth for mixed-tokenizer comparison.
- `save_logprobs.py`
([`lib/marin/src/marin/evaluation/save_logprobs.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/marin/src/marin/evaluation/save_logprobs.py#L85))
is a useful reference for how to gather per-token outputs on TPU, but the gap
finder should not serialize full token streams for both models by default.
- The default raw validation helper should mirror the current tokenized helper's
dataset coverage so the new tool can be dropped into existing analysis flows.

## Future Work

- support `ChatLmDatasetFormat` and template-rendered raw comparisons
- add optional W&B artifact logging for the HTML report and summary JSON
- richer byte-pattern discovery beyond the fixed interpretable buckets
- support approximate context-preserving chunk transitions for very long
documents instead of dropping the first-token loss in each chunk
23 changes: 23 additions & 0 deletions docs/tutorials/run-lm-evals.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,26 @@ For deeper dives, see:
- `docs/explanations/evaluation.md`
- `experiments/evals/task_configs.py`
- `experiments/evals/evals.py`

## Raw Perplexity Gap Datasets

The raw perplexity-gap workflow uses `default_raw_validation_sets()` from `experiments/defaults.py`. That bundle now includes:

- Paloma
- Uncheatable Eval
- Curated capability-family slices for:
- `chat/wildchat`
- `agent_traces/openhands_swe_rebench`
- `reasoning_icl/gsm8k_main`
- `reasoning_icl/global_mgsm_en`

These capability datasets are first normalized into reusable OpenAI-chat JSONL artifacts under each step's `oai/` output. Consumers that want Levanter chat tokenization can use `capability_chat_validation_components()`, which wraps those rows in `ChatLmDatasetFormat` with `MARIN_CHAT_TEMPLATE`. The raw gap finder still consumes plain `text`, so the same step also writes a derived `raw_text/` projection using Marin's chat-token surface. OpenHands traces keep the full system/user/tool conversation in the OAI artifact, while the raw-text projection scores only assistant-generated trace targets and final patches.

The curated default uses modest, reproducible slices for the larger structured corpora rather than mirroring whole Hugging Face datasets into GCS. That keeps cost and executor output size bounded while still giving useful coverage for base-model PPL comparisons.

If you want the gated chat sources as well, use `extended_raw_validation_sets()` instead of `default_raw_validation_sets()`. That currently adds:

- `chat/lima_train`
- `chat/lmsys_chat_1m`

Those opt-in datasets stay out of the default bundle because access and licensing are more restrictive.
65 changes: 40 additions & 25 deletions experiments/chat_templates/llama3pt1_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@
{%- set first_user_message = (messages[0]['content'] or '')|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
{%- endif %}
{%- set first_user_message = "" %}
{%- endif %}
{{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
{{- "Given the following functions, please respond with a JSON for a function call " }}
{{- "with its proper arguments that best answers the given prompt.\n\n" }}
Expand All @@ -74,40 +74,55 @@
{%- endif %}

{%- for message in messages %}
{%- if not (message.role == 'ipython' or message.role == 'tool' or ('tool_calls' in message and message.tool_calls is not none)) %}
{%- if not (message.role == 'ipython' or message.role == 'tool' or (message.tool_calls is defined and message.tool_calls)) %}
{%- if message.role == 'assistant' %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}{% generation %}{{- (message['content'] or '') | trim + '<|eot_id|>' }}{% endgeneration %}
{%- else %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ (message['content'] or '') | trim + '<|eot_id|>' }}
{%- endif %}
{%- elif 'tool_calls' in message and message.tool_calls is not none %}
{%- if not message.tool_calls|length == 1 %}
{{- raise_exception("This model only supports single tool-calls at once!") }}
{%- elif message.tool_calls is defined and message.tool_calls %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
{% generation %}
{%- if message.content %}
{{- message.content | trim }}
{{- "\n" }}
{%- endif %}
{%- set tool_call = message.tool_calls[0].function %}
{%- if builtin_tools is defined and tool_call.name in builtin_tools %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
{% generation %}{{- "<|python_tag|>" + tool_call.name + ".call(" }}
{%- for arg_name, arg_val in tool_call.arguments | items %}
{{- arg_name + '="' + arg_val + '"' }}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- for raw_tool_call in message.tool_calls %}
{%- if raw_tool_call.function is defined and raw_tool_call.function is not none %}
{%- set tool_call = raw_tool_call.function %}
{%- else %}
{%- set tool_call = raw_tool_call %}
{%- endif %}
{%- if builtin_tools is defined and tool_call.name in builtin_tools %}
{{- "<|python_tag|>" + tool_call.name + ".call(" }}
{%- for arg_name, arg_val in tool_call.arguments | items %}
{{- arg_name + '="' + arg_val + '"' }}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- ")" }}{% endgeneration %}
{%- else %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
{% generation %}{{- '{"name": "' + tool_call.name + '", ' }}
{{- '"parameters": ' }}
{{- tool_call.arguments | tojson }}
{{- "}" }}{% endgeneration %}
{%- endif %}
{{- ")" }}
{%- else %}
{{- '{"name": "' + tool_call.name + '", ' }}
{{- '"parameters": ' }}
{%- if tool_call.arguments is string %}
{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}
{{- "}" }}
{%- endif %}
{%- if not loop.last %}
{{- "\n" }}
{%- endif %}
{%- endfor %}
{%- if builtin_tools is defined %}
{#- This means we're in ipython mode #}
{% generation %}{{- "<|eom_id|>" }}{% endgeneration %}
{{- "<|eom_id|>" }}
{%- else %}
{% generation %}{{- "<|eot_id|>" }}{% endgeneration %}
{{- "<|eot_id|>" }}
{%- endif %}
{% endgeneration %}
{%- elif message.role == "tool" or message.role == "ipython" %}
{{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
{%- if message.content is none %}
Expand Down
22 changes: 21 additions & 1 deletion experiments/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from experiments.evals.task_configs import CORE_TASKS
from marin.evaluation.evaluation_config import convert_to_levanter_task_config
from experiments.paloma import paloma_tokenized
from experiments.paloma import paloma_raw_validation_sets, paloma_tokenized
from experiments.simple_dpo_config import SimpleDPOConfig
from experiments.simple_sft_config import SimpleSFTConfig
from experiments.simple_train_config import SimpleTrainConfig
Expand Down Expand Up @@ -304,6 +304,26 @@ def default_validation_sets(tokenizer: str, base_path: str = "tokenized/") -> di
return validation_sets


@lru_cache
def default_raw_validation_sets() -> dict[str, Any]:
from experiments.evals.exp1600_uncheatable_evals import uncheatable_eval_raw_validation_sets
from experiments.evals.raw_capability_eval_sets import capability_raw_validation_sets

validation_sets = dict(paloma_raw_validation_sets())
validation_sets.update(uncheatable_eval_raw_validation_sets())
validation_sets.update(capability_raw_validation_sets())
return validation_sets


@lru_cache
def extended_raw_validation_sets() -> dict[str, Any]:
from experiments.evals.raw_capability_eval_sets import opt_in_capability_raw_validation_sets

validation_sets = dict(default_raw_validation_sets())
validation_sets.update(opt_in_capability_raw_validation_sets())
return validation_sets


def simulated_epoching_train(
name: str,
tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig,
Expand Down
9 changes: 9 additions & 0 deletions experiments/evals/exp1600_uncheatable_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def uncheatable_eval_tokenized(
return uncheatable_eval_steps


def uncheatable_eval_raw_validation_sets(*, uncheatable_eval_raw: ExecutorStep = uncheatable_eval):
from marin.evaluation.perplexity_gap import raw_text_dataset

return {
os.path.join("uncheatable_eval", dataset): raw_text_dataset(uncheatable_eval_raw.cd(path_part))
for dataset, path_part in ((dataset, ALL_UNCHEATABLE_EVAL_DATASETS[dataset]) for dataset in ACTIVE_DATASETS)
}


@dataclass
class ModelConfig:
model_name: str
Expand Down
Loading