Skip to content

Commit d1d4d2b

Browse files
committed
[tts] Refactor inference and test-time scaling ownership
1 parent c0be364 commit d1d4d2b

18 files changed

Lines changed: 134 additions & 110 deletions

experiments/evals/run_reasoning_tts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
from rigging.log_setup import configure_logging
1212

13-
from marin.evaluation.evaluators.evaluator import ModelConfig
13+
from marin.inference.chat_completions import OpenAIChatCompletionProvider
14+
from marin.inference.model_config import ModelConfig
1415
from marin.inference.vllm_server import VllmEnvironment
1516
from marin.test_time_scaling import (
1617
DEFAULT_REASONING_SELECTORS,
1718
CandidateGenerationConfig,
18-
OpenAIChatCompletionProvider,
1919
SelectorName,
2020
TestTimeScalingConfig,
2121
build_run_summary,

lib/marin/src/marin/evaluation/evaluators/evalchemy_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@
3636
from rigging.filesystem import filesystem as marin_filesystem
3737

3838
from marin.evaluation.evaluation_config import WANDB_PROJECT, EvalTaskConfig
39-
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig, launch_evaluate_with_ray
39+
from marin.evaluation.evaluators.evaluator import Evaluator, launch_evaluate_with_ray
4040
from marin.evaluation.utils import is_remote_path, upload_to_gcs
41+
from marin.inference.model_config import ModelConfig
4142
from marin.inference.vllm_server import resolve_model_name_or_path
4243

4344
logger = logging.getLogger(__name__)

lib/marin/src/marin/evaluation/evaluators/evaluator.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from abc import ABC, abstractmethod
55
from collections.abc import Sequence
66
from dataclasses import dataclass
7-
from typing import Any
87

98
from fray.v1.cluster import Entrypoint, EnvironmentConfig, JobRequest, ResourceConfig, current_cluster
109

1110
from marin.evaluation.evaluation_config import EvalTaskConfig
11+
from marin.inference.model_config import ModelConfig
1212
from marin.utils import remove_tpu_lockfile_on_exit
1313
from rigging.log_setup import configure_logging as _init_logging
1414

@@ -27,35 +27,6 @@ def __str__(self):
2727
return f"{self.name}=={self.version}" if self.version else self.name
2828

2929

30-
@dataclass
31-
class ModelConfig:
32-
name: str
33-
"""The name of the model e.g., allenai/olmo-7b"""
34-
35-
path: str | None
36-
"""
37-
The path to the model checkpoint. Can be a local path or a path on GCS.
38-
"""
39-
40-
engine_kwargs: dict[str, Any]
41-
"""
42-
Additional keyword arguments to pass to the vLLM engine.
43-
"""
44-
45-
generation_params: dict | None = None
46-
"""
47-
Additional keyword arguments passed to the SamplingParams for the vLLM engine
48-
"""
49-
50-
apply_chat_template: bool = False
51-
"""
52-
Whether or not this model was trained with a Chat Template in the tokenizer
53-
"""
54-
55-
base_eval_run_name: str | None = None
56-
"""Custom base name for wandb runs."""
57-
58-
5930
class Evaluator(ABC):
6031
@abstractmethod
6132
def launch_evaluate_with_ray(

lib/marin/src/marin/evaluation/evaluators/harbor_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
from rigging.filesystem import open_url
2929

3030
from marin.evaluation.evaluation_config import EvalTaskConfig
31-
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig, launch_evaluate_with_ray
31+
from marin.evaluation.evaluators.evaluator import Evaluator, launch_evaluate_with_ray
3232
from marin.evaluation.utils import download_from_gcs, is_remote_path, upload_to_gcs
33+
from marin.inference.model_config import ModelConfig
3334
from marin.inference.vllm_server import VLLM_NATIVE_PIP_PACKAGES, VllmEnvironment, resolve_vllm_mode
3435
from marin.utils import fsspec_exists, fsspec_glob
3536

lib/marin/src/marin/evaluation/evaluators/levanter_lm_eval_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from levanter.tracker.wandb import WandbConfig
1616
from levanter.trainer import TrainerConfig
1717

18+
from fray.v1.cluster.ray.deps import build_runtime_env_for_packages
1819
from marin.evaluation.evaluation_config import EvalTaskConfig, convert_to_levanter_task_config
19-
from marin.evaluation.evaluators.evaluator import ModelConfig
2020
from marin.evaluation.evaluators.levanter_tpu_evaluator import LevanterTpuEvaluator
21-
from fray.v1.cluster.ray.deps import build_runtime_env_for_packages
21+
from marin.inference.model_config import ModelConfig
2222

2323
logger = logging.getLogger(__name__)
2424

lib/marin/src/marin/evaluation/evaluators/levanter_tpu_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from fray.v1.cluster import ResourceConfig
77

88
from marin.evaluation.evaluation_config import EvalTaskConfig
9-
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig, launch_evaluate_with_ray
9+
from marin.evaluation.evaluators.evaluator import Evaluator, launch_evaluate_with_ray
10+
from marin.inference.model_config import ModelConfig
1011

1112

1213
class LevanterTpuEvaluator(Evaluator, ABC):

lib/marin/src/marin/evaluation/evaluators/lm_evaluation_harness_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from rigging.filesystem import open_url, url_to_fs
1616

1717
from marin.evaluation.evaluation_config import EvalTaskConfig
18-
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig, launch_evaluate_with_ray
18+
from marin.evaluation.evaluators.evaluator import Evaluator, launch_evaluate_with_ray
1919
from marin.evaluation.utils import is_remote_path, upload_to_gcs
20+
from marin.inference.model_config import ModelConfig
2021
from marin.inference.vllm_server import VLLM_NATIVE_PIP_PACKAGES, VllmEnvironment, resolve_vllm_mode
2122

2223
logger = logging.getLogger(__name__)

lib/marin/src/marin/evaluation/evaluators/simple_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from fray.v1.cluster import ResourceConfig
1010

1111
from marin.evaluation.evaluation_config import EvalTaskConfig
12-
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig, launch_evaluate_with_ray
12+
from marin.evaluation.evaluators.evaluator import Evaluator, launch_evaluate_with_ray
13+
from marin.inference.model_config import ModelConfig
1314
from marin.inference.vllm_server import VLLM_NATIVE_PIP_PACKAGES, resolve_model_name_or_path, resolve_vllm_mode
1415

1516

lib/marin/src/marin/evaluation/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
from fray.v1.cluster import TpuConfig as V1TpuConfig
2222

2323
from marin.evaluation.evaluation_config import EvaluationConfig
24-
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
24+
from marin.evaluation.evaluators.evaluator import Evaluator
2525
from marin.evaluation.evaluators.evalchemy_evaluator import EvalchemyEvaluator
2626
from marin.evaluation.evaluators.harbor_evaluator import HarborEvaluator
2727
from marin.evaluation.evaluators.levanter_lm_eval_evaluator import LevanterLmEvalEvaluator
2828
from marin.evaluation.evaluators.lm_evaluation_harness_evaluator import LMEvaluationHarnessEvaluator
2929
from marin.evaluation.evaluators.simple_evaluator import SimpleEvaluator
3030
from marin.evaluation.utils import discover_hf_checkpoints
31+
from marin.inference.model_config import ModelConfig
3132
from marin.utils import fsspec_exists
3233

3334
logger = logging.getLogger(__name__)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from dataclasses import dataclass
7+
from typing import Any, Protocol
8+
9+
from openai import OpenAI
10+
from openai.types.chat import ChatCompletion
11+
12+
13+
@dataclass(frozen=True)
14+
class ChatCompletionRequest:
15+
"""OpenAI-compatible chat completion request parameters."""
16+
17+
messages: tuple[dict[str, str], ...]
18+
num_completions: int
19+
temperature: float
20+
top_p: float = 1.0
21+
max_tokens: int | None = None
22+
seed: int | None = None
23+
logprobs: bool = False
24+
25+
def __post_init__(self) -> None:
26+
if self.num_completions <= 0:
27+
raise ValueError("num_completions must be positive")
28+
if self.temperature < 0:
29+
raise ValueError("temperature must be non-negative")
30+
if not 0 < self.top_p <= 1.0:
31+
raise ValueError("top_p must be in the interval (0, 1]")
32+
if self.max_tokens is not None and self.max_tokens <= 0:
33+
raise ValueError("max_tokens must be positive when set")
34+
35+
36+
class CompletionProvider(Protocol):
37+
"""Protocol for chat completion backends used by inference clients."""
38+
39+
def complete_messages(self, request: ChatCompletionRequest) -> ChatCompletion:
40+
"""Return an OpenAI-compatible chat completion response."""
41+
42+
43+
class OpenAIChatCompletionProvider:
44+
"""Minimal synchronous OpenAI-compatible completion provider."""
45+
46+
def __init__(
47+
self,
48+
*,
49+
server_url: str,
50+
model: str,
51+
api_key: str = "marin-tts",
52+
timeout: float | None = None,
53+
extra_request_kwargs: dict[str, Any] | None = None,
54+
) -> None:
55+
self._client = OpenAI(base_url=server_url, api_key=api_key, timeout=timeout)
56+
self._model = model
57+
self._extra_request_kwargs = dict(extra_request_kwargs or {})
58+
59+
def complete_messages(self, request: ChatCompletionRequest) -> ChatCompletion:
60+
request_kwargs: dict[str, Any] = {
61+
"model": self._model,
62+
"messages": list(request.messages),
63+
"n": request.num_completions,
64+
"temperature": request.temperature,
65+
"top_p": request.top_p,
66+
"logprobs": request.logprobs,
67+
**self._extra_request_kwargs,
68+
}
69+
if request.max_tokens is not None:
70+
request_kwargs["max_tokens"] = request.max_tokens
71+
if request.seed is not None:
72+
request_kwargs["seed"] = request.seed
73+
74+
return self._client.chat.completions.create(**request_kwargs)

0 commit comments

Comments
 (0)