Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions experiments/evals/run_reasoning_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import argparse
import json
import logging
from contextlib import ExitStack

from rigging.log_setup import configure_logging

from marin.inference.chat_completions import OpenAIChatCompletionProvider
from marin.inference.model_config import ModelConfig
from marin.inference.vllm_server import VllmEnvironment
from marin.test_time_scaling import (
DEFAULT_REASONING_SELECTORS,
CandidateGenerationConfig,
SelectorName,
TestTimeScalingConfig,
build_run_summary,
generate_candidates,
load_prompt_manifest,
replay_selectors,
write_candidate_records,
write_prompt_manifest,
write_run_summary,
write_selection_records,
)

logger = logging.getLogger(__name__)


def _parse_selector_names(raw_selectors: list[str] | None) -> tuple[SelectorName, ...]:
if not raw_selectors:
return DEFAULT_REASONING_SELECTORS
return tuple(SelectorName(raw_selector) for raw_selector in raw_selectors)


def _parse_engine_kwargs(raw_engine_kwargs: str) -> dict:
try:
engine_kwargs = json.loads(raw_engine_kwargs)
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid JSON for --engine-kwargs-json: {exc}") from exc
if not isinstance(engine_kwargs, dict):
raise ValueError("--engine-kwargs-json must decode to a JSON object")
return engine_kwargs


def _build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run sample-only reasoning TTS against a prompt manifest.")
parser.add_argument("--manifest", required=True, help="Prompt manifest directory or prompts.jsonl path")
parser.add_argument("--output-dir", required=True, help="Directory where TTS artifacts should be written")
parser.add_argument("--model", required=True, help="Model name or request model id")
parser.add_argument("--model-path", help="Optional checkpoint path when launching a local vLLM server")
parser.add_argument("--server-url", help="Existing OpenAI-compatible /v1 server URL")
parser.add_argument("--vllm-mode", choices=["native", "docker"], help="Mode to use when launching vLLM")
parser.add_argument(
"--engine-kwargs-json",
default="{}",
help="JSON object of engine kwargs forwarded when launching vLLM",
)
parser.add_argument(
"--selector",
dest="selectors",
action="append",
choices=[selector.value for selector in SelectorName],
help="Selector to evaluate. Can be repeated. Defaults to the built-in set.",
)
parser.add_argument("--num-candidates", type=int, default=4, help="Number of candidates to sample per prompt")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
parser.add_argument("--top-p", type=float, default=1.0, help="Top-p nucleus sampling threshold")
parser.add_argument("--max-gen-toks", type=int, default=2048, help="Maximum generated tokens per candidate")
parser.add_argument("--seed", type=int, help="Base request seed")
parser.add_argument("--request-timeout", type=float, default=600.0, help="HTTP timeout for each request")
parser.add_argument("--startup-timeout", type=int, default=3600, help="Timeout when launching a local vLLM server")
parser.add_argument("--api-key", default="marin-tts", help="API key for OpenAI-compatible servers")
parser.add_argument(
"--extra-vllm-arg",
action="append",
default=[],
help="Extra CLI argument passed through to `vllm serve`. Can be repeated.",
)
return parser


def main() -> None:
parser = _build_arg_parser()
args = parser.parse_args()

configure_logging(level=logging.INFO)

manifest = load_prompt_manifest(args.manifest)
run_config = TestTimeScalingConfig(
generation=CandidateGenerationConfig(
num_candidates=args.num_candidates,
temperature=args.temperature,
top_p=args.top_p,
max_gen_toks=args.max_gen_toks,
seed=args.seed,
),
selectors=_parse_selector_names(args.selectors),
)

write_prompt_manifest(args.output_dir, manifest)

with ExitStack() as stack:
if args.server_url:
server_url = args.server_url
model_id = args.model
else:
model = ModelConfig(
name=args.model,
path=args.model_path,
engine_kwargs=_parse_engine_kwargs(args.engine_kwargs_json),
)
environment = stack.enter_context(
VllmEnvironment(
model,
mode=args.vllm_mode,
timeout_seconds=args.startup_timeout,
extra_args=args.extra_vllm_arg,
)
)
server_url = environment.server_url
model_id = environment.model_id or args.model

