|
| 1 | +# Model Perplexity Gap Finder |
| 2 | + |
| 3 | +## Problem |
| 4 | + |
| 5 | +Levanter's current analysis path compares models only after they have been |
| 6 | +tokenized with a single shared tokenizer. The existing compare-viz entrypoint in |
| 7 | +[`lib/levanter/src/levanter/main/viz_logprobs.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/main/viz_logprobs.py#L34) |
| 8 | +loads one tokenizer from `config.data.the_tokenizer` and uses one `LmConfig` for |
| 9 | +both models |
| 10 | +([`viz_logprobs.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/main/viz_logprobs.py#L54), |
| 11 | +[`viz_logprobs.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/main/viz_logprobs.py#L123)). |
| 12 | +That is fine for "same tokenizer, two checkpoints", but it cannot answer |
| 13 | +"where is Marin worse than Llama 3.1?" once the models use different tokenizers. |
| 14 | + |
| 15 | +Levanter already has the right aggregation idea for corpus slices: tagged eval |
| 16 | +datasets with hierarchical rollups and per-tag `bpb` |
| 17 | +([`lib/levanter/src/levanter/eval.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/eval.py#L199), |
| 18 | +[`eval.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/eval.py#L387)). |
| 19 | +Marin already defaults validation to Paloma plus uncheatable eval, but only in a |
| 20 | +tokenizer-specific cached form |
| 21 | +([`experiments/defaults.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/experiments/defaults.py#L297)). |
| 22 | + |
| 23 | +For this feature we want a different path: |
| 24 | + |
| 25 | +- take raw text corpora in the usual InputName-driven Marin style |
| 26 | +- tokenize on the fly for each model separately |
| 27 | +- compare models on a tokenizer-independent unit |
| 28 | +- report both dataset-level gaps and byte-pattern / word-part gaps |
| 29 | + |
| 30 | +No backward compatibility work is needed. Existing cached tokenization, `eval_lm`, |
| 31 | +and `viz_logprobs` behavior should stay unchanged. |
| 32 | + |
| 33 | +## Goals |
| 34 | + |
| 35 | +- Compare two Levanter-loadable LMs, where each model may have its own tokenizer, |
| 36 | + its own `LmConfig`, and either an HF or native Levanter checkpoint. |
| 37 | +- Score raw text documents directly and normalize results as bits per byte. |
| 38 | +- Attribute loss deltas onto byte spans so reports can surface tokenization-free |
| 39 | + "word part" effects such as whitespace runs, punctuation clusters, or short |
| 40 | + literal byte spans. |
| 41 | +- Reuse Marin's existing raw-dataset conventions and default to raw Paloma plus |
| 42 | + raw uncheatable eval. |
| 43 | +- Produce a persisted report that is readable without W&B. |
| 44 | + |
| 45 | +Non-goals: |
| 46 | + |
| 47 | +- replacing `LmDataConfig` or the cache-based training/eval path |
| 48 | +- supporting non-text dataset formats in v1 |
| 49 | +- unsupervised topic discovery or clustering |
| 50 | +- exact token-to-token alignment across two tokenizers |
| 51 | + |
| 52 | +## Proposed Solution |
| 53 | + |
| 54 | +### Core approach |
| 55 | + |
| 56 | +Introduce a new raw-text analysis path in Levanter that scores both models on the |
| 57 | +same raw UTF-8 documents, but tokenizes each document independently per model. |
| 58 | +Each model's per-token next-token loss is projected back onto the original |
| 59 | +document bytes through tokenizer offset mappings. Once both models live on the |
| 60 | +same byte axis, every report becomes an aggregation over byte-attributed losses. |
| 61 | + |
| 62 | +This keeps the core invariant simple: |
| 63 | + |
| 64 | +1. raw document bytes are the shared evaluation unit |
| 65 | +2. model A and model B may tokenize differently |
| 66 | +3. both models' losses are attributed onto those same bytes |
| 67 | + |
| 68 | +### Config shape |
| 69 | + |
| 70 | +Levanter gets a dedicated entrypoint and config rather than extending |
| 71 | +`VizLmConfig`. |
| 72 | + |
| 73 | +```python |
| 74 | +@dataclass |
| 75 | +class GapFinderModelConfig: |
| 76 | + checkpoint_path: str |
| 77 | + model: LmConfig | None = None |
| 78 | + checkpoint_is_hf: bool = False |
| 79 | + tokenizer: str | None = None |
| 80 | + tokenizer_backend: TokenizerBackend = TokenizerBackend.HF |
| 81 | + |
| 82 | + |
| 83 | +@dataclass |
| 84 | +class GapFinderConfig: |
| 85 | + model_a: GapFinderModelConfig |
| 86 | + model_b: GapFinderModelConfig |
| 87 | + datasets: dict[str, DatasetComponent] |
| 88 | + trainer: TrainerConfig = field(default_factory=TrainerConfig) |
| 89 | + output_path: str = "gap-finder" |
| 90 | + max_eval_length: int = 4096 |
| 91 | + max_docs_per_dataset: int | None = 256 |
| 92 | +``` |
| 93 | + |
| 94 | +Marin gets a thin wrapper config that accepts raw datasets, converts them into |
| 95 | +`DatasetComponent` values with `UrlDatasetSourceConfig` / |
| 96 | +`HfDatasetSourceConfig`, then submits the Levanter job on Iris. |
| 97 | + |
| 98 | +### Raw scoring loop |
| 99 | + |
| 100 | +The raw path should not go through `LmDataConfig.validation_sets()` because that |
| 101 | +method is cache- and tokenizer-oriented |
| 102 | +([`lib/levanter/src/levanter/data/text/datasets.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/data/text/datasets.py#L817), |
| 103 | +[`datasets.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/data/text/datasets.py#L826)). |
| 104 | +Instead, the new entrypoint should iterate raw shards via |
| 105 | +`DatasetComponent.source.get_shard_source("validation")`, read `text` from |
| 106 | +`TextLmDatasetFormat`, tokenize batches on the host, and feed padded arrays into |
| 107 | +the model. |
| 108 | + |
| 109 | +The forward pass should reuse the standard next-token loss path rather than |
| 110 | +custom logits math: |
| 111 | + |
| 112 | +```python |
| 113 | +@hax.named_jit(axis_resources=compute_axis_mapping) |
| 114 | +def compute_token_losses(model: LmHeadModel, batch: LmExample): |
| 115 | + model = inference_mode(model, True) |
| 116 | + model = mp.cast_to_compute(model) |
| 117 | + per_pos_loss = model.compute_next_token_loss( |
| 118 | + batch, |
| 119 | + reduction=None, |
| 120 | + reduction_axis=(), |
| 121 | + ).array |
| 122 | + target_ids = jnp.roll(batch.tokens.array, -1, axis=-1) |
| 123 | + return per_pos_loss, batch.loss_weight.array, target_ids |
| 124 | +``` |
| 125 | + |
| 126 | +### Byte attribution |
| 127 | + |
| 128 | +For each raw document: |
| 129 | + |
| 130 | +1. tokenize with offsets using the model's HF tokenizer |
| 131 | +2. add BOS/EOS manually when the tokenizer would not insert them itself |
| 132 | +3. score padded chunks up to `max_eval_length` |
| 133 | +4. shift losses onto target-token spans, mirroring Levanter eval's target-id |
| 134 | + handling |
| 135 | +5. spread each token's loss uniformly across its covered byte span |
| 136 | + |
| 137 | +Uniform byte spreading is the simplest stable attribution rule. It preserves |
| 138 | +document-level `bpb`, avoids token-to-token alignment, and lets us aggregate by |
| 139 | +arbitrary byte-derived patterns later. |
| 140 | + |
| 141 | +### Report structure |
| 142 | + |
| 143 | +The report should include: |
| 144 | + |
| 145 | +- dataset / subcorpus gap table (`model_a_bpb`, `model_b_bpb`, `gap_bpb`) |
| 146 | +- hierarchical rollups for names like `paloma/...` |
| 147 | +- top documents by positive and negative delta |
| 148 | +- pattern-bucket gap table, with buckets such as: |
| 149 | + - `whitespace/single_space` |
| 150 | + - `whitespace/multi_space` |
| 151 | + - `whitespace/newline` |
| 152 | + - `whitespace/tab_or_cr` |
| 153 | + - `text/url` |
| 154 | + - `text/number` |
| 155 | + - `text/punctuation` |
| 156 | + - `text/non_ascii` |
| 157 | + - `text/word` |
| 158 | +- top literal byte spans / short substrings with the largest deltas |
| 159 | + |
| 160 | +Persist both JSON and HTML so downstream scripts can consume the data while |
| 161 | +humans can inspect a single rendered report. |
| 162 | + |
| 163 | +## Implementation Outline |
| 164 | + |
| 165 | +1. Add a Levanter raw-text gap finder entrypoint, config types, model-loading |
| 166 | + helpers, and HTML/JSON report writer. |
| 167 | +2. Add host-side raw text iteration, tokenizer-with-offset handling, and |
| 168 | + byte-attributed loss aggregation for text datasets. |
| 169 | +3. Add a Marin wrapper plus helpers for raw evaluation components and default raw |
| 170 | + Paloma/uncheatable dataset wiring. |
| 171 | +4. Add focused tests for byte attribution, bucket aggregation, and a tiny |
| 172 | + end-to-end Levanter run. |
| 173 | +5. Add an experiment script that compares `marin-community/marin-8b-base` and |
| 174 | + `meta-llama/Meta-Llama-3.1-8B` on Iris v5p-8 in `us-central1`. |
| 175 | + |
| 176 | +## Notes |
| 177 | + |
| 178 | +- V1 should explicitly support `TextLmDatasetFormat` only. Chat/template-aware |
| 179 | + data can be added later once there is a clear raw-byte contract. |
| 180 | +- Existing tagged eval code in |
| 181 | + [`lib/levanter/src/levanter/eval.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/eval.py#L538) |
| 182 | + is still the right model for hierarchical corpus aggregation; the new path just |
| 183 | + computes those aggregates from raw byte-attributed records instead of from a |
| 184 | + shared-tokenizer dataset. |
| 185 | +- The existing `byte_length_of_token()` helper |
| 186 | + ([`lib/levanter/src/levanter/utils/hf_utils.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/utils/hf_utils.py#L23)) |
| 187 | + remains useful for sanity checks, but offset-based byte attribution is the |
| 188 | + source of truth for mixed-tokenizer comparison. |
| 189 | +- `save_logprobs.py` |
| 190 | + ([`lib/marin/src/marin/evaluation/save_logprobs.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/marin/src/marin/evaluation/save_logprobs.py#L85)) |
| 191 | + is a useful reference for how to gather per-token outputs on TPU, but the gap |
| 192 | + finder should not serialize full token streams for both models by default. |
| 193 | +- The default raw validation helper should mirror the current tokenized helper's |
| 194 | + dataset coverage so the new tool can be dropped into existing analysis flows. |
| 195 | + |
| 196 | +## Future Work |
| 197 | + |
| 198 | +- support `ChatLmDatasetFormat` and template-rendered raw comparisons |
| 199 | +- add optional W&B artifact logging for the HTML report and summary JSON |
| 200 | +- richer byte-pattern discovery beyond the fixed interpretable buckets |
| 201 | +- support approximate context-preserving chunk transitions for very long |
| 202 | + documents instead of dropping the first-token loss in each chunk |
0 commit comments