Skip to content

Commit cca542d

Browse files
authored
Merge branch 'main' into will/exp1337-seeds
2 parents 0a1c62b + 09234be commit cca542d

619 files changed

Lines changed: 55513 additions & 19249 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Model Perplexity Gap Finder
2+
3+
## Problem
4+
5+
Levanter's current analysis path compares models only after they have been
6+
tokenized with a single shared tokenizer. The existing compare-viz entrypoint in
7+
[`lib/levanter/src/levanter/main/viz_logprobs.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/main/viz_logprobs.py#L34)
8+
loads one tokenizer from `config.data.the_tokenizer` and uses one `LmConfig` for
9+
both models
10+
([`viz_logprobs.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/main/viz_logprobs.py#L54),
11+
[`viz_logprobs.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/main/viz_logprobs.py#L123)).
12+
That is fine for "same tokenizer, two checkpoints", but it cannot answer
13+
"where is Marin worse than Llama 3.1?" once the models use different tokenizers.
14+
15+
Levanter already has the right aggregation idea for corpus slices: tagged eval
16+
datasets with hierarchical rollups and per-tag `bpb`
17+
([`lib/levanter/src/levanter/eval.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/eval.py#L199),
18+
[`eval.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/eval.py#L387)).
19+
Marin already defaults validation to Paloma plus uncheatable eval, but only in a
20+
tokenizer-specific cached form
21+
([`experiments/defaults.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/experiments/defaults.py#L297)).
22+
23+
For this feature we want a different path:
24+
25+
- take raw text corpora in the usual InputName-driven Marin style
26+
- tokenize on the fly for each model separately
27+
- compare models on a tokenizer-independent unit
28+
- report both dataset-level gaps and byte-pattern / word-part gaps
29+
30+
No backward compatibility work is needed. Existing cached tokenization, `eval_lm`,
31+
and `viz_logprobs` behavior should stay unchanged.
32+
33+
## Goals
34+
35+
- Compare two Levanter-loadable LMs, where each model may have its own tokenizer,
36+
its own `LmConfig`, and either an HF or native Levanter checkpoint.
37+
- Score raw text documents directly and normalize results as bits per byte.
38+
- Attribute loss deltas onto byte spans so reports can surface tokenization-free
39+
"word part" effects such as whitespace runs, punctuation clusters, or short
40+
literal byte spans.
41+
- Reuse Marin's existing raw-dataset conventions and default to raw Paloma plus
42+
raw uncheatable eval.
43+
- Produce a persisted report that is readable without W&B.
44+
45+
Non-goals:
46+
47+
- replacing `LmDataConfig` or the cache-based training/eval path
48+
- supporting non-text dataset formats in v1
49+
- unsupervised topic discovery or clustering
50+
- exact token-to-token alignment across two tokenizers
51+
52+
## Proposed Solution
53+
54+
### Core approach
55+
56+
Introduce a new raw-text analysis path in Levanter that scores both models on the
57+
same raw UTF-8 documents, but tokenizes each document independently per model.
58+
Each model's per-token next-token loss is projected back onto the original
59+
document bytes through tokenizer offset mappings. Once both models live on the
60+
same byte axis, every report becomes an aggregation over byte-attributed losses.
61+
62+
This keeps the core invariant simple:
63+
64+
1. raw document bytes are the shared evaluation unit
65+
2. model A and model B may tokenize differently
66+
3. both models' losses are attributed onto those same bytes
67+
68+
### Config shape
69+
70+
Levanter gets a dedicated entrypoint and config rather than extending
71+
`VizLmConfig`.
72+
73+
```python
74+
@dataclass
75+
class GapFinderModelConfig:
76+
checkpoint_path: str
77+
model: LmConfig | None = None
78+
checkpoint_is_hf: bool = False
79+
tokenizer: str | None = None
80+
tokenizer_backend: TokenizerBackend = TokenizerBackend.HF
81+
82+
83+
@dataclass
84+
class GapFinderConfig:
85+
model_a: GapFinderModelConfig
86+
model_b: GapFinderModelConfig
87+
datasets: dict[str, DatasetComponent]
88+
trainer: TrainerConfig = field(default_factory=TrainerConfig)
89+
output_path: str = "gap-finder"
90+
max_eval_length: int = 4096
91+
max_docs_per_dataset: int | None = 256
92+
```
93+
94+
Marin gets a thin wrapper config that accepts raw datasets, converts them into
95+
`DatasetComponent` values with `UrlDatasetSourceConfig` /
96+
`HfDatasetSourceConfig`, then submits the Levanter job on Iris.
97+
98+
### Raw scoring loop
99+
100+
The raw path should not go through `LmDataConfig.validation_sets()` because that
101+
method is cache- and tokenizer-oriented
102+
([`lib/levanter/src/levanter/data/text/datasets.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/data/text/datasets.py#L817),
103+
[`datasets.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/data/text/datasets.py#L826)).
104+
Instead, the new entrypoint should iterate raw shards via
105+
`DatasetComponent.source.get_shard_source("validation")`, read `text` from
106+
`TextLmDatasetFormat`, tokenize batches on the host, and feed padded arrays into
107+
the model.
108+
109+
The forward pass should reuse the standard next-token loss path rather than
110+
custom logits math:
111+
112+
```python
113+
@hax.named_jit(axis_resources=compute_axis_mapping)
114+
def compute_token_losses(model: LmHeadModel, batch: LmExample):
115+
model = inference_mode(model, True)
116+
model = mp.cast_to_compute(model)
117+
per_pos_loss = model.compute_next_token_loss(
118+
batch,
119+
reduction=None,
120+
reduction_axis=(),
121+
).array
122+
target_ids = jnp.roll(batch.tokens.array, -1, axis=-1)
123+
return per_pos_loss, batch.loss_weight.array, target_ids
124+
```
125+
126+
### Byte attribution
127+
128+
For each raw document:
129+
130+
1. tokenize with offsets using the model's HF tokenizer
131+
2. add BOS/EOS manually when the tokenizer would not insert them itself
132+
3. score padded chunks up to `max_eval_length`
133+
4. shift losses onto target-token spans, mirroring Levanter eval's target-id
134+
handling
135+
5. spread each token's loss uniformly across its covered byte span
136+
137+
Uniform byte spreading is the simplest stable attribution rule. It preserves
138+
document-level `bpb`, avoids token-to-token alignment, and lets us aggregate by
139+
arbitrary byte-derived patterns later.
140+
141+
### Report structure
142+
143+
The report should include:
144+
145+
- dataset / subcorpus gap table (`model_a_bpb`, `model_b_bpb`, `gap_bpb`)
146+
- hierarchical rollups for names like `paloma/...`
147+
- top documents by positive and negative delta
148+
- pattern-bucket gap table, with buckets such as:
149+
- `whitespace/single_space`
150+
- `whitespace/multi_space`
151+
- `whitespace/newline`
152+
- `whitespace/tab_or_cr`
153+
- `text/url`
154+
- `text/number`
155+
- `text/punctuation`
156+
- `text/non_ascii`
157+
- `text/word`
158+
- top literal byte spans / short substrings with the largest deltas
159+
160+
Persist both JSON and HTML so downstream scripts can consume the data while
161+
humans can inspect a single rendered report.
162+
163+
## Implementation Outline
164+
165+
1. Add a Levanter raw-text gap finder entrypoint, config types, model-loading
166+
helpers, and HTML/JSON report writer.
167+
2. Add host-side raw text iteration, tokenizer-with-offset handling, and
168+
byte-attributed loss aggregation for text datasets.
169+
3. Add a Marin wrapper plus helpers for raw evaluation components and default raw
170+
Paloma/uncheatable dataset wiring.
171+
4. Add focused tests for byte attribution, bucket aggregation, and a tiny
172+
end-to-end Levanter run.
173+
5. Add an experiment script that compares `marin-community/marin-8b-base` and
174+
`meta-llama/Meta-Llama-3.1-8B` on Iris v5p-8 in `us-central1`.
175+
176+
## Notes
177+
178+
- V1 should explicitly support `TextLmDatasetFormat` only. Chat/template-aware
179+
data can be added later once there is a clear raw-byte contract.
180+
- Existing tagged eval code in
181+
[`lib/levanter/src/levanter/eval.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/eval.py#L538)
182+
is still the right model for hierarchical corpus aggregation; the new path just
183+
computes those aggregates from raw byte-attributed records instead of from a
184+
shared-tokenizer dataset.
185+
- The existing `byte_length_of_token()` helper
186+
([`lib/levanter/src/levanter/utils/hf_utils.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/levanter/src/levanter/utils/hf_utils.py#L23))
187+
remains useful for sanity checks, but offset-based byte attribution is the
188+
source of truth for mixed-tokenizer comparison.
189+
- `save_logprobs.py`
190+
([`lib/marin/src/marin/evaluation/save_logprobs.py`](/Users/dlwh/.codex/worktrees/a2ab/marin/lib/marin/src/marin/evaluation/save_logprobs.py#L85))
191+
is a useful reference for how to gather per-token outputs on TPU, but the gap
192+
finder should not serialize full token streams for both models by default.
193+
- The default raw validation helper should mirror the current tokenized helper's
194+
dataset coverage so the new tool can be dropped into existing analysis flows.
195+
196+
## Future Work
197+
198+
- support `ChatLmDatasetFormat` and template-rendered raw comparisons
199+
- add optional W&B artifact logging for the HTML report and summary JSON
200+
- richer byte-pattern discovery beyond the fixed interpretable buckets
201+
- support approximate context-preserving chunk transitions for very long
202+
documents instead of dropping the first-token loss in each chunk

.agents/skills/agent-research/SKILL.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ When using W&B:
5858
- iteration is quick,
5959
- you are tuning kernels or benchmarks,
6060
- full pipeline apparatus is unnecessary.
61-
- Use `.agents/skills/dev-tpu/SKILL.md` for the standard Iris-backed workflow.
62-
- Use `.agents/skills/dev-tpu-ray/SKILL.md` only when you specifically need the legacy Ray-backed workflow.
61+
- Use `.agents/skills/dev-tpu/SKILL.md` for the Iris-backed workflow.
6362

6463
Rule of thumb:
6564
- Start with dev TPU for fast hillclimbing.

.agents/skills/babysit-job/SKILL.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@ description: Monitor/babysit a job continuously and recover on failure. Use when
77

88
Monitor a job continuously and recover on failure. For **Zephyr pipelines**,
99
delegate to **babysit-zephyr** instead. Otherwise, follow this skill — Iris is
10-
the default execution backend.
11-
12-
**Ray is deprecated.** If the user asks to run or babysit a Ray job, tell them
13-
Ray is no longer supported and they should use Iris instead.
10+
the execution backend.
1411

1512
## Required Info
1613

.agents/skills/canary-triage/SKILL.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ write a Slack summary. Diagnosis and reporting only — no code changes, no PRs.
2828
The cluster is still live. Collect signal now — it will be torn down after you.
2929

3030
- Iris job state via `.venv/bin/iris --config=$IRIS_CONFIG job list --json`
31-
- **GPU lane:** you have kubectl at `~/.kube/coreweave-iris`, namespace `$IRIS_NAMESPACE`.
31+
- **GPU lane:** you have kubectl at `~/.kube/coreweave-iris`, namespace `$IRIS_NAMESPACE` (defaults to `iris-ci` — the canary shares this namespace with PR CI).
3232
Get pod status, controller logs, task pod logs, warning events, pod describe.
33+
**Filter by `iris.job_id=<CANARY_JOB_ID with '/' replaced by '.'>`** so you only see this canary's pods, not co-tenant CI pods. Example: `kubectl -n iris-ci get pods -l iris.job_id=runner.iris-run-job-abc123`.
3334
- **TPU lane:** use `iris process logs` and `iris job list`.
3435
- Re-run `scripts/canary/validate_canary_metrics.py` if you need the validation output.
3536

0 commit comments

Comments
 (0)