provider = OpenAIChatCompletionProvider(
server_url=server_url,
model=model_id,
api_key=args.api_key,
timeout=args.request_timeout,
)
candidates = generate_candidates(manifest, provider, run_config.generation)

write_candidate_records(args.output_dir, candidates)
selections = replay_selectors(candidates, run_config.selectors)
write_selection_records(args.output_dir, selections)
summary = build_run_summary(manifest, run_config, candidates, selections)
write_run_summary(args.output_dir, summary)

logger.info(
"Completed reasoning TTS run for %s with %d prompts and %d candidates",
manifest.task_name,
len(manifest.records),
len(candidates),
)
for selector_summary in summary.selector_summaries:
logger.info(
"selector=%s accuracy=%s oracle_gap_rate=%s",
selector_summary.selector_name,
selector_summary.accuracy,
selector_summary.oracle_gap_rate,
)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@
from rigging.filesystem import filesystem as marin_filesystem

from marin.evaluation.evaluation_config import WANDB_PROJECT, EvalTaskConfig
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
from marin.evaluation.evaluators.evaluator import Evaluator
from marin.inference.vllm_server import resolve_model_name_or_path
from marin.evaluation.utils import is_remote_path, upload_to_gcs
from marin.inference.model_config import ModelConfig

logger = logging.getLogger(__name__)

Expand Down
32 changes: 1 addition & 31 deletions lib/marin/src/marin/evaluation/evaluators/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

from marin.evaluation.evaluation_config import EvalTaskConfig


@dataclass
class ModelConfig:
name: str
"""The name of the model e.g., allenai/olmo-7b"""

path: str | None
"""
The path to the model checkpoint. Can be a local path or a path on GCS.
"""

engine_kwargs: dict[str, Any]
"""
Additional keyword arguments to pass to the vLLM engine.
"""

generation_params: dict | None = None
"""
Additional keyword arguments passed to the SamplingParams for the vLLM engine
"""

apply_chat_template: bool = False
"""
Whether or not this model was trained with a Chat Template in the tokenizer
"""

base_eval_run_name: str | None = None
"""Custom base name for wandb runs."""
from marin.inference.model_config import ModelConfig


