|
| 1 | +# Copyright The Marin Authors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +# Copyright 2025 The Marin Authors |
| 5 | +# SPDX-License-Identifier: Apache-2.0 |
| 6 | + |
| 7 | +""" |
| 8 | +Save per-token log probabilities for a language model on a dataset. |
| 9 | +
|
| 10 | +This module computes per-token logprobs using Levanter on TPU and saves them |
| 11 | +to gzipped JSONL files. Optionally saves top-k logprobs at each position. |
| 12 | +""" |
| 13 | + |
| 14 | +import json |
| 15 | +import logging |
| 16 | +import os |
| 17 | +from contextlib import nullcontext |
| 18 | +from dataclasses import dataclass, field, replace |
| 19 | + |
| 20 | +import equinox as eqx |
| 21 | +import fsspec |
| 22 | +import jax |
| 23 | + |
| 24 | +import jmp |
| 25 | +import numpy as np |
| 26 | +from jax.experimental import multihost_utils |
| 27 | + |
| 28 | +import haliax as hax |
| 29 | +from haliax import Axis |
| 30 | +from haliax.partitioning import round_axis_for_partitioning |
| 31 | + |
| 32 | +import levanter |
| 33 | +from levanter.checkpoint import load_checkpoint |
| 34 | +from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef |
| 35 | +from levanter.data import DataLoader |
| 36 | +from levanter.data.text import DatasetComponent, LmDataConfig, LMMixtureDatasetConfig |
| 37 | +from levanter.distributed import RayConfig |
| 38 | +from levanter.models.llama import LlamaConfig |
| 39 | +from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel |
| 40 | +from levanter.models.loss import next_token_loss |
| 41 | +from levanter.tracker import NoopConfig |
| 42 | +from levanter.trainer import TrainerConfig |
| 43 | +from levanter.utils.jax_utils import use_cpu_device |
| 44 | +from levanter.utils.tree_utils import inference_mode |
| 45 | + |
| 46 | +from fray.v2 import current_client |
| 47 | +from fray.v2.types import Entrypoint, JobRequest, ResourceConfig, TpuConfig, create_environment |
| 48 | + |
| 49 | +from marin.execution.executor import ExecutorStep, InputName, this_output_path |
| 50 | +from marin.utilities.executor_utils import ckpt_path_to_step_name |
| 51 | + |
| 52 | +logger = logging.getLogger(__name__) |
| 53 | + |
| 54 | + |
| 55 | +@dataclass |
| 56 | +class SaveLogprobsConfig: |
| 57 | + """Configuration for saving per-token logprobs. Also serves as the Levanter init config.""" |
| 58 | + |
| 59 | + trainer: TrainerConfig = field(default_factory=lambda: TrainerConfig(mp=jmp.get_policy("c=bf16"))) |
| 60 | + data: LmDataConfig = field(default_factory=LmDataConfig) |
| 61 | + model: LmConfig = field(default_factory=LlamaConfig) |
| 62 | + checkpoint_path: str | None = None |
| 63 | + checkpoint_is_hf: bool = False |
| 64 | + max_eval_length: int = 4096 |
| 65 | + output_path: str = "" |
| 66 | + top_k: int | None = None |
| 67 | + |
| 68 | + |
| 69 | +@dataclass(frozen=True) |
| 70 | +class SaveLogprobsOnPodConfig: |
| 71 | + """Wrapper config for running save_logprobs on a TPU pod via fray.""" |
| 72 | + |
| 73 | + save_logprobs_config: SaveLogprobsConfig |
| 74 | + resources: ResourceConfig |
| 75 | + |
| 76 | + |
| 77 | +def _force_pack_data(data: LmDataConfig) -> LmDataConfig: |
| 78 | + packed_components = { |
| 79 | + name: replace(component, pack=True) if isinstance(component, DatasetComponent) else component |
| 80 | + for name, component in data.components.items() |
| 81 | + } |
| 82 | + packed_data = replace(data, components=packed_components, block_cross_document_attention=True) |
| 83 | + return packed_data |
| 84 | + |
| 85 | + |
| 86 | +def save_logprobs(config: SaveLogprobsConfig) -> None: |
| 87 | + """Compute and save per-token logprobs.""" |
| 88 | + levanter.initialize(config) |
| 89 | + tokenizer = config.data.the_tokenizer |
| 90 | + |
| 91 | + hf_checkpoint = RepoRef.from_string(config.checkpoint_path) if config.checkpoint_is_hf else None |
| 92 | + |
| 93 | + EvalBatch = config.trainer.EvalBatch |
| 94 | + Pos = config.model.max_Pos.resize(config.max_eval_length) |
| 95 | + |
| 96 | + packed_data = _force_pack_data(config.data) |
| 97 | + validation_sets = packed_data.validation_sets(Pos) |
| 98 | + |
| 99 | + compute_axis_mapping = config.trainer.compute_axis_mapping |
| 100 | + parameter_axis_mapping = config.trainer.parameter_axis_mapping |
| 101 | + |
| 102 | + with config.trainer.use_device_mesh(), hax.axis_mapping(parameter_axis_mapping): |
| 103 | + key = jax.random.PRNGKey(0) |
| 104 | + |
| 105 | + vocab_size = len(tokenizer) |
| 106 | + Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), compute_axis_mapping) |
| 107 | + if vocab_size != Vocab.size: |
| 108 | + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") |
| 109 | + |
| 110 | + mp: jmp.Policy = config.trainer.mp |
| 111 | + |
| 112 | + @hax.named_jit |
| 113 | + def compute_forward(model: LmHeadModel, example: LmExample): |
| 114 | + """Shared forward pass: returns per-token logprobs and logits.""" |
| 115 | + model = inference_mode(model, True) |
| 116 | + model = mp.cast_to_compute(model) |
| 117 | + activations = model.activations(example.tokens, example.attn_mask, key=key) |
| 118 | + logits = hax.dot(activations, model.get_lm_head(), axis=model.Embed) |
| 119 | + loss = next_token_loss( |
| 120 | + model.Pos, |
| 121 | + model.Vocab, |
| 122 | + logits=logits, |
| 123 | + true_ids=example.tokens, |
| 124 | + loss_weight=example.loss_weight, |
| 125 | + reduction=None, |
| 126 | + ) |
| 127 | + logprobs = hax.nn.log_softmax(logits, axis=model.Vocab) |
| 128 | + |
| 129 | + return loss.rearrange((EvalBatch, Pos)), logprobs.rearrange((EvalBatch, Pos, model.Vocab)) |
| 130 | + |
| 131 | + @hax.named_jit |
| 132 | + def compute_top(logprobs: hax.NamedArray, k: int): |
| 133 | + top_k_values, top_k_indices = hax.top_k(logprobs, model.Vocab, k=k, new_axis="top_k") |
| 134 | + TopK = top_k_values.resolve_axis("top_k") |
| 135 | + return top_k_values.rearrange((EvalBatch, Pos, TopK)), top_k_indices.rearrange((EvalBatch, Pos, TopK)) |
| 136 | + |
| 137 | + # Load model |
| 138 | + if config.checkpoint_path is not None and not config.checkpoint_is_hf: |
| 139 | + with use_cpu_device(): |
| 140 | + model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) |
| 141 | + model = load_checkpoint(model, config.checkpoint_path, subpath="model") |
| 142 | + model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) |
| 143 | + elif hf_checkpoint is not None: |
| 144 | + model_config = config.model |
| 145 | + if not hasattr(model_config, "hf_checkpoint_converter"): |
| 146 | + raise ValueError("Model config does not have an HF checkpoint converter. Can't load HF checkpoint.") |
| 147 | + converter: HFCheckpointConverter = model_config.hf_checkpoint_converter() |
| 148 | + converter = converter.replaced(reference_checkpoint=hf_checkpoint, tokenizer=tokenizer) |
| 149 | + model = converter.load_pretrained(model_config.model_type, ref=hf_checkpoint, dtype=mp.compute_dtype) |
| 150 | + else: |
| 151 | + raise AssertionError("Should not get here") |
| 152 | + |
| 153 | + for name, dataset in validation_sets.items(): |
| 154 | + loader = DataLoader( |
| 155 | + dataset, |
| 156 | + config.trainer.eval_batch_size, |
| 157 | + mesh=config.trainer.device_mesh, |
| 158 | + axis_resources=compute_axis_mapping, |
| 159 | + ) |
| 160 | + |
| 161 | + output_file = os.path.join(config.output_path, name, "outputs.jsonl.gz") |
| 162 | + cm = fsspec.open(output_file, "wt", compression="gzip") if jax.process_index() == 0 else nullcontext() |
| 163 | + with cm as f: |
| 164 | + for batch in loader: |
| 165 | + with hax.axis_mapping(compute_axis_mapping): |
| 166 | + out = compute_forward(model, batch) |
| 167 | + b_loss, b_logprobs = out |
| 168 | + |
| 169 | + if config.top_k is not None: |
| 170 | + b_topk_vals, b_topk_ids = compute_top(b_logprobs, config.top_k) |
| 171 | + b_topk_vals, b_topk_ids = multihost_utils.process_allgather( |
| 172 | + (b_topk_vals, b_topk_ids), tiled=True |
| 173 | + ) |
| 174 | + |
| 175 | + b_tokens, b_seg_ids = batch.tokens.rearrange((EvalBatch, Pos)), batch.attn_mask.segment_ids[ |
| 176 | + 0 |
| 177 | + ].rearrange((EvalBatch, Pos)) |
| 178 | + b_loss, b_tokens, b_seg_ids = multihost_utils.process_allgather( |
| 179 | + (b_loss, b_tokens, b_seg_ids), tiled=True |
| 180 | + ) |
| 181 | + |
| 182 | + if jax.process_index() == 0: |
| 183 | + b_loss = np.array(b_loss.array) |
| 184 | + b_tokens = np.array(b_tokens.array) |
| 185 | + b_seg_ids = np.array(b_seg_ids.array) |
| 186 | + |
| 187 | + if config.top_k is not None: |
| 188 | + b_topk_ids = np.array(b_topk_ids.array) |
| 189 | + b_topk_vals = np.array(b_topk_vals.array) |
| 190 | + |
| 191 | + for i in range(len(b_tokens)): |
| 192 | + if np.all(b_tokens[i] == 0): |
| 193 | + continue |
| 194 | + |
| 195 | + unique_ids = np.unique(b_seg_ids[i]) |
| 196 | + unique_ids = unique_ids[unique_ids >= 0] # exclude padding (-1) |
| 197 | + |
| 198 | + for seg_id in unique_ids: |
| 199 | + mask = b_seg_ids[i] == seg_id |
| 200 | + record = { |
| 201 | + "token_ids": b_tokens[i][mask].tolist(), |
| 202 | + "losses": b_loss[i][mask].tolist(), |
| 203 | + } |
| 204 | + if config.top_k is not None: |
| 205 | + record["top_k_token_ids"] = b_topk_ids[i][mask].tolist() |
| 206 | + record["top_k_logprobs"] = b_topk_vals[i][mask].tolist() |
| 207 | + f.write(json.dumps(record) + "\n") |
| 208 | + |
| 209 | + if jax.process_index() == 0: |
| 210 | + logger.info(f"Saved logprobs to {output_file}") |
| 211 | + |
| 212 | + levanter.tracker.current_tracker().finish() |
| 213 | + |
| 214 | + |
| 215 | +def run_save_logprobs_on_pod(config: SaveLogprobsOnPodConfig) -> None: |
| 216 | + """Submit save_logprobs as a fray job on a TPU pod and wait for completion.""" |
| 217 | + client = current_client() |
| 218 | + |
| 219 | + extras = [] |
| 220 | + if isinstance(config.resources.device, TpuConfig): |
| 221 | + extras.append("tpu") |
| 222 | + |
| 223 | + job_request = JobRequest( |
| 224 | + name="save_logprobs", |
| 225 | + entrypoint=Entrypoint.from_callable(save_logprobs, args=[config.save_logprobs_config]), |
| 226 | + resources=config.resources, |
| 227 | + environment=create_environment(extras=extras), |
| 228 | + ) |
| 229 | + job = client.submit(job_request) |
| 230 | + job.wait(raise_on_failure=True) |
| 231 | + |
| 232 | + |
| 233 | +def default_save_logprobs( |
| 234 | + checkpoint: str | InputName, |
| 235 | + model: LmConfig, |
| 236 | + data: LMMixtureDatasetConfig, |
| 237 | + resource_config: ResourceConfig, |
| 238 | + checkpoint_is_hf: bool, |
| 239 | + per_device_batch_size: int = 4, |
| 240 | + top_k: int | None = None, |
| 241 | + name: str | None = None, |
| 242 | +) -> ExecutorStep: |
| 243 | + """Creates an ExecutorStep that saves per-token logprobs to disk.""" |
| 244 | + if not name: |
| 245 | + name = ckpt_path_to_step_name(checkpoint) |
| 246 | + |
| 247 | + return ExecutorStep( |
| 248 | + name=f"analysis/logprobs/{name}", |
| 249 | + fn=run_save_logprobs_on_pod, |
| 250 | + config=SaveLogprobsOnPodConfig( |
| 251 | + save_logprobs_config=SaveLogprobsConfig( |
| 252 | + checkpoint_path=checkpoint, # type: ignore |
| 253 | + checkpoint_is_hf=checkpoint_is_hf, |
| 254 | + model=model, |
| 255 | + data=data, |
| 256 | + trainer=TrainerConfig( |
| 257 | + tracker=NoopConfig(), |
| 258 | + ray=RayConfig(auto_start_cluster=False), |
| 259 | + per_device_eval_parallelism=per_device_batch_size, |
| 260 | + mp=jmp.get_policy("c=bf16"), |
| 261 | + ), |
| 262 | + output_path=this_output_path(), |
| 263 | + top_k=top_k, |
| 264 | + ), |
| 265 | + resources=resource_config, |
| 266 | + ), |
| 267 | + ) |
0 commit comments