Skip to content

Commit 002f61b

Browse files
Migrate logprob evals to Fray v2 + some flybys (#4398)
Fixes #4397 + some flyby changes: 1) add a tracker callback to save eval results to a file (instead of just wandb); 2) add a function to save top-k logprobs per document (useful for some experiments I've been running). I'm fine with moving (2) to another PR / not moving it into main at all if that would be better
1 parent 47a567f commit 002f61b

3 files changed

Lines changed: 339 additions & 7 deletions

File tree

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright The Levanter Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import dataclasses
5+
import json
6+
import logging
7+
import os
8+
from typing import Any, Mapping, Optional
9+
10+
import jax
11+
12+
from levanter.tracker.json_logger import _flatten, _to_jsonable
13+
from levanter.tracker.tracker import NoopTracker, Tracker, TrackerConfig
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class JsonFileTracker(Tracker):
19+
"""Tracker that accumulates metrics and saves them to a JSON file on finish()."""
20+
21+
name: str = "json_file"
22+
23+
def __init__(self, output_path: str):
24+
self.output_path = output_path
25+
self._last_metrics: dict[str, Any] = {}
26+
self._summary_metrics: dict[str, Any] = {}
27+
28+
def log_hyperparameters(self, hparams: dict[str, Any]):
29+
pass
30+
31+
def log(self, metrics: Mapping[str, Any], *, step: Optional[int], commit: Optional[bool] = None):
32+
if step is not None:
33+
self._last_metrics.update(_flatten(metrics))
34+
35+
def log_summary(self, metrics: Mapping[str, Any]):
36+
self._summary_metrics.update(_flatten(metrics))
37+
38+
def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None):
39+
pass
40+
41+
def finish(self):
42+
import fsspec
43+
44+
summary = {**self._summary_metrics, **self._last_metrics}
45+
output_file = os.path.join(self.output_path, "eval_results.json")
46+
with fsspec.open(output_file, "wt") as f:
47+
json.dump(_to_jsonable(summary), f, indent=2)
48+
logger.info(f"Saved eval results to {output_file}")
49+
50+
51+
@TrackerConfig.register_subclass("json_file")
52+
@dataclasses.dataclass
53+
class JsonFileTrackerConfig(TrackerConfig):
54+
output_path: str = ""
55+
56+
def init(self, run_id: Optional[str]) -> Tracker:
57+
if jax.process_index() != 0:
58+
return NoopTracker()
59+
return JsonFileTracker(self.output_path)

lib/marin/src/marin/evaluation/log_probs.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
import os
1010
from dataclasses import dataclass
1111

12-
from fray.v1.cluster import Entrypoint, EnvironmentConfig, JobRequest, ResourceConfig, current_cluster
13-
from fray.v1.cluster.base import TpuConfig
12+
from fray.v2 import current_client
13+
from fray.v2.types import Entrypoint, JobRequest, ResourceConfig, TpuConfig, create_environment
1414
from levanter.compat.hf_checkpoints import RepoRef
1515
from levanter.data.text import LMMixtureDatasetConfig
1616
from levanter.distributed import RayConfig
1717
from levanter.main.eval_lm import EvalLmConfig as LevanterEvalLmConfig
1818
from levanter.main.eval_lm import main as eval_lm_main
1919
from levanter.models.lm_model import LmConfig
20+
from levanter.tracker.json_file import JsonFileTrackerConfig
2021
from levanter.tracker.wandb import WandbConfig
2122
from levanter.trainer import TrainerConfig
2223

