diff --git a/experiments/evals/run_reasoning_tts.py b/experiments/evals/run_reasoning_tts.py new file mode 100644 index 0000000000..2192002b01 --- /dev/null +++ b/experiments/evals/run_reasoning_tts.py @@ -0,0 +1,158 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +import logging +from contextlib import ExitStack + +from rigging.log_setup import configure_logging + +from marin.inference.chat_completions import OpenAIChatCompletionProvider +from marin.inference.model_config import ModelConfig +from marin.inference.vllm_server import VllmEnvironment +from marin.test_time_scaling import ( + DEFAULT_REASONING_SELECTORS, + CandidateGenerationConfig, + SelectorName, + TestTimeScalingConfig, + build_run_summary, + generate_candidates, + load_prompt_manifest, + replay_selectors, + write_candidate_records, + write_prompt_manifest, + write_run_summary, + write_selection_records, +) + +logger = logging.getLogger(__name__) + + +def _parse_selector_names(raw_selectors: list[str] | None) -> tuple[SelectorName, ...]: + if not raw_selectors: + return DEFAULT_REASONING_SELECTORS + return tuple(SelectorName(raw_selector) for raw_selector in raw_selectors) + + +def _parse_engine_kwargs(raw_engine_kwargs: str) -> dict: + try: + engine_kwargs = json.loads(raw_engine_kwargs) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid JSON for --engine-kwargs-json: {exc}") from exc + if not isinstance(engine_kwargs, dict): + raise ValueError("--engine-kwargs-json must decode to a JSON object") + return engine_kwargs + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run sample-only reasoning TTS against a prompt manifest.") + parser.add_argument("--manifest", required=True, help="Prompt manifest directory or prompts.jsonl path") + parser.add_argument("--output-dir", required=True, help="Directory where TTS artifacts should be written") + parser.add_argument("--model", required=True, help="Model name or request model id") + parser.add_argument("--model-path", help="Optional checkpoint path when launching a local vLLM server") + parser.add_argument("--server-url", help="Existing OpenAI-compatible /v1 server URL") + parser.add_argument("--vllm-mode", choices=["native", "docker"], help="Mode to use when launching vLLM") + parser.add_argument( + "--engine-kwargs-json", + default="{}", + help="JSON object of engine kwargs forwarded when launching vLLM", + ) + parser.add_argument( + "--selector", + dest="selectors", + action="append", + choices=[selector.value for selector in SelectorName], + help="Selector to evaluate. Can be repeated. Defaults to the built-in set.", + ) + parser.add_argument("--num-candidates", type=int, default=4, help="Number of candidates to sample per prompt") + parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature") + parser.add_argument("--top-p", type=float, default=1.0, help="Top-p nucleus sampling threshold") + parser.add_argument("--max-gen-toks", type=int, default=2048, help="Maximum generated tokens per candidate") + parser.add_argument("--seed", type=int, help="Base request seed") + parser.add_argument("--request-timeout", type=float, default=600.0, help="HTTP timeout for each request") + parser.add_argument("--startup-timeout", type=int, default=3600, help="Timeout when launching a local vLLM server") + parser.add_argument("--api-key", default="marin-tts", help="API key for OpenAI-compatible servers") + parser.add_argument( + "--extra-vllm-arg", + action="append", + default=[], + help="Extra CLI argument passed through to `vllm serve`. Can be repeated.", + ) + return parser + + +def main() -> None: + parser = _build_arg_parser() + args = parser.parse_args() + + configure_logging(level=logging.INFO) + + manifest = load_prompt_manifest(args.manifest) + run_config = TestTimeScalingConfig( + generation=CandidateGenerationConfig( + num_candidates=args.num_candidates, + temperature=args.temperature, + top_p=args.top_p, + max_gen_toks=args.max_gen_toks, + seed=args.seed, + ), + selectors=_parse_selector_names(args.selectors), + ) + + write_prompt_manifest(args.output_dir, manifest) + + with ExitStack() as stack: + if args.server_url: + server_url = args.server_url + model_id = args.model + else: + model = ModelConfig( + name=args.model, + path=args.model_path, + engine_kwargs=_parse_engine_kwargs(args.engine_kwargs_json), + ) + environment = stack.enter_context( + VllmEnvironment( + model, + mode=args.vllm_mode, + timeout_seconds=args.startup_timeout, + extra_args=args.extra_vllm_arg, + ) + ) + server_url = environment.server_url + model_id = environment.model_id or args.model + + provider = OpenAIChatCompletionProvider( + server_url=server_url, + model=model_id, + api_key=args.api_key, + timeout=args.request_timeout, + ) + candidates = generate_candidates(manifest, provider, run_config.generation) + + write_candidate_records(args.output_dir, candidates) + selections = replay_selectors(candidates, run_config.selectors) + write_selection_records(args.output_dir, selections) + summary = build_run_summary(manifest, run_config, candidates, selections) + write_run_summary(args.output_dir, summary) + + logger.info( + "Completed reasoning TTS run for %s with %d prompts and %d candidates", + manifest.task_name, + len(manifest.records), + len(candidates), + ) + for selector_summary in summary.selector_summaries: + logger.info( + "selector=%s accuracy=%s oracle_gap_rate=%s", + selector_summary.selector_name, + selector_summary.accuracy, + selector_summary.oracle_gap_rate, + ) + + +if __name__ == "__main__": + main() diff --git a/lib/marin/src/marin/evaluation/evaluators/evalchemy_evaluator.py b/lib/marin/src/marin/evaluation/evaluators/evalchemy_evaluator.py index e66a8c9762..85f8dbd84e 100644 --- a/lib/marin/src/marin/evaluation/evaluators/evalchemy_evaluator.py +++ b/lib/marin/src/marin/evaluation/evaluators/evalchemy_evaluator.py @@ -35,9 +35,10 @@ from rigging.filesystem import filesystem as marin_filesystem from marin.evaluation.evaluation_config import WANDB_PROJECT, EvalTaskConfig -from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig +from marin.evaluation.evaluators.evaluator import Evaluator from marin.inference.vllm_server import resolve_model_name_or_path from marin.evaluation.utils import is_remote_path, upload_to_gcs +from marin.inference.model_config import ModelConfig logger = logging.getLogger(__name__) diff --git a/lib/marin/src/marin/evaluation/evaluators/evaluator.py b/lib/marin/src/marin/evaluation/evaluators/evaluator.py index 9bf07cfa0c..cef215b956 100644 --- a/lib/marin/src/marin/evaluation/evaluators/evaluator.py +++ b/lib/marin/src/marin/evaluation/evaluators/evaluator.py @@ -2,39 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any from marin.evaluation.evaluation_config import EvalTaskConfig - - -@dataclass -class ModelConfig: - name: str - """The name of the model e.g., allenai/olmo-7b""" - - path: str | None - """ - The path to the model checkpoint. Can be a local path or a path on GCS. - """ - - engine_kwargs: dict[str, Any] - """ - Additional keyword arguments to pass to the vLLM engine. - """ - - generation_params: dict | None = None - """ - Additional keyword arguments passed to the SamplingParams for the vLLM engine - """ - - apply_chat_template: bool = False - """ - Whether or not this model was trained with a Chat Template in the tokenizer - """ - - base_eval_run_name: str | None = None - """Custom base name for wandb runs.""" +from marin.inference.model_config import ModelConfig class Evaluator(ABC): diff --git a/lib/marin/src/marin/evaluation/evaluators/harbor_evaluator.py b/lib/marin/src/marin/evaluation/evaluators/harbor_evaluator.py index 9ec89b66f6..0809297b07 100644 --- a/lib/marin/src/marin/evaluation/evaluators/harbor_evaluator.py +++ b/lib/marin/src/marin/evaluation/evaluators/harbor_evaluator.py @@ -27,8 +27,9 @@ from rigging.filesystem import open_url from marin.evaluation.evaluation_config import EvalTaskConfig -from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig +from marin.evaluation.evaluators.evaluator import Evaluator from marin.evaluation.utils import download_from_gcs, is_remote_path, upload_to_gcs +from marin.inference.model_config import ModelConfig from marin.inference.vllm_server import VllmEnvironment from marin.utils import fsspec_exists, fsspec_glob diff --git a/lib/marin/src/marin/evaluation/evaluators/levanter_lm_eval_evaluator.py b/lib/marin/src/marin/evaluation/evaluators/levanter_lm_eval_evaluator.py index 8e89e9c69b..4045cbbecb 100644 --- a/lib/marin/src/marin/evaluation/evaluators/levanter_lm_eval_evaluator.py +++ b/lib/marin/src/marin/evaluation/evaluators/levanter_lm_eval_evaluator.py @@ -15,7 +15,8 @@ from levanter.trainer import TrainerConfig from marin.evaluation.evaluation_config import EvalTaskConfig, convert_to_levanter_task_config -from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig +from marin.evaluation.evaluators.evaluator import Evaluator +from marin.inference.model_config import ModelConfig logger = logging.getLogger(__name__) diff --git a/lib/marin/src/marin/evaluation/evaluators/lm_evaluation_harness_evaluator.py b/lib/marin/src/marin/evaluation/evaluators/lm_evaluation_harness_evaluator.py index 220b4a6f40..fea09d1896 100644 --- a/lib/marin/src/marin/evaluation/evaluators/lm_evaluation_harness_evaluator.py +++ b/lib/marin/src/marin/evaluation/evaluators/lm_evaluation_harness_evaluator.py @@ -12,8 +12,9 @@ from rigging.filesystem import open_url, url_to_fs from marin.evaluation.evaluation_config import EvalTaskConfig -from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig +from marin.evaluation.evaluators.evaluator import Evaluator from marin.evaluation.utils import is_remote_path, upload_to_gcs +from marin.inference.model_config import ModelConfig from marin.inference.vllm_server import VllmEnvironment logger = logging.getLogger(__name__) diff --git a/lib/marin/src/marin/evaluation/evaluators/simple_evaluator.py b/lib/marin/src/marin/evaluation/evaluators/simple_evaluator.py index 1cab5d8e22..95133d315c 100644 --- a/lib/marin/src/marin/evaluation/evaluators/simple_evaluator.py +++ b/lib/marin/src/marin/evaluation/evaluators/simple_evaluator.py @@ -7,7 +7,8 @@ from typing import ClassVar from marin.evaluation.evaluation_config import EvalTaskConfig -from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig +from marin.evaluation.evaluators.evaluator import Evaluator +from marin.inference.model_config import ModelConfig from marin.inference.vllm_server import resolve_model_name_or_path diff --git a/lib/marin/src/marin/evaluation/run.py b/lib/marin/src/marin/evaluation/run.py index c436feceea..0d8a79b4c9 100644 --- a/lib/marin/src/marin/evaluation/run.py +++ b/lib/marin/src/marin/evaluation/run.py @@ -17,13 +17,14 @@ import draccus from marin.evaluation.evaluation_config import EvaluationConfig -from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig +from marin.evaluation.evaluators.evaluator import Evaluator from marin.evaluation.evaluators.evalchemy_evaluator import EvalchemyEvaluator from marin.evaluation.evaluators.harbor_evaluator import HarborEvaluator from marin.evaluation.evaluators.levanter_lm_eval_evaluator import LevanterLmEvalEvaluator from marin.evaluation.evaluators.lm_evaluation_harness_evaluator import LMEvaluationHarnessEvaluator from marin.evaluation.evaluators.simple_evaluator import SimpleEvaluator from marin.evaluation.utils import discover_hf_checkpoints +from marin.inference.model_config import ModelConfig from marin.utils import fsspec_exists logger = logging.getLogger(__name__) diff --git a/lib/marin/src/marin/inference/chat_completions.py b/lib/marin/src/marin/inference/chat_completions.py new file mode 100644 index 0000000000..66597a8394 --- /dev/null +++ b/lib/marin/src/marin/inference/chat_completions.py @@ -0,0 +1,74 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Protocol + +from openai import OpenAI +from openai.types.chat import ChatCompletion + + +@dataclass(frozen=True) +class ChatCompletionRequest: + """OpenAI-compatible chat completion request parameters.""" + + messages: tuple[dict[str, str], ...] + num_completions: int + temperature: float + top_p: float = 1.0 + max_tokens: int | None = None + seed: int | None = None + logprobs: bool = False + + def __post_init__(self) -> None: + if self.num_completions <= 0: + raise ValueError("num_completions must be positive") + if self.temperature < 0: + raise ValueError("temperature must be non-negative") + if not 0 < self.top_p <= 1.0: + raise ValueError("top_p must be in the interval (0, 1]") + if self.max_tokens is not None and self.max_tokens <= 0: + raise ValueError("max_tokens must be positive when set") + + +class CompletionProvider(Protocol): + """Protocol for chat completion backends used by inference clients.""" + + def complete_messages(self, request: ChatCompletionRequest) -> ChatCompletion: + """Return an OpenAI-compatible chat completion response.""" + + +class OpenAIChatCompletionProvider: + """Minimal synchronous OpenAI-compatible completion provider.""" + + def __init__( + self, + *, + server_url: str, + model: str, + api_key: str = "marin-tts", + timeout: float | None = None, + extra_request_kwargs: dict[str, Any] | None = None, + ) -> None: + self._client = OpenAI(base_url=server_url, api_key=api_key, timeout=timeout) + self._model = model + self._extra_request_kwargs = dict(extra_request_kwargs or {}) + + def complete_messages(self, request: ChatCompletionRequest) -> ChatCompletion: + request_kwargs: dict[str, Any] = { + "model": self._model, + "messages": list(request.messages), + "n": request.num_completions, + "temperature": request.temperature, + "top_p": request.top_p, + "logprobs": request.logprobs, + **self._extra_request_kwargs, + } + if request.max_tokens is not None: + request_kwargs["max_tokens"] = request.max_tokens + if request.seed is not None: + request_kwargs["seed"] = request.seed + + return self._client.chat.completions.create(**request_kwargs) diff --git a/lib/marin/src/marin/inference/model_config.py b/lib/marin/src/marin/inference/model_config.py new file mode 100644 index 0000000000..448a2910b9 --- /dev/null +++ b/lib/marin/src/marin/inference/model_config.py @@ -0,0 +1,19 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass +class ModelConfig: + """Configuration for launching or querying an inference model.""" + + name: str + path: str | None + engine_kwargs: dict[str, Any] + generation_params: dict | None = None + apply_chat_template: bool = False + base_eval_run_name: str | None = None diff --git a/lib/marin/src/marin/inference/vllm_server.py b/lib/marin/src/marin/inference/vllm_server.py index cf06b3d97f..b720200b0f 100644 --- a/lib/marin/src/marin/inference/vllm_server.py +++ b/lib/marin/src/marin/inference/vllm_server.py @@ -21,7 +21,7 @@ import requests from rigging.filesystem import marin_prefix -from marin.evaluation.evaluators.evaluator import ModelConfig +from marin.inference.model_config import ModelConfig logger = logging.getLogger(__name__) DEFAULT_VLLM_TPU_DOCKER_IMAGE: str = "vllm/vllm-tpu:nightly-20260104-4a1e25b-0d4044e" diff --git a/lib/marin/src/marin/inference/vllm_smoke_test.py b/lib/marin/src/marin/inference/vllm_smoke_test.py index 3a56818818..3787f7ef27 100644 --- a/lib/marin/src/marin/inference/vllm_smoke_test.py +++ b/lib/marin/src/marin/inference/vllm_smoke_test.py @@ -13,7 +13,7 @@ from fray.v2 import current_client from fray.v2.types import Entrypoint, JobRequest, ResourceConfig, create_environment -from marin.evaluation.evaluators.evaluator import ModelConfig +from marin.inference.model_config import ModelConfig from marin.inference.vllm_server import VLLM_NATIVE_PIP_PACKAGES, VllmEnvironment, resolve_vllm_mode from marin.utils import remove_tpu_lockfile_on_exit diff --git a/lib/marin/src/marin/test_time_scaling/__init__.py b/lib/marin/src/marin/test_time_scaling/__init__.py new file mode 100644 index 0000000000..60c8233b06 --- /dev/null +++ b/lib/marin/src/marin/test_time_scaling/__init__.py @@ -0,0 +1,35 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from marin.test_time_scaling.analysis import build_run_summary, group_candidates_by_prompt, replay_selectors +from marin.test_time_scaling.config import ( + DEFAULT_REASONING_SELECTORS, + CandidateGenerationConfig, + ScoringMode, + SelectorName, + TestTimeScalingConfig, +) +from marin.test_time_scaling.generate import generate_candidates +from marin.test_time_scaling.manifests import ( + MANIFEST_FILENAME, + PROMPTS_FILENAME, + PromptManifest, + PromptManifestRecord, + PromptMessage, + load_prompt_manifest, + write_prompt_manifest, +) +from marin.test_time_scaling.results import ( + CANDIDATES_FILENAME, + SELECTED_FILENAME, + SUMMARY_FILENAME, + CandidateRecord, + RunSummary, + SelectionRecord, + SelectorSummary, + read_candidate_records, + read_selection_records, + write_candidate_records, + write_run_summary, + write_selection_records, +) diff --git a/lib/marin/src/marin/test_time_scaling/analysis.py b/lib/marin/src/marin/test_time_scaling/analysis.py new file mode 100644 index 0000000000..092f94f629 --- /dev/null +++ b/lib/marin/src/marin/test_time_scaling/analysis.py @@ -0,0 +1,121 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import defaultdict + +from marin.test_time_scaling.budgets import budget_totals_from_candidates +from marin.test_time_scaling.config import SelectorName, TestTimeScalingConfig +from marin.test_time_scaling.manifests import PromptManifest +from marin.test_time_scaling.results import CandidateRecord, RunSummary, SelectionRecord, SelectorSummary +from marin.test_time_scaling.selectors import select_candidate + + +def group_candidates_by_prompt(candidates: list[CandidateRecord]) -> dict[str, list[CandidateRecord]]: + """Group generated candidates by prompt id.""" + + grouped_candidates: dict[str, list[CandidateRecord]] = defaultdict(list) + for candidate in candidates: + grouped_candidates[candidate.prompt_id].append(candidate) + for prompt_id in grouped_candidates: + grouped_candidates[prompt_id].sort(key=lambda candidate: candidate.sample_index) + return dict(grouped_candidates) + + +def replay_selectors( + candidates: list[CandidateRecord], + selectors: tuple[SelectorName, ...], +) -> list[SelectionRecord]: + """Replay selectors against a saved candidate pool.""" + + selections: list[SelectionRecord] = [] + for prompt_id, prompt_candidates in group_candidates_by_prompt(candidates).items(): + oracle_values = [candidate.is_correct for candidate in prompt_candidates if candidate.is_correct is not None] + oracle_correct = any(oracle_values) if oracle_values else None + for selector_name in selectors: + chosen_candidate = select_candidate(prompt_candidates, selector_name) + correctness = chosen_candidate.is_correct + oracle_gap = None + if oracle_correct is not None and correctness is not None: + oracle_gap = oracle_correct and not correctness + selections.append( + SelectionRecord( + prompt_id=prompt_id, + selector_name=selector_name, + chosen_candidate_id=chosen_candidate.candidate_id, + chosen_sample_index=chosen_candidate.sample_index, + final_selected_answer=chosen_candidate.extracted_answer or chosen_candidate.raw_text, + correctness=correctness, + oracle_correct=oracle_correct, + oracle_gap=oracle_gap, + ) + ) + return selections + + +def build_run_summary( + manifest: PromptManifest, + run_config: TestTimeScalingConfig, + candidates: list[CandidateRecord], + selections: list[SelectionRecord], +) -> RunSummary: + """Build the top-level run summary from saved candidates and selections.""" + + budget_totals = budget_totals_from_candidates(candidates) + parse_values = [candidate.parse_valid for candidate in candidates if candidate.parse_valid is not None] + parse_valid_rate = sum(parse_values) / len(parse_values) if parse_values else None + + duplicate_count = 0 + for prompt_candidates in group_candidates_by_prompt(candidates).values(): + unique_texts = {candidate.raw_text.strip() for candidate in prompt_candidates} + duplicate_count += max(0, len(prompt_candidates) - len(unique_texts)) + duplicate_rate = duplicate_count / len(candidates) if candidates else None + + oracle_by_prompt: dict[str, bool] = {} + for selection in selections: + if selection.oracle_correct is None or selection.prompt_id in oracle_by_prompt: + continue + oracle_by_prompt[selection.prompt_id] = selection.oracle_correct + oracle_values = list(oracle_by_prompt.values()) + oracle_accuracy = sum(oracle_values) / len(oracle_values) if oracle_values else None + + selector_summaries: list[SelectorSummary] = [] + for selector_name in run_config.selectors: + selector_rows = [selection for selection in selections if selection.selector_name == selector_name] + scored_values = [selection.correctness for selection in selector_rows if selection.correctness is not None] + gap_values = [selection.oracle_gap for selection in selector_rows if selection.oracle_gap is not None] + accuracy = sum(scored_values) / len(scored_values) if scored_values else None + oracle_gap_rate = sum(gap_values) / len(gap_values) if gap_values else None + selector_summaries.append( + SelectorSummary( + selector_name=selector_name, + num_prompts=len(selector_rows), + num_scored_prompts=len(scored_values), + accuracy=accuracy, + oracle_gap_rate=oracle_gap_rate, + ) + ) + + return RunSummary( + manifest_id=manifest.manifest_id, + task_name=manifest.task_name, + num_prompts=len(manifest.records), + total_candidates=len(candidates), + oracle_accuracy=oracle_accuracy, + parse_valid_rate=parse_valid_rate, + duplicate_rate=duplicate_rate, + total_prompt_tokens=budget_totals.total_prompt_tokens, + total_completion_tokens=budget_totals.total_completion_tokens, + total_request_latency_seconds=budget_totals.total_request_latency_seconds, + selector_summaries=tuple(selector_summaries), + metadata={ + "generation": { + "num_candidates": run_config.generation.num_candidates, + "temperature": run_config.generation.temperature, + "top_p": run_config.generation.top_p, + "max_gen_toks": run_config.generation.max_gen_toks, + "seed": run_config.generation.seed, + } + }, + ) diff --git a/lib/marin/src/marin/test_time_scaling/budgets.py b/lib/marin/src/marin/test_time_scaling/budgets.py new file mode 100644 index 0000000000..1dc352730b --- /dev/null +++ b/lib/marin/src/marin/test_time_scaling/budgets.py @@ -0,0 +1,44 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +from marin.test_time_scaling.results import CandidateRecord + + +@dataclass(frozen=True) +class BudgetTotals: + """Aggregate budget usage derived from saved candidate records.""" + + num_prompts: int + total_candidates: int + total_prompt_tokens: int + total_completion_tokens: int + total_request_latency_seconds: float + + +def budget_totals_from_candidates(candidates: list[CandidateRecord]) -> BudgetTotals: + """Compute aggregate token and latency budgets from generated candidates.""" + + prompt_ids = {candidate.prompt_id for candidate in candidates} + total_prompt_tokens = 0 + total_completion_tokens = sum(candidate.completion_tokens or 0 for candidate in candidates) + total_request_latency_seconds = 0.0 + + seen_prompts: set[str] = set() + for candidate in candidates: + if candidate.prompt_id in seen_prompts: + continue + seen_prompts.add(candidate.prompt_id) + total_prompt_tokens += candidate.prompt_tokens or 0 + total_request_latency_seconds += candidate.request_latency_seconds + + return BudgetTotals( + num_prompts=len(prompt_ids), + total_candidates=len(candidates), + total_prompt_tokens=total_prompt_tokens, + total_completion_tokens=total_completion_tokens, + total_request_latency_seconds=total_request_latency_seconds, + ) diff --git a/lib/marin/src/marin/test_time_scaling/config.py b/lib/marin/src/marin/test_time_scaling/config.py new file mode 100644 index 0000000000..18989e2335 --- /dev/null +++ b/lib/marin/src/marin/test_time_scaling/config.py @@ -0,0 +1,62 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum + + +class ScoringMode(StrEnum): + """Supported prompt scoring modes.""" + + UNSCORED = "unscored" + MATH_BOXED = "math_boxed" + + +class SelectorName(StrEnum): + """Built-in selector names for replayable sample-only experiments.""" + + FIRST_SAMPLE = "first_sample" + MAJORITY_VOTE = "majority_vote" + NORMALIZED_LOGPROB = "normalized_logprob" + + +DEFAULT_REASONING_SELECTORS: tuple[SelectorName, ...] = ( + SelectorName.FIRST_SAMPLE, + SelectorName.MAJORITY_VOTE, + SelectorName.NORMALIZED_LOGPROB, +) + + +@dataclass(frozen=True) +class CandidateGenerationConfig: + """Sampling configuration for candidate-pool generation.""" + + num_candidates: int + temperature: float + top_p: float = 1.0 + max_gen_toks: int | None = None + seed: int | None = None + + def __post_init__(self) -> None: + if self.num_candidates <= 0: + raise ValueError("num_candidates must be positive") + if self.temperature < 0: + raise ValueError("temperature must be non-negative") + if not 0 < self.top_p <= 1.0: + raise ValueError("top_p must be in the interval (0, 1]") + if self.max_gen_toks is not None and self.max_gen_toks <= 0: + raise ValueError("max_gen_toks must be positive when set") + + +@dataclass(frozen=True) +class TestTimeScalingConfig: + """Top-level sample-only run configuration for PR 1.""" + + generation: CandidateGenerationConfig + selectors: tuple[SelectorName, ...] = DEFAULT_REASONING_SELECTORS + + def __post_init__(self) -> None: + if not self.selectors: + raise ValueError("selectors must be non-empty") diff --git a/lib/marin/src/marin/test_time_scaling/generate.py b/lib/marin/src/marin/test_time_scaling/generate.py new file mode 100644 index 0000000000..5d2b7074f5 --- /dev/null +++ b/lib/marin/src/marin/test_time_scaling/generate.py @@ -0,0 +1,96 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import time + +from openai.types.chat.chat_completion import Choice + +from marin.inference.chat_completions import ChatCompletionRequest, CompletionProvider +from marin.test_time_scaling.config import CandidateGenerationConfig +from marin.test_time_scaling.manifests import PromptManifest, PromptManifestRecord +from marin.test_time_scaling.results import CandidateRecord +from marin.test_time_scaling.scorers import score_candidate_text + + +def _choice_logprob_stats(choice: Choice) -> tuple[float | None, float | None, int | None]: + if not choice.logprobs or not choice.logprobs.content: + return None, None, None + + values = [token.logprob for token in choice.logprobs.content if token.logprob is not None] + if not values: + return None, None, len(choice.logprobs.content) + + logprob_sum = float(sum(values)) + completion_tokens = len(choice.logprobs.content) + normalized_logprob = logprob_sum / completion_tokens if completion_tokens else None + return logprob_sum, normalized_logprob, completion_tokens + + +def _candidate_from_choice( + *, + prompt: PromptManifestRecord, + choice: Choice, + sample_index: int, + request_latency_seconds: float, + prompt_tokens: int | None, + generation_seed: int | None, +) -> CandidateRecord: + text = choice.message.content or "" + score = score_candidate_text(text, prompt.expected_answer, prompt.scoring_mode) + logprob_sum, normalized_logprob, completion_tokens = _choice_logprob_stats(choice) + return CandidateRecord( + prompt_id=prompt.prompt_id, + candidate_id=f"{prompt.prompt_id}-{sample_index}", + sample_index=sample_index, + raw_text=text, + extracted_answer=score.extracted_answer, + is_correct=score.is_correct, + parse_valid=score.parse_valid, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + finish_reason=choice.finish_reason, + request_latency_seconds=request_latency_seconds, + generation_seed=generation_seed, + logprob_sum=logprob_sum, + normalized_logprob=normalized_logprob, + ) + + +def generate_candidates( + manifest: PromptManifest, + provider: CompletionProvider, + generation_config: CandidateGenerationConfig, +) -> list[CandidateRecord]: + """Generate and score sample-only candidates for a prompt manifest.""" + + candidates: list[CandidateRecord] = [] + for prompt_index, prompt in enumerate(manifest.records): + request_seed = generation_config.seed + prompt_index if generation_config.seed is not None else None + started_at = time.perf_counter() + completion = provider.complete_messages( + ChatCompletionRequest( + messages=tuple(message.to_openai_dict() for message in prompt.messages), + num_completions=generation_config.num_candidates, + temperature=generation_config.temperature, + top_p=generation_config.top_p, + max_tokens=generation_config.max_gen_toks, + seed=request_seed, + logprobs=True, + ) + ) + request_latency_seconds = time.perf_counter() - started_at + prompt_tokens = completion.usage.prompt_tokens if completion.usage is not None else None + for sample_index, choice in enumerate(completion.choices): + candidates.append( + _candidate_from_choice( + prompt=prompt, + choice=choice, + sample_index=sample_index, + request_latency_seconds=request_latency_seconds, + prompt_tokens=prompt_tokens, + generation_seed=request_seed, + ) + ) + return candidates diff --git a/lib/marin/src/marin/test_time_scaling/manifests.py b/lib/marin/src/marin/test_time_scaling/manifests.py new file mode 100644 index 0000000000..96db061034 --- /dev/null +++ b/lib/marin/src/marin/test_time_scaling/manifests.py @@ -0,0 +1,153 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import Any + +import fsspec +from fsspec.core import url_to_fs + +from marin.test_time_scaling.config import ScoringMode + +MANIFEST_FILENAME = "manifest.json" +PROMPTS_FILENAME = "prompts.jsonl" +MANIFEST_FORMAT_VERSION = 1 + + +def _artifact_path(base_path: str, filename: str) -> str: + if not base_path: + return filename + return f"{base_path.rstrip('/')}/{filename}" + + +def _ensure_parent_dir(path: str) -> None: + fs, fs_path = url_to_fs(path) + parent = fs_path.rsplit("/", 1)[0] if "/" in fs_path else "" + if parent: + fs.mkdirs(parent, exist_ok=True) + + +@dataclass(frozen=True) +class PromptMessage: + """Single chat message for a prompt.""" + + role: str + content: str + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> PromptMessage: + return cls(role=str(data["role"]), content=str(data["content"])) + + def to_openai_dict(self) -> dict[str, str]: + return {"role": self.role, "content": self.content} + + def to_dict(self) -> dict[str, str]: + return self.to_openai_dict() + + +@dataclass(frozen=True) +class PromptManifestRecord: + """Prompt record written to and read from `prompts.jsonl`.""" + + prompt_id: str + messages: tuple[PromptMessage, ...] + expected_answer: str | None = None + scoring_mode: ScoringMode = ScoringMode.UNSCORED + metadata: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> PromptManifestRecord: + return cls( + prompt_id=str(data["prompt_id"]), + messages=tuple(PromptMessage.from_dict(message) for message in data["messages"]), + expected_answer=data.get("expected_answer"), + scoring_mode=ScoringMode(data.get("scoring_mode", ScoringMode.UNSCORED)), + metadata=dict(data.get("metadata", {})), + ) + + def to_dict(self) -> dict[str, Any]: + return { + "prompt_id": self.prompt_id, + "messages": [message.to_dict() for message in self.messages], + "expected_answer": self.expected_answer, + "scoring_mode": self.scoring_mode.value, + "metadata": self.metadata, + } + + +@dataclass(frozen=True) +class PromptManifest: + """Prompt manifest metadata plus ordered prompt records.""" + + manifest_id: str + task_name: str + records: tuple[PromptManifestRecord, ...] + metadata: dict[str, Any] = field(default_factory=dict) + + def header_dict(self) -> dict[str, Any]: + return { + "format_version": MANIFEST_FORMAT_VERSION, + "manifest_id": self.manifest_id, + "task_name": self.task_name, + "num_prompts": len(self.records), + "metadata": self.metadata, + } + + +def write_prompt_manifest(output_dir: str, manifest: PromptManifest) -> None: + """Write `manifest.json` and `prompts.jsonl` for a prompt manifest.""" + + manifest_path = _artifact_path(output_dir, MANIFEST_FILENAME) + prompts_path = _artifact_path(output_dir, PROMPTS_FILENAME) + _ensure_parent_dir(manifest_path) + + with fsspec.open(manifest_path, "w", encoding="utf-8") as handle: + json.dump(manifest.header_dict(), handle, indent=2, sort_keys=True) + + _ensure_parent_dir(prompts_path) + with fsspec.open(prompts_path, "w", encoding="utf-8") as handle: + for record in manifest.records: + handle.write(json.dumps(record.to_dict(), sort_keys=True) + "\n") + + +def load_prompt_manifest(path: str) -> PromptManifest: + """Load a prompt manifest from a directory, `manifest.json`, or `prompts.jsonl`.""" + + fs, fs_path = url_to_fs(path) + if fs.isdir(fs_path): + manifest_path = _artifact_path(path, MANIFEST_FILENAME) + prompts_path = _artifact_path(path, PROMPTS_FILENAME) + with fsspec.open(manifest_path, "r", encoding="utf-8") as handle: + header = json.load(handle) + elif fs_path.endswith(MANIFEST_FILENAME): + manifest_path = path + prompts_path = _artifact_path(path.rpartition("/")[0], PROMPTS_FILENAME) + with fsspec.open(manifest_path, "r", encoding="utf-8") as handle: + header = json.load(handle) + elif fs_path.endswith(".jsonl"): + prompts_path = path + header = { + "manifest_id": fs_path.rsplit("/", 1)[-1].removesuffix(".jsonl"), + "task_name": fs_path.rsplit("/", 1)[-1].removesuffix(".jsonl"), + "metadata": {}, + } + else: + raise ValueError(f"Unsupported manifest path: {path}") + + records: list[PromptManifestRecord] = [] + with fsspec.open(prompts_path, "r", encoding="utf-8") as handle: + for line in handle: + stripped = line.strip() + if not stripped: + continue + records.append(PromptManifestRecord.from_dict(json.loads(stripped))) + + return PromptManifest( + manifest_id=str(header["manifest_id"]), + task_name=str(header.get("task_name", header["manifest_id"])), + records=tuple(records), + metadata=dict(header.get("metadata", {})), + ) diff --git a/lib/marin/src/marin/test_time_scaling/results.py b/lib/marin/src/marin/test_time_scaling/results.py new file mode 100644 index 0000000000..6ef741bc2f --- /dev/null +++ b/lib/marin/src/marin/test_time_scaling/results.py @@ -0,0 +1,179 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field, is_dataclass +from enum import Enum +from typing import Any + +import fsspec +from fsspec.core import url_to_fs + +from marin.test_time_scaling.config import SelectorName + +CANDIDATES_FILENAME = "candidates.jsonl" +SELECTED_FILENAME = "selected.jsonl" +SUMMARY_FILENAME = "summary.json" + + +def _artifact_path(base_path: str, filename: str) -> str: + return f"{base_path.rstrip('/')}/{filename}" + + +def _ensure_parent_dir(path: str) -> None: + fs, fs_path = url_to_fs(path) + parent = fs_path.rsplit("/", 1)[0] if "/" in fs_path else "" + if parent: + fs.mkdirs(parent, exist_ok=True) + + +def _to_jsonable(value: Any) -> Any: + if is_dataclass(value): + return {key: _to_jsonable(item) for key, item in asdict(value).items()} + if isinstance(value, Enum): + return value.value + if isinstance(value, dict): + return {key: _to_jsonable(item) for key, item in value.items()} + if isinstance(value, list | tuple): + return [_to_jsonable(item) for item in value] + return value + + +@dataclass(frozen=True) +class CandidateRecord: + """Single generated candidate saved to `candidates.jsonl`.""" + + prompt_id: str + candidate_id: str + sample_index: int + raw_text: str + extracted_answer: str | None + is_correct: bool | None + parse_valid: bool | None + prompt_tokens: int | None + completion_tokens: int | None + finish_reason: str | None + request_latency_seconds: float + generation_seed: int | None + logprob_sum: float | None + normalized_logprob: float | None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> CandidateRecord: + return cls( + prompt_id=str(data["prompt_id"]), + candidate_id=str(data["candidate_id"]), + sample_index=int(data["sample_index"]), + raw_text=str(data["raw_text"]), + extracted_answer=data.get("extracted_answer"), + is_correct=data.get("is_correct"), + parse_valid=data.get("parse_valid"), + prompt_tokens=data.get("prompt_tokens"), + completion_tokens=data.get("completion_tokens"), + finish_reason=data.get("finish_reason"), + request_latency_seconds=float(data["request_latency_seconds"]), + generation_seed=data.get("generation_seed"), + logprob_sum=data.get("logprob_sum"), + normalized_logprob=data.get("normalized_logprob"), + ) + + +@dataclass(frozen=True) +class SelectionRecord: + """Single selector decision saved to `selected.jsonl`.""" + + prompt_id: str + selector_name: SelectorName + chosen_candidate_id: str + chosen_sample_index: int + final_selected_answer: str + correctness: bool | None + oracle_correct: bool | None + oracle_gap: bool | None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> SelectionRecord: + return cls( + prompt_id=str(data["prompt_id"]), + selector_name=SelectorName(data["selector_name"]), + chosen_candidate_id=str(data["chosen_candidate_id"]), + chosen_sample_index=int(data["chosen_sample_index"]), + final_selected_answer=str(data["final_selected_answer"]), + correctness=data.get("correctness"), + oracle_correct=data.get("oracle_correct"), + oracle_gap=data.get("oracle_gap"), + ) + + +@dataclass(frozen=True) +class SelectorSummary: + """Aggregate metrics for a single selector.""" + + selector_name: SelectorName + num_prompts: int + num_scored_prompts: int + accuracy: float | None + oracle_gap_rate: float | None + + +@dataclass(frozen=True) +class RunSummary: + """Top-level run summary written to `summary.json`.""" + + manifest_id: str + task_name: str + num_prompts: int + total_candidates: int + oracle_accuracy: float | None + parse_valid_rate: float | None + duplicate_rate: float | None + total_prompt_tokens: int + total_completion_tokens: int + total_request_latency_seconds: float + selector_summaries: tuple[SelectorSummary, ...] + metadata: dict[str, Any] = field(default_factory=dict) + + +def write_candidate_records(output_dir: str, candidates: list[CandidateRecord]) -> None: + path = _artifact_path(output_dir, CANDIDATES_FILENAME) + _ensure_parent_dir(path) + with fsspec.open(path, "w", encoding="utf-8") as handle: + for candidate in candidates: + handle.write(json.dumps(_to_jsonable(candidate), sort_keys=True) + "\n") + + +def read_candidate_records(path: str) -> list[CandidateRecord]: + candidates: list[CandidateRecord] = [] + with fsspec.open(path, "r", encoding="utf-8") as handle: + for line in handle: + stripped = line.strip() + if stripped: + candidates.append(CandidateRecord.from_dict(json.loads(stripped))) + return candidates + + +def write_selection_records(output_dir: str, selections: list[SelectionRecord]) -> None: + path = _artifact_path(output_dir, SELECTED_FILENAME) + _ensure_parent_dir(path) + with fsspec.open(path, "w", encoding="utf-8") as handle: + for selection in selections: + handle.write(json.dumps(_to_jsonable(selection), sort_keys=True) + "\n") + + +def read_selection_records(path: str) -> list[SelectionRecord]: + selections: list[SelectionRecord] = [] + with fsspec.open(path, "r", encoding="utf-8") as handle: + for line in handle: + stripped = line.strip() + if stripped: + selections.append(SelectionRecord.from_dict(json.loads(stripped))) + return selections + + +def write_run_summary(output_dir: str, summary: RunSummary) -> None: + path = _artifact_path(output_dir, SUMMARY_FILENAME) + _ensure_parent_dir(path) + with fsspec.open(path, "w", encoding="utf-8") as handle: + json.dump(_to_jsonable(summary), handle, indent=2, sort_keys=True) diff --git a/lib/marin/src/marin/test_time_scaling/scorers.py b/lib/marin/src/marin/test_time_scaling/scorers.py new file mode 100644 index 0000000000..ce5e2b2a6e --- /dev/null +++ b/lib/marin/src/marin/test_time_scaling/scorers.py @@ -0,0 +1,58 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +import re + +from marin.rl.environments.tinker_environments.math_grading import extract_boxed, grade_answer, normalize_answer +from marin.rl.math_utils import last_boxed_only_string +from marin.test_time_scaling.config import ScoringMode + + +@dataclass(frozen=True) +class CandidateScore: + """Structured score fields added to each generated candidate.""" + + extracted_answer: str | None + parse_valid: bool | None + is_correct: bool | None + + +_SIMPLE_FRAC_PATTERN = re.compile(r"^\\(?:dfrac|tfrac|frac)\{([^{}]+)\}\{([^{}]+)\}$") + + +def _normalize_extracted_answer(answer: str) -> str: + normalized = normalize_answer(answer) + if normalized is None: + return answer + + match = _SIMPLE_FRAC_PATTERN.fullmatch(normalized) + if match is None: + return normalized + + numerator, denominator = match.groups() + return f"{numerator}/{denominator}" + + +def score_candidate_text(text: str, expected_answer: str | None, scoring_mode: ScoringMode) -> CandidateScore: + """Score a generated candidate against the prompt's configured scoring mode.""" + + if scoring_mode == ScoringMode.UNSCORED: + return CandidateScore(extracted_answer=None, parse_valid=None, is_correct=None) + + if scoring_mode != ScoringMode.MATH_BOXED: + raise ValueError(f"Unsupported scoring mode: {scoring_mode}") + + if "\\boxed" not in text: + return CandidateScore(extracted_answer=None, parse_valid=False, is_correct=False if expected_answer else None) + + boxed = last_boxed_only_string(text) + if boxed is None: + return CandidateScore(extracted_answer=None, parse_valid=False, is_correct=False if expected_answer else None) + + extracted_answer_raw = extract_boxed(boxed) + extracted_answer = _normalize_extracted_answer(extracted_answer_raw) + is_correct = grade_answer(extracted_answer_raw, expected_answer) if expected_answer is not None else None + return CandidateScore(extracted_answer=extracted_answer, parse_valid=True, is_correct=is_correct) diff --git a/lib/marin/src/marin/test_time_scaling/selectors.py b/lib/marin/src/marin/test_time_scaling/selectors.py new file mode 100644 index 0000000000..3ce7237df8 --- /dev/null +++ b/lib/marin/src/marin/test_time_scaling/selectors.py @@ -0,0 +1,78 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import defaultdict + +from marin.test_time_scaling.config import SelectorName +from marin.test_time_scaling.results import CandidateRecord + + +def _require_candidates(candidates: list[CandidateRecord]) -> list[CandidateRecord]: + if not candidates: + raise ValueError("selector received no candidates") + return sorted(candidates, key=lambda candidate: candidate.sample_index) + + +def select_first_sample(candidates: list[CandidateRecord]) -> CandidateRecord: + """Return the earliest generated candidate.""" + + ordered_candidates = _require_candidates(candidates) + return ordered_candidates[0] + + +def select_majority_vote(candidates: list[CandidateRecord]) -> CandidateRecord: + """Select by extracted-answer majority vote with deterministic tie-breaking.""" + + ordered_candidates = _require_candidates(candidates) + answer_groups: dict[str, list[CandidateRecord]] = defaultdict(list) + for candidate in ordered_candidates: + if candidate.extracted_answer is None: + continue + answer_groups[candidate.extracted_answer].append(candidate) + + if not answer_groups: + return ordered_candidates[0] + + def group_key(group: list[CandidateRecord]) -> tuple[int, float, int]: + score_values = [candidate.normalized_logprob for candidate in group if candidate.normalized_logprob is not None] + mean_score = sum(score_values) / len(score_values) if score_values else float("-inf") + earliest_sample_index = min(candidate.sample_index for candidate in group) + return (len(group), mean_score, -earliest_sample_index) + + winning_group = max(answer_groups.values(), key=group_key) + return max( + winning_group, + key=lambda candidate: ( + candidate.normalized_logprob if candidate.normalized_logprob is not None else float("-inf"), + -candidate.sample_index, + ), + ) + + +def select_normalized_logprob(candidates: list[CandidateRecord]) -> CandidateRecord: + """Select the candidate with the best mean token logprob.""" + + ordered_candidates = _require_candidates(candidates) + candidates_with_logprobs = [ + candidate for candidate in ordered_candidates if candidate.normalized_logprob is not None + ] + if not candidates_with_logprobs: + return ordered_candidates[0] + return max( + candidates_with_logprobs, + key=lambda candidate: (candidate.normalized_logprob, -candidate.sample_index), + ) + + +def select_candidate(candidates: list[CandidateRecord], selector_name: SelectorName) -> CandidateRecord: + """Dispatch a built-in selector by name.""" + + if selector_name == SelectorName.FIRST_SAMPLE: + return select_first_sample(candidates) + if selector_name == SelectorName.MAJORITY_VOTE: + return select_majority_vote(candidates) + if selector_name == SelectorName.NORMALIZED_LOGPROB: + return select_normalized_logprob(candidates) + raise ValueError(f"Unsupported selector: {selector_name}") diff --git a/tests/evals/test_lm_eval.py b/tests/evals/test_lm_eval.py index 600350f6ba..10c479c85e 100644 --- a/tests/evals/test_lm_eval.py +++ b/tests/evals/test_lm_eval.py @@ -6,8 +6,8 @@ import pytest from fray.cluster import ResourceConfig from marin.evaluation.evaluation_config import EvaluationConfig -from marin.evaluation.evaluators.evaluator import ModelConfig from marin.evaluation.run import evaluate +from marin.inference.model_config import ModelConfig from experiments.evals.task_configs import EvalTaskConfig diff --git a/tests/test_time_scaling/conftest.py b/tests/test_time_scaling/conftest.py new file mode 100644 index 0000000000..9368c5efa4 --- /dev/null +++ b/tests/test_time_scaling/conftest.py @@ -0,0 +1,23 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +import json +import shutil +from pathlib import Path + +import pytest +from levanter.tokenizers import load_tokenizer + +_GPT2_TOKENIZER_CONFIG = ( + Path(__file__).resolve().parents[2] / "lib" / "levanter" / "tests" / "gpt2_tokenizer_config.json" +) + + +@pytest.fixture(scope="session") +def gpt2_tokenizer(tmp_path_factory): + """Local GPT-2 tokenizer fixture without network access.""" + tmpdir = tmp_path_factory.mktemp("tts_gpt2_tok") + shutil.copy(_GPT2_TOKENIZER_CONFIG, tmpdir / "tokenizer.json") + shutil.copy(_GPT2_TOKENIZER_CONFIG, tmpdir / "tokenizer_config.json") + (tmpdir / "config.json").write_text(json.dumps({"model_type": "gpt2", "vocab_size": 5027})) + return load_tokenizer(str(tmpdir)) diff --git a/tests/test_time_scaling/test_analysis.py b/tests/test_time_scaling/test_analysis.py new file mode 100644 index 0000000000..3a0a3e72ab --- /dev/null +++ b/tests/test_time_scaling/test_analysis.py @@ -0,0 +1,119 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from marin.test_time_scaling import ( + CandidateGenerationConfig, + CandidateRecord, + PromptManifest, + PromptManifestRecord, + PromptMessage, + ScoringMode, + SelectorName, + TestTimeScalingConfig as TtsRunConfig, + build_run_summary, + replay_selectors, +) + + +def _candidate( + *, + prompt_id: str, + sample_index: int, + raw_text: str, + extracted_answer: str | None, + is_correct: bool, + prompt_tokens: int, + completion_tokens: int, + request_latency_seconds: float, +) -> CandidateRecord: + return CandidateRecord( + prompt_id=prompt_id, + candidate_id=f"{prompt_id}-{sample_index}", + sample_index=sample_index, + raw_text=raw_text, + extracted_answer=extracted_answer, + is_correct=is_correct, + parse_valid=extracted_answer is not None, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + finish_reason="stop", + request_latency_seconds=request_latency_seconds, + generation_seed=7, + logprob_sum=-0.5 * completion_tokens, + normalized_logprob=-0.5, + ) + + +def test_build_run_summary_dedupes_prompt_budget_per_request(): + manifest = PromptManifest( + manifest_id="math-slice", + task_name="math-demo", + records=( + PromptManifestRecord( + prompt_id="p0", + messages=(PromptMessage(role="user", content="What is 2 + 2?"),), + expected_answer="\\boxed{4}", + scoring_mode=ScoringMode.MATH_BOXED, + ), + PromptManifestRecord( + prompt_id="p1", + messages=(PromptMessage(role="user", content="What is 3 + 4?"),), + expected_answer="\\boxed{7}", + scoring_mode=ScoringMode.MATH_BOXED, + ), + ), + ) + candidates = [ + _candidate( + prompt_id="p0", + sample_index=0, + raw_text="\\boxed{4}", + extracted_answer="4", + is_correct=True, + prompt_tokens=11, + completion_tokens=5, + request_latency_seconds=1.5, + ), + _candidate( + prompt_id="p0", + sample_index=1, + raw_text="\\boxed{5}", + extracted_answer="5", + is_correct=False, + prompt_tokens=11, + completion_tokens=6, + request_latency_seconds=1.5, + ), + _candidate( + prompt_id="p1", + sample_index=0, + raw_text="\\boxed{7}", + extracted_answer="7", + is_correct=True, + prompt_tokens=13, + completion_tokens=4, + request_latency_seconds=2.0, + ), + _candidate( + prompt_id="p1", + sample_index=1, + raw_text="\\boxed{7}", + extracted_answer="7", + is_correct=True, + prompt_tokens=13, + completion_tokens=7, + request_latency_seconds=2.0, + ), + ] + run_config = TtsRunConfig( + generation=CandidateGenerationConfig(num_candidates=2, temperature=0.7, seed=5), + selectors=(SelectorName.FIRST_SAMPLE, SelectorName.MAJORITY_VOTE), + ) + + selections = replay_selectors(candidates, run_config.selectors) + summary = build_run_summary(manifest, run_config, candidates, selections) + + assert summary.total_prompt_tokens == 24 + assert summary.total_completion_tokens == 22 + assert summary.total_request_latency_seconds == 3.5 + assert summary.oracle_accuracy == 1.0 diff --git a/tests/test_time_scaling/test_manifests.py b/tests/test_time_scaling/test_manifests.py new file mode 100644 index 0000000000..406f943ec8 --- /dev/null +++ b/tests/test_time_scaling/test_manifests.py @@ -0,0 +1,64 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from marin.test_time_scaling import ( + PromptManifest, + PromptManifestRecord, + PromptMessage, + ScoringMode, + load_prompt_manifest, + write_prompt_manifest, +) + + +def test_prompt_manifest_round_trip(tmp_path): + manifest = PromptManifest( + manifest_id="demo-manifest", + task_name="demo-math", + records=( + PromptManifestRecord( + prompt_id="p0", + messages=(PromptMessage(role="user", content="What is 2 + 2? Put the answer in \\boxed{}."),), + expected_answer="\\boxed{4}", + scoring_mode=ScoringMode.MATH_BOXED, + metadata={"split": "test"}, + ), + ), + metadata={"suite": "unit"}, + ) + + output_dir = tmp_path / "manifest" + write_prompt_manifest(str(output_dir), manifest) + loaded_manifest = load_prompt_manifest(str(output_dir)) + + assert loaded_manifest.manifest_id == manifest.manifest_id + assert loaded_manifest.task_name == manifest.task_name + assert loaded_manifest.metadata == manifest.metadata + assert len(loaded_manifest.records) == 1 + assert loaded_manifest.records[0].prompt_id == "p0" + assert loaded_manifest.records[0].messages[0].role == "user" + assert loaded_manifest.records[0].expected_answer == "\\boxed{4}" + assert loaded_manifest.records[0].scoring_mode == ScoringMode.MATH_BOXED + + +def test_load_prompt_manifest_from_relative_manifest_file(tmp_path, monkeypatch): + manifest = PromptManifest( + manifest_id="demo-manifest", + task_name="demo-math", + records=( + PromptManifestRecord( + prompt_id="p0", + messages=(PromptMessage(role="user", content="What is 2 + 2? Put the answer in \\boxed{}."),), + expected_answer="\\boxed{4}", + scoring_mode=ScoringMode.MATH_BOXED, + ), + ), + ) + + write_prompt_manifest(str(tmp_path), manifest) + monkeypatch.chdir(tmp_path) + + loaded_manifest = load_prompt_manifest("manifest.json") + + assert loaded_manifest.manifest_id == manifest.manifest_id + assert loaded_manifest.records[0].expected_answer == "\\boxed{4}" diff --git a/tests/test_time_scaling/test_reasoning_tts.py b/tests/test_time_scaling/test_reasoning_tts.py new file mode 100644 index 0000000000..88bad28760 --- /dev/null +++ b/tests/test_time_scaling/test_reasoning_tts.py @@ -0,0 +1,165 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice, ChoiceLogprobs +from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob +from openai.types.completion_usage import CompletionUsage + +from marin.inference.chat_completions import ChatCompletionRequest +from marin.test_time_scaling import ( + CandidateGenerationConfig, + PromptManifest, + PromptManifestRecord, + PromptMessage, + ScoringMode, + SelectorName, + TestTimeScalingConfig as TtsRunConfig, + build_run_summary, + generate_candidates, + load_prompt_manifest, + read_candidate_records, + read_selection_records, + replay_selectors, + write_candidate_records, + write_prompt_manifest, + write_run_summary, + write_selection_records, +) + + +def _create_choice(tokenizer, response_text: str, logprob_values: list[float]) -> Choice: + token_ids = tokenizer.encode(response_text, add_special_tokens=False) + if len(logprob_values) != len(token_ids): + if len(logprob_values) == 1: + logprob_values = logprob_values * len(token_ids) + else: + logprob_values = [logprob_values[0]] * len(token_ids) + logprobs = [] + for token_id, logprob in zip(token_ids, logprob_values, strict=True): + token = tokenizer.convert_ids_to_tokens(token_id) + logprobs.append( + ChatCompletionTokenLogprob( + token=token, + logprob=logprob, + bytes=list(token.encode("utf-8")), + top_logprobs=[], + ) + ) + + return Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=response_text), + logprobs=ChoiceLogprobs(content=logprobs), + ) + + +def _create_completion(tokenizer, responses: list[tuple[str, list[float]]]) -> ChatCompletion: + choices = [] + completion_tokens = 0 + for index, (text, logprobs) in enumerate(responses): + choice = _create_choice(tokenizer, text, logprobs) + choice.index = index + choices.append(choice) + completion_tokens += len(choice.logprobs.content) if choice.logprobs and choice.logprobs.content else 0 + + return ChatCompletion( + id="chatcmpl-test", + choices=choices, + created=1234567890, + model="test-model", + object="chat.completion", + usage=CompletionUsage( + prompt_tokens=12, + completion_tokens=completion_tokens, + total_tokens=12 + completion_tokens, + ), + ) + + +class FakeCompletionProvider: + def __init__(self, completions: list[ChatCompletion]): + self._completions = completions + self._request_index = 0 + + def complete_messages(self, request: ChatCompletionRequest): + assert request.num_completions == 3 + assert request.messages[0]["role"] == "user" + completion = self._completions[self._request_index] + self._request_index += 1 + return completion + + +def test_end_to_end_reasoning_tts_math_vertical_slice(tmp_path, gpt2_tokenizer): + manifest = PromptManifest( + manifest_id="math-slice", + task_name="math-demo", + records=( + PromptManifestRecord( + prompt_id="p0", + messages=(PromptMessage(role="user", content="What is 2 + 2? Put the answer in \\boxed{}."),), + expected_answer="\\boxed{4}", + scoring_mode=ScoringMode.MATH_BOXED, + ), + PromptManifestRecord( + prompt_id="p1", + messages=(PromptMessage(role="user", content="What is 3 + 4? Put the answer in \\boxed{}."),), + expected_answer="\\boxed{7}", + scoring_mode=ScoringMode.MATH_BOXED, + ), + ), + ) + provider = FakeCompletionProvider( + [ + _create_completion( + gpt2_tokenizer, + [ + ("The answer is \\boxed{5}", [-0.05, -0.05, -0.05, -0.05, -0.05]), + ("Working carefully gives \\boxed{4}", [-0.2, -0.2, -0.2, -0.2, -0.2]), + ("Checking again gives \\boxed{4}", [-0.3, -0.3, -0.3, -0.3, -0.3]), + ], + ), + _create_completion( + gpt2_tokenizer, + [ + ("By inspection we get \\boxed{7}", [-0.4, -0.4, -0.4, -0.4, -0.4]), + ("Another derivation gives \\boxed{7}", [-0.6, -0.6, -0.6, -0.6, -0.6]), + ("One more pass confirms \\boxed{7}", [-0.5, -0.5, -0.5, -0.5, -0.5]), + ], + ), + ] + ) + run_config = TtsRunConfig( + generation=CandidateGenerationConfig(num_candidates=3, temperature=0.7, top_p=1.0, max_gen_toks=128, seed=11), + selectors=( + SelectorName.FIRST_SAMPLE, + SelectorName.MAJORITY_VOTE, + SelectorName.NORMALIZED_LOGPROB, + ), + ) + + output_dir = tmp_path / "artifacts" + write_prompt_manifest(str(output_dir), manifest) + candidates = generate_candidates(manifest, provider, run_config.generation) + write_candidate_records(str(output_dir), candidates) + saved_candidates = read_candidate_records(str(output_dir / "candidates.jsonl")) + selections = replay_selectors(saved_candidates, run_config.selectors) + write_selection_records(str(output_dir), selections) + summary = build_run_summary(manifest, run_config, saved_candidates, selections) + write_run_summary(str(output_dir), summary) + + loaded_manifest = load_prompt_manifest(str(output_dir)) + loaded_selections = read_selection_records(str(output_dir / "selected.jsonl")) + + assert loaded_manifest.manifest_id == "math-slice" + assert len(saved_candidates) == 6 + assert len(loaded_selections) == 6 + assert summary.oracle_accuracy == 1.0 + assert summary.total_candidates == 6 + + selector_summaries = {selector.selector_name: selector for selector in summary.selector_summaries} + assert selector_summaries[SelectorName.FIRST_SAMPLE].accuracy == 0.5 + assert selector_summaries[SelectorName.MAJORITY_VOTE].accuracy == 1.0 diff --git a/tests/test_time_scaling/test_scorers.py b/tests/test_time_scaling/test_scorers.py new file mode 100644 index 0000000000..d712c8f6b1 --- /dev/null +++ b/tests/test_time_scaling/test_scorers.py @@ -0,0 +1,29 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from marin.test_time_scaling.config import ScoringMode +from marin.test_time_scaling.scorers import score_candidate_text + + +def test_math_boxed_scoring_extracts_normalized_answer(): + score = score_candidate_text( + "Working carefully gives \\boxed{\\frac{1}{2}}", + "\\boxed{\\frac{1}{2}}", + ScoringMode.MATH_BOXED, + ) + + assert score.parse_valid is True + assert score.extracted_answer == "1/2" + assert score.is_correct is True + + +def test_math_boxed_scoring_marks_missing_boxed_answer_invalid(): + score = score_candidate_text( + "The answer is 4.", + "\\boxed{4}", + ScoringMode.MATH_BOXED, + ) + + assert score.parse_valid is False + assert score.extracted_answer is None + assert score.is_correct is False diff --git a/tests/test_time_scaling/test_selectors.py b/tests/test_time_scaling/test_selectors.py new file mode 100644 index 0000000000..a5fd7ce585 --- /dev/null +++ b/tests/test_time_scaling/test_selectors.py @@ -0,0 +1,79 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from marin.test_time_scaling import CandidateRecord, SelectorName, replay_selectors + + +def _candidate( + *, + prompt_id: str, + sample_index: int, + text: str, + extracted_answer: str | None, + is_correct: bool, + normalized_logprob: float | None, +) -> CandidateRecord: + return CandidateRecord( + prompt_id=prompt_id, + candidate_id=f"{prompt_id}-{sample_index}", + sample_index=sample_index, + raw_text=text, + extracted_answer=extracted_answer, + is_correct=is_correct, + parse_valid=extracted_answer is not None, + prompt_tokens=10, + completion_tokens=5, + finish_reason="stop", + request_latency_seconds=0.1, + generation_seed=42, + logprob_sum=(normalized_logprob * 5) if normalized_logprob is not None else None, + normalized_logprob=normalized_logprob, + ) + + +def test_replay_selectors_uses_same_candidate_pool(): + candidates = [ + _candidate( + prompt_id="p0", + sample_index=0, + text="\\boxed{5}", + extracted_answer="5", + is_correct=False, + normalized_logprob=-0.05, + ), + _candidate( + prompt_id="p0", + sample_index=1, + text="\\boxed{4}", + extracted_answer="4", + is_correct=True, + normalized_logprob=-0.20, + ), + _candidate( + prompt_id="p0", + sample_index=2, + text="\\boxed{4}", + extracted_answer="4", + is_correct=True, + normalized_logprob=-0.30, + ), + ] + + selections = replay_selectors( + candidates, + ( + SelectorName.FIRST_SAMPLE, + SelectorName.MAJORITY_VOTE, + SelectorName.NORMALIZED_LOGPROB, + ), + ) + by_selector = {selection.selector_name: selection for selection in selections} + + assert by_selector[SelectorName.FIRST_SAMPLE].chosen_candidate_id == "p0-0" + assert by_selector[SelectorName.FIRST_SAMPLE].oracle_gap is True + + assert by_selector[SelectorName.MAJORITY_VOTE].chosen_candidate_id in {"p0-1", "p0-2"} + assert by_selector[SelectorName.MAJORITY_VOTE].correctness is True + + assert by_selector[SelectorName.NORMALIZED_LOGPROB].chosen_candidate_id == "p0-0" + assert by_selector[SelectorName.NORMALIZED_LOGPROB].correctness is False diff --git a/tests/vllm/test_llm_inference.py b/tests/vllm/test_llm_inference.py index 7017c2f53e..20c0f5dbe1 100644 --- a/tests/vllm/test_llm_inference.py +++ b/tests/vllm/test_llm_inference.py @@ -6,7 +6,7 @@ import pytest -from marin.evaluation.evaluators.evaluator import ModelConfig +from marin.inference.model_config import ModelConfig from marin.inference.vllm_server import resolve_model_name_or_path try: