Skip to content

Commit 89aac6f

Browse files
committed
[tts] Refactor inference and test-time scaling ownership
1 parent eb148e2 commit 89aac6f

17 files changed

Lines changed: 132 additions & 109 deletions

experiments/evals/run_reasoning_tts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
from rigging.log_setup import configure_logging
1212

13-
from marin.evaluation.evaluators.evaluator import ModelConfig
13+
from marin.inference.chat_completions import OpenAIChatCompletionProvider
14+
from marin.inference.model_config import ModelConfig
1415
from marin.inference.vllm_server import VllmEnvironment
1516
from marin.test_time_scaling import (
1617
DEFAULT_REASONING_SELECTORS,
1718
CandidateGenerationConfig,
18-
OpenAIChatCompletionProvider,
1919
SelectorName,
2020
TestTimeScalingConfig,
2121
build_run_summary,

lib/marin/src/marin/evaluation/evaluators/evalchemy_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
from rigging.filesystem import filesystem as marin_filesystem
3636

3737
from marin.evaluation.evaluation_config import WANDB_PROJECT, EvalTaskConfig
38-
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
38+
from marin.evaluation.evaluators.evaluator import Evaluator
3939
from marin.inference.vllm_server import resolve_model_name_or_path
4040
from marin.evaluation.utils import is_remote_path, upload_to_gcs
41+
from marin.inference.model_config import ModelConfig
4142

4243
logger = logging.getLogger(__name__)
4344

lib/marin/src/marin/evaluation/evaluators/evaluator.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,9 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from abc import ABC, abstractmethod
5-
from dataclasses import dataclass
6-
from typing import Any
75

86
from marin.evaluation.evaluation_config import EvalTaskConfig
9-
10-
11-
@dataclass
12-
class ModelConfig:
13-
name: str
14-
"""The name of the model e.g., allenai/olmo-7b"""
15-
16-
path: str | None
17-
"""
18-
The path to the model checkpoint. Can be a local path or a path on GCS.
19-
"""
20-
21-
engine_kwargs: dict[str, Any]
22-
"""
23-
Additional keyword arguments to pass to the vLLM engine.
24-
"""
25-
26-
generation_params: dict | None = None
27-
"""
28-
Additional keyword arguments passed to the SamplingParams for the vLLM engine
29-
"""
30-
31-
apply_chat_template: bool = False
32-
"""
33-
Whether or not this model was trained with a Chat Template in the tokenizer
34-
"""
35-
36-
base_eval_run_name: str | None = None
37-
"""Custom base name for wandb runs."""
7+
from marin.inference.model_config import ModelConfig
388

399

4010
class Evaluator(ABC):

lib/marin/src/marin/evaluation/evaluators/harbor_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727
from rigging.filesystem import open_url
2828

2929
from marin.evaluation.evaluation_config import EvalTaskConfig
30-
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
30+
from marin.evaluation.evaluators.evaluator import Evaluator
3131
from marin.evaluation.utils import download_from_gcs, is_remote_path, upload_to_gcs
32+
from marin.inference.model_config import ModelConfig
3233
from marin.inference.vllm_server import VllmEnvironment
3334
from marin.utils import fsspec_exists, fsspec_glob
3435

lib/marin/src/marin/evaluation/evaluators/levanter_lm_eval_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from levanter.trainer import TrainerConfig
1616

1717
from marin.evaluation.evaluation_config import EvalTaskConfig, convert_to_levanter_task_config
18-
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
18+
from marin.evaluation.evaluators.evaluator import Evaluator
19+
from marin.inference.model_config import ModelConfig
1920

2021
logger = logging.getLogger(__name__)
2122

lib/marin/src/marin/evaluation/evaluators/lm_evaluation_harness_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from rigging.filesystem import open_url, url_to_fs
1313

1414
from marin.evaluation.evaluation_config import EvalTaskConfig
15-
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
15+
from marin.evaluation.evaluators.evaluator import Evaluator
1616
from marin.evaluation.utils import is_remote_path, upload_to_gcs
17+
from marin.inference.model_config import ModelConfig
1718
from marin.inference.vllm_server import VllmEnvironment
1819

1920
logger = logging.getLogger(__name__)

lib/marin/src/marin/evaluation/evaluators/simple_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from typing import ClassVar
88

99
from marin.evaluation.evaluation_config import EvalTaskConfig
10-
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
10+
from marin.evaluation.evaluators.evaluator import Evaluator
11+
from marin.inference.model_config import ModelConfig
1112
from marin.inference.vllm_server import resolve_model_name_or_path
1213

1314

lib/marin/src/marin/evaluation/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
import draccus
1818

1919
from marin.evaluation.evaluation_config import EvaluationConfig
20-
from marin.evaluation.evaluators.evaluator import Evaluator, ModelConfig
20+
from marin.evaluation.evaluators.evaluator import Evaluator
2121
from marin.evaluation.evaluators.evalchemy_evaluator import EvalchemyEvaluator
2222
from marin.evaluation.evaluators.harbor_evaluator import HarborEvaluator
2323
from marin.evaluation.evaluators.levanter_lm_eval_evaluator import LevanterLmEvalEvaluator
2424
from marin.evaluation.evaluators.lm_evaluation_harness_evaluator import LMEvaluationHarnessEvaluator
2525
from marin.evaluation.evaluators.simple_evaluator import SimpleEvaluator
2626
from marin.evaluation.utils import discover_hf_checkpoints
27+
from marin.inference.model_config import ModelConfig
2728
from marin.utils import fsspec_exists
2829

2930
logger = logging.getLogger(__name__)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from dataclasses import dataclass
7+
from typing import Any, Protocol
8+
9+
from openai import OpenAI
10+
from openai.types.chat import ChatCompletion
11+
12+
13+
@dataclass(frozen=True)
14+
class ChatCompletionRequest:
15+
"""OpenAI-compatible chat completion request parameters."""
16+
17+
messages: tuple[dict[str, str], ...]
18+
num_completions: int
19+
temperature: float
20+
top_p: float = 1.0
21+
max_tokens: int | None = None
22+
seed: int | None = None
23+
logprobs: bool = False
24+
25+
def __post_init__(self) -> None:
26+
if self.num_completions <= 0:
27+
raise ValueError("num_completions must be positive")
28+
if self.temperature < 0:
29+
raise ValueError("temperature must be non-negative")
30+
if not 0 < self.top_p <= 1.0:
31+
raise ValueError("top_p must be in the interval (0, 1]")
32+
if self.max_tokens is not None and self.max_tokens <= 0:
33+
raise ValueError("max_tokens must be positive when set")
34+
35+
36+
class CompletionProvider(Protocol):
37+
"""Protocol for chat completion backends used by inference clients."""
38+
39+
def complete_messages(self, request: ChatCompletionRequest) -> ChatCompletion:
40+
"""Return an OpenAI-compatible chat completion response."""
41+
42+
43+
class OpenAIChatCompletionProvider:
44+
"""Minimal synchronous OpenAI-compatible completion provider."""
45+
46+
def __init__(
47+
self,
48+
*,
49+
server_url: str,
50+
model: str,
51+
api_key: str = "marin-tts",
52+
timeout: float | None = None,
53+
extra_request_kwargs: dict[str, Any] | None = None,
54+
) -> None:
55+
self._client = OpenAI(base_url=server_url, api_key=api_key, timeout=timeout)
56+
self._model = model
57+
self._extra_request_kwargs = dict(extra_request_kwargs or {})
58+
59+
def complete_messages(self, request: ChatCompletionRequest) -> ChatCompletion:
60+
request_kwargs: dict[str, Any] = {
61+
"model": self._model,
62+
"messages": list(request.messages),
63+
"n": request.num_completions,
64+
"temperature": request.temperature,
65+
"top_p": request.top_p,
66+
"logprobs": request.logprobs,
67+
**self._extra_request_kwargs,
68+
}
69+
if request.max_tokens is not None:
70+
request_kwargs["max_tokens"] = request.max_tokens
71+
if request.seed is not None:
72+
request_kwargs["seed"] = request.seed
73+
74+
return self._client.chat.completions.create(**request_kwargs)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from dataclasses import dataclass
7+
from typing import Any
8+
9+
10+
@dataclass
11+
class ModelConfig:
12+
"""Configuration for launching or querying an inference model."""
13+
14+
name: str
15+
path: str | None
16+
engine_kwargs: dict[str, Any]
17+
generation_params: dict | None = None
18+
apply_chat_template: bool = False
19+
base_eval_run_name: str | None = None

0 commit comments

Comments
 (0)