@@ -128,7 +129,10 @@ def evaluate_lm_log_probs(config: EvalLmConfig) -> None:
128129
model=config.model,
129130
data=config.datasets,
130131
trainer=TrainerConfig(
131-
tracker=WandbConfig(project="marin", tags=wandb_tags, name=name),
132+
tracker=(
133+
WandbConfig(project="marin", tags=wandb_tags, name=name),
134+
JsonFileTrackerConfig(output_path=config.output_path),
135+
),
132136
ray=RayConfig(auto_start_cluster=False),
133137
per_device_eval_parallelism=config.per_device_batch_size,
134138
max_eval_batches=max_eval_batches,
@@ -138,12 +142,14 @@ def evaluate_lm_log_probs(config: EvalLmConfig) -> None:
138142

139143
assert isinstance(config.resource_config.device, TpuConfig), "evaluate_lm_log_probs requires TPU resource config"
140144

141-
cluster = current_cluster()
145+
extras = ["tpu"]
146+
147+
client = current_client()
142148
job_request = JobRequest(
143149
name=f"eval-lm-{name}",
144150
resources=config.resource_config,
145151
entrypoint=Entrypoint.from_callable(do_eval_lm, args=[levanter_config]),
146-
environment=EnvironmentConfig.create(),
152+
environment=create_environment(extras=extras),
147153
)
148-
job_id = cluster.launch(job_request)
149-
cluster.wait(job_id, raise_on_failure=True)
154+
job = client.submit(job_request)
155+
job.wait(raise_on_failure=True)
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Copyright 2025 The Marin Authors
5+
# SPDX-License-Identifier: Apache-2.0
6+
7+
"""
8+
Save per-token log probabilities for a language model on a dataset.
9+
10+
This module computes per-token logprobs using Levanter on TPU and saves them
11+
to gzipped JSONL files. Optionally saves top-k logprobs at each position.
12+
"""
13+
14+
import json
15+
import logging
16+
import os
17+
from contextlib import nullcontext
18+
from dataclasses import dataclass, field, replace
19+
20+
import equinox as eqx
21+
import fsspec
22+
import jax
23+
24+
import jmp
25+
import numpy as np
26+
from jax.experimental import multihost_utils
27+
28+
import haliax as hax
29+
from haliax import Axis
30+
from haliax.partitioning import round_axis_for_partitioning
31+
32+
import levanter
33+
from levanter.checkpoint import load_checkpoint
34+
from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef
35+
from levanter.data import DataLoader
36+
from levanter.data.text import DatasetComponent, LmDataConfig, LMMixtureDatasetConfig
37+
from levanter.distributed import RayConfig
38+
from levanter.models.llama import LlamaConfig
39+
from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel
40+
from levanter.models.loss import next_token_loss
41+
from levanter.tracker import NoopConfig
42+
from levanter.trainer import TrainerConfig
43+
from levanter.utils.jax_utils import use_cpu_device
44+
from levanter.utils.tree_utils import inference_mode
45+
46+
from fray.v2 import current_client
47+
from fray.v2.types import Entrypoint, JobRequest, ResourceConfig, TpuConfig, create_environment
48+
49+
from marin.execution.executor import ExecutorStep, InputName, this_output_path
50+
from marin.utilities.executor_utils import ckpt_path_to_step_name
51+
52+
logger = logging.getLogger(__name__)
53+
54+
55+
@dataclass
56+
class SaveLogprobsConfig:
57+
"""Configuration for saving per-token logprobs. Also serves as the Levanter init config."""
58+
59+
trainer: TrainerConfig = field(default_factory=lambda: TrainerConfig(mp=jmp.get_policy("c=bf16")))
60+
data: LmDataConfig = field(default_factory=LmDataConfig)
61+
model: LmConfig = field(default_factory=LlamaConfig)
62+
checkpoint_path: str | None = None
63+
checkpoint_is_hf: bool = False
64+
max_eval_length: int = 4096
65+
output_path: str = ""
66+
top_k: int | None = None
67+
68+
69+
@dataclass(frozen=True)
70+
class SaveLogprobsOnPodConfig:
71+
"""Wrapper config for running save_logprobs on a TPU pod via fray."""
72+
73+
save_logprobs_config: SaveLogprobsConfig
74+
resources: ResourceConfig
75+
76+
77+
def _force_pack_data(data: LmDataConfig) -> LmDataConfig:
78+
packed_components = {
79+
name: replace(component, pack=True) if isinstance(component, DatasetComponent) else component
80+
for name, component in data.components.items()
81+
}
82+
packed_data = replace(data, components=packed_components, block_cross_document_attention=True)
83+
return packed_data
84+
85+
86+
def save_logprobs(config: SaveLogprobsConfig) -> None:
87+
"""Compute and save per-token logprobs."""
88+
levanter.initialize(config)
89+
tokenizer = config.data.the_tokenizer
90+
91+
hf_checkpoint = RepoRef.from_string(config.checkpoint_path) if config.checkpoint_is_hf else None
92+
93+
EvalBatch = config.trainer.EvalBatch
94+
Pos = config.model.max_Pos.resize(config.max_eval_length)
95+
96+
packed_data = _force_pack_data(config.data)
97+
validation_sets = packed_data.validation_sets(Pos)
98+
99+
compute_axis_mapping = config.trainer.compute_axis_mapping
100+
parameter_axis_mapping = config.trainer.parameter_axis_mapping
101+
102+
with config.trainer.use_device_mesh(), hax.axis_mapping(parameter_axis_mapping):
103+
key = jax.random.PRNGKey(0)
104+
105+
vocab_size = len(tokenizer)
106+
Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), compute_axis_mapping)
107+
if vocab_size != Vocab.size:
108+
logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning")
109+
110+
mp: jmp.Policy = config.trainer.mp
111+
112+
@hax.named_jit
113+
def compute_forward(model: LmHeadModel, example: LmExample):
114+
"""Shared forward pass: returns per-token logprobs and logits."""
115+
model = inference_mode(model, True)
116+
model = mp.cast_to_compute(model)
117+
activations = model.activations(example.tokens, example.attn_mask, key=key)
118+
logits = hax.dot(activations, model.get_lm_head(), axis=model.Embed)
119+
loss = next_token_loss(
120+
model.Pos,
121+
model.Vocab,
122+
logits=logits,
123+
true_ids=example.tokens,
124+
loss_weight=example.loss_weight,
125+
reduction=None,
126+
)
127+
logprobs = hax.nn.log_softmax(logits, axis=model.Vocab)
128+
129+
return loss.rearrange((EvalBatch, Pos)), logprobs.rearrange((EvalBatch, Pos, model.Vocab))
130+
131+
@hax.named_jit
132+
def compute_top(logprobs: hax.NamedArray, k: int):
133+
top_k_values, top_k_indices = hax.top_k(logprobs, model.Vocab, k=k, new_axis="top_k")
134+
TopK = top_k_values.resolve_axis("top_k")
135+
return top_k_values.rearrange((EvalBatch, Pos, TopK)), top_k_indices.rearrange((EvalBatch, Pos, TopK))
136+
137+
# Load model
138+
if config.checkpoint_path is not None and not config.checkpoint_is_hf:
139+
with use_cpu_device():
140+
model = eqx.filter_eval_shape(config.model.build, Vocab, key=key)
141+
model = load_checkpoint(model, config.checkpoint_path, subpath="model")
142+
model = hax.shard_with_axis_mapping(model, parameter_axis_mapping)
143+
elif hf_checkpoint is not None:
144+
model_config = config.model
145+
if not hasattr(model_config, "hf_checkpoint_converter"):
146+
raise ValueError("Model config does not have an HF checkpoint converter. Can't load HF checkpoint.")
147+
converter: HFCheckpointConverter = model_config.hf_checkpoint_converter()
148+
converter = converter.replaced(reference_checkpoint=hf_checkpoint, tokenizer=tokenizer)
149+
model = converter.load_pretrained(model_config.model_type, ref=hf_checkpoint, dtype=mp.compute_dtype)
150+
else:
151+
raise AssertionError("Should not get here")
152+
153+
for name, dataset in validation_sets.items():
154+
loader = DataLoader(
155+
dataset,
156+
config.trainer.eval_batch_size,
157+
mesh=config.trainer.device_mesh,
158+
axis_resources=compute_axis_mapping,
159+
)
160+
161+
output_file = os.path.join(config.output_path, name, "outputs.jsonl.gz")
162+
cm = fsspec.open(output_file, "wt", compression="gzip") if jax.process_index() == 0 else nullcontext()
163+
with cm as f:
164+
for batch in loader:
165+
with hax.axis_mapping(compute_axis_mapping):
166+
out = compute_forward(model, batch)
167+
b_loss, b_logprobs = out
168+
169+
if config.top_k is not None:
170+
b_topk_vals, b_topk_ids = compute_top(b_logprobs, config.top_k)
171+
b_topk_vals, b_topk_ids = multihost_utils.process_allgather(
172+
(b_topk_vals, b_topk_ids), tiled=True
173+
)
174+
175+
b_tokens, b_seg_ids = batch.tokens.rearrange((EvalBatch, Pos)), batch.attn_mask.segment_ids[
176+
0
177+
].rearrange((EvalBatch, Pos))
178+
b_loss, b_tokens, b_seg_ids = multihost_utils.process_allgather(
179+
(b_loss, b_tokens, b_seg_ids), tiled=True
180+
)
181+
182+
if jax.process_index() == 0:
183+
b_loss = np.array(b_loss.array)
184+
b_tokens = np.array(b_tokens.array)
185+
b_seg_ids = np.array(b_seg_ids.array)
186+
187+
if config.top_k is not None:
188+
b_topk_ids = np.array(b_topk_ids.array)
189+
b_topk_vals = np.array(b_topk_vals.array)
190+
191+
for i in range(len(b_tokens)):
192+
if np.all(b_tokens[i] == 0):
193+
continue
194+
195+
unique_ids = np.unique(b_seg_ids[i])
196+
unique_ids = unique_ids[unique_ids >= 0] # exclude padding (-1)
197+
198+
for seg_id in unique_ids:
199+
mask = b_seg_ids[i] == seg_id
200+
record = {
201+
"token_ids": b_tokens[i][mask].tolist(),
202+
"losses": b_loss[i][mask].tolist(),
203+
}
204+
if config.top_k is not None:
205+
record["top_k_token_ids"] = b_topk_ids[i][mask].tolist()
206+
record["top_k_logprobs"] = b_topk_vals[i][mask].tolist()
207+
f.write(json.dumps(record) + "\n")
208+
209+
if jax.process_index() == 0:
210+
logger.info(f"Saved logprobs to {output_file}")
211+
212+
levanter.tracker.current_tracker().finish()
213+
214+
215+
def run_save_logprobs_on_pod(config: SaveLogprobsOnPodConfig) -> None:
216+
"""Submit save_logprobs as a fray job on a TPU pod and wait for completion."""
217+
client = current_client()
218+
219+
extras = []
220+
if isinstance(config.resources.device, TpuConfig):
221+
extras.append("tpu")
222+
223+
job_request = JobRequest(
224+
name="save_logprobs",
225+
entrypoint=Entrypoint.from_callable(save_logprobs, args=[config.save_logprobs_config]),
226+
resources=config.resources,
227+
environment=create_environment(extras=extras),
228+
)
229+
job = client.submit(job_request)
230+
job.wait(raise_on_failure=True)
231+
232+
233+
def default_save_logprobs(
234+
checkpoint: str | InputName,
235+
model: LmConfig,
236+
data: LMMixtureDatasetConfig,
237+
resource_config: ResourceConfig,
238+
checkpoint_is_hf: bool,
239+
per_device_batch_size: int = 4,
240+
top_k: int | None = None,
241+
name: str | None = None,
242+
) -> ExecutorStep:
243+
"""Creates an ExecutorStep that saves per-token logprobs to disk."""
244+
if not name:
245+
name = ckpt_path_to_step_name(checkpoint)
246+
247+
return ExecutorStep(
248+
name=f"analysis/logprobs/{name}",
249+
fn=run_save_logprobs_on_pod,
250+
config=SaveLogprobsOnPodConfig(
251+
save_logprobs_config=SaveLogprobsConfig(
252+
checkpoint_path=checkpoint, # type: ignore
253+
checkpoint_is_hf=checkpoint_is_hf,
254+
model=model,
255+
data=data,
256+
trainer=TrainerConfig(
257+
tracker=NoopConfig(),
258+
ray=RayConfig(auto_start_cluster=False),
259+
per_device_eval_parallelism=per_device_batch_size,
260+
mp=jmp.get_policy("c=bf16"),
261+
),
262+
output_path=this_output_path(),
263+
top_k=top_k,
264+
),
265+
resources=resource_config,
266+
),
267+
)

0 commit comments

Comments
 (0)