class Evaluator(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
from rigging.filesystem import open_url

from marin.evaluation.evaluation_config import EvalTaskConfig
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
from marin.evaluation.evaluators.evaluator import Evaluator
from marin.evaluation.utils import download_from_gcs, is_remote_path, upload_to_gcs
from marin.inference.model_config import ModelConfig
from marin.inference.vllm_server import VllmEnvironment
from marin.utils import fsspec_exists, fsspec_glob

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from levanter.trainer import TrainerConfig

from marin.evaluation.evaluation_config import EvalTaskConfig, convert_to_levanter_task_config
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
from marin.evaluation.evaluators.evaluator import Evaluator
from marin.inference.model_config import ModelConfig

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from rigging.filesystem import open_url, url_to_fs

from marin.evaluation.evaluation_config import EvalTaskConfig
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
from marin.evaluation.evaluators.evaluator import Evaluator
from marin.evaluation.utils import is_remote_path, upload_to_gcs
from marin.inference.model_config import ModelConfig
from marin.inference.vllm_server import VllmEnvironment

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import ClassVar

from marin.evaluation.evaluation_config import EvalTaskConfig
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
from marin.evaluation.evaluators.evaluator import Evaluator
from marin.inference.model_config import ModelConfig
from marin.inference.vllm_server import resolve_model_name_or_path


Expand Down
3 changes: 2 additions & 1 deletion lib/marin/src/marin/evaluation/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
import draccus

from marin.evaluation.evaluation_config import EvaluationConfig
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
from marin.evaluation.evaluators.evaluator import Evaluator
from marin.evaluation.evaluators.evalchemy_evaluator import EvalchemyEvaluator
from marin.evaluation.evaluators.harbor_evaluator import HarborEvaluator
from marin.evaluation.evaluators.levanter_lm_eval_evaluator import LevanterLmEvalEvaluator
from marin.evaluation.evaluators.lm_evaluation_harness_evaluator import LMEvaluationHarnessEvaluator
from marin.evaluation.evaluators.simple_evaluator import SimpleEvaluator
from marin.evaluation.utils import discover_hf_checkpoints
from marin.inference.model_config import ModelConfig
from marin.utils import fsspec_exists

logger = logging.getLogger(__name__)
Expand Down
74 changes: 74 additions & 0 deletions lib/marin/src/marin/inference/chat_completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Protocol

from openai import OpenAI
from openai.types.chat import ChatCompletion


@dataclass(frozen=True)
class ChatCompletionRequest:
"""OpenAI-compatible chat completion request parameters."""

messages: tuple[dict[str, str], ...]
num_completions: int
temperature: float
top_p: float = 1.0
max_tokens: int | None = None
seed: int | None = None
logprobs: bool = False

def __post_init__(self) -> None:
if self.num_completions <= 0:
raise ValueError("num_completions must be positive")
if self.temperature < 0:
raise ValueError("temperature must be non-negative")
if not 0 < self.top_p <= 1.0:
raise ValueError("top_p must be in the interval (0, 1]")
if self.max_tokens is not None and self.max_tokens <= 0:
raise ValueError("max_tokens must be positive when set")


class CompletionProvider(Protocol):
"""Protocol for chat completion backends used by inference clients."""

def complete_messages(self, request: ChatCompletionRequest) -> ChatCompletion:
"""Return an OpenAI-compatible chat completion response."""


class OpenAIChatCompletionProvider:
"""Minimal synchronous OpenAI-compatible completion provider."""

def __init__(
self,
*,
server_url: str,
model: str,
api_key: str = "marin-tts",
timeout: float | None = None,
extra_request_kwargs: dict[str, Any] | None = None,
) -> None:
self._client = OpenAI(base_url=server_url, api_key=api_key, timeout=timeout)
self._model = model
self._extra_request_kwargs = dict(extra_request_kwargs or {})

def complete_messages(self, request: ChatCompletionRequest) -> ChatCompletion:
request_kwargs: dict[str, Any] = {
"model": self._model,
"messages": list(request.messages),
"n": request.num_completions,
"temperature": request.temperature,
"top_p": request.top_p,
"logprobs": request.logprobs,
**self._extra_request_kwargs,
}
if request.max_tokens is not None:
request_kwargs["max_tokens"] = request.max_tokens
if request.seed is not None:
request_kwargs["seed"] = request.seed

return self._client.chat.completions.create(**request_kwargs)
19 changes: 19 additions & 0 deletions lib/marin/src/marin/inference/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass
from typing import Any


@dataclass
class ModelConfig:
"""Configuration for launching or querying an inference model."""

name: str
path: str | None
engine_kwargs: dict[str, Any]
generation_params: dict | None = None
apply_chat_template: bool = False
base_eval_run_name: str | None = None
2 changes: 1 addition & 1 deletion lib/marin/src/marin/inference/vllm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import requests
from rigging.filesystem import marin_prefix

from marin.evaluation.evaluators.evaluator import ModelConfig
from marin.inference.model_config import ModelConfig

logger = logging.getLogger(__name__)
DEFAULT_VLLM_TPU_DOCKER_IMAGE: str = "vllm/vllm-tpu:nightly-20260104-4a1e25b-0d4044e"
Expand Down
2 changes: 1 addition & 1 deletion lib/marin/src/marin/inference/vllm_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fray.v2 import current_client
from fray.v2.types import Entrypoint, JobRequest, ResourceConfig, create_environment

from marin.evaluation.evaluators.evaluator import ModelConfig
from marin.inference.model_config import ModelConfig
from marin.inference.vllm_server import VLLM_NATIVE_PIP_PACKAGES, VllmEnvironment, resolve_vllm_mode
from marin.utils import remove_tpu_lockfile_on_exit

Expand Down
Loading