Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/marin/src/marin/rl/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class MarinEnv(ABC):
Subclasses must implement sample() method.
"""

def prepare(self) -> None:
"""Perform one-time worker-local setup before sampling."""
return None

@abstractmethod
def sample(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

from .base import BaseInferenceContext
from .base import BaseInferenceContext, PromptLike
from .levanter import LevanterInferenceContext, LevanterInferenceContextConfig
from .vllm import (
vLLMInferenceContext,
Expand All @@ -19,6 +19,7 @@
"BaseInferenceContext",
"LevanterInferenceContext",
"LevanterInferenceContextConfig",
"PromptLike",
"VLLMSamplingConfig",
"vLLMInferenceContext",
"vLLMInferenceContextConfig",
Expand Down
39 changes: 34 additions & 5 deletions lib/marin/src/marin/rl/environments/inference_ctx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,26 @@
from marin.rl.types import Rollout

from levanter.models.lm_model import LmHeadModel
from marin.rl.environments.inference_ctx.openai_compat import OpenAICompatClient

logger = logging.getLogger(__name__)

PromptMessage = dict[str, object]
PromptLike = str | list[PromptMessage]


def prompt_to_messages(prompt: PromptLike, system_prompt: str | None = None) -> list[PromptMessage]:
"""Normalize plain-string or chat-message prompts to a message list."""
if isinstance(prompt, str):
messages: list[PromptMessage] = [{"role": "user", "content": prompt}]
else:
messages = [dict(message) for message in prompt]

if system_prompt is None:
return messages

return [{"role": "system", "content": system_prompt}, *messages]


class BaseInferenceContext:
"""Base class for inference contexts."""
Expand All @@ -34,9 +51,13 @@ def get_metrics(self) -> dict[str, Any]:
"""Return implementation-specific metrics for tracker logging."""
return {}

def openai_client(self) -> OpenAICompatClient:
"""Return an AsyncOpenAI-compatible client for verifier environments."""
return OpenAICompatClient(self)

def batch_completions(
self,
prompts: list[str] | list[list[dict]],
prompts: list[PromptLike],
temperature: float,
n: int,
max_tokens: int | None = None,
Expand All @@ -47,17 +68,23 @@ def batch_completions(
"""Batch completions from the inference server."""
raise NotImplementedError

def tokenize_prompt(self, prompt: str, choice: Choice | None = None, system_prompt: str | None = None) -> np.ndarray:
def tokenize_prompt(
self,
prompt: PromptLike,
choice: Choice | None = None,
system_prompt: str | None = None,
) -> np.ndarray:
"""Tokenize with chat template matching server behavior."""
messages = [{"role": "user", "content": prompt}]
messages = prompt_to_messages(prompt, system_prompt)
try:
tokens = self.tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
except Exception as e:
logger.warning(f"Chat template failed: {e}")
prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True)
if not tokens:
raise ValueError(f"Failed to tokenize: {prompt[:100]}...") from None
prompt_preview = prompt[:100] if isinstance(prompt, str) else repr(messages)[:100]
raise ValueError(f"Failed to tokenize: {prompt_preview}...") from None

return np.array(tokens, dtype=np.int32)

Expand Down Expand Up @@ -93,7 +120,7 @@ def logprobs_from_choice(self, choice: Choice) -> np.ndarray:

def create_rollout_from_choice(
self,
prompt: str,
prompt: PromptLike,
choice: Choice,
env_name: str,
env_example_id: str,
Expand All @@ -118,6 +145,7 @@ def create_rollout_from_choice(
logger.error(f"Prompt tokenization failed for {env_example_id}")

token_rewards = np.full(len(response_tokens), reward, dtype=np.float32)
response_loss_mask = np.ones(len(response_tokens), dtype=np.float32)
is_truncated = choice.finish_reason == "length"

return Rollout(
Expand All @@ -126,6 +154,7 @@ def create_rollout_from_choice(
prompt_tokens=prompt_tokens,
response_tokens=response_tokens,
response_logprobs=response_logprobs,
response_loss_mask=response_loss_mask,
token_rewards=token_rewards,
episode_reward=float(reward),
correctness_reward=correctness_reward,
Expand Down
16 changes: 11 additions & 5 deletions lib/marin/src/marin/rl/environments/inference_ctx/levanter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from levanter.tokenizers import MarinTokenizer
from levanter.models.lm_model import LmHeadModel
import haliax as hax
from marin.rl.environments.inference_ctx.base import BaseInferenceContext
from marin.rl.environments.inference_ctx.base import BaseInferenceContext, PromptLike

# TODO(chris): use a different weight transfer method update model, take it out from here
from marin.rl.weight_transfer.arrow_flight import update_model
Expand Down Expand Up @@ -85,10 +85,9 @@ def start_server(self, model: LmHeadModel) -> None:
def shutdown(self) -> None:
self._inference_server.shutdown()

# TODO: add support for ChatCompletion style [ { role, content} ] messages
def batch_completions(
self,
prompts: list[str] | list[list[dict]],
prompts: list[PromptLike],
temperature: float,
n: int,
max_tokens: int | None = None,
Expand All @@ -108,10 +107,17 @@ def batch_completions(
asyncio.set_event_loop(loop)
client = self.openai_client()

async def create_completion(prompt: str) -> ChatCompletion:
async def create_completion(prompt: PromptLike) -> ChatCompletion:
if isinstance(prompt, list):
messages = prompt
elif system_prompt is not None:
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
else:
messages = [{"role": "user", "content": prompt}]

return await client.chat.completions.create(
model=getattr(self._inference_server.config, "model_name", "test-model"),
messages=[{"role": "user", "content": prompt}],
messages=messages,
logprobs=True,
max_tokens=max_tokens,
temperature=temperature,
Expand Down
170 changes: 170 additions & 0 deletions lib/marin/src/marin/rl/environments/inference_ctx/openai_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""AsyncOpenAI-compatible adapter built on top of Marin inference contexts."""

import asyncio
from collections.abc import Mapping
from typing import Any

from openai.types.chat import ChatCompletion

_SUPPORTED_EXTRA_BODY_KEYS = frozenset({"top_k", "return_tokens_as_token_ids"})
_SUPPORTED_TOP_LOGPROBS = {None, 1}


def _supports_logprobs(value: object) -> bool:
return value is None or value is True or value == 1


def _normalize_messages(messages: object) -> list[dict[str, object]]:
if not isinstance(messages, list):
raise TypeError("messages must be a list of chat message dicts")

normalized_messages: list[dict[str, object]] = []
for index, message in enumerate(messages):
if not isinstance(message, Mapping):
raise TypeError(f"messages[{index}] must be a mapping")
if not isinstance(message.get("role"), str):
raise TypeError(f"messages[{index}] is missing a string role")
if not isinstance(message.get("content"), str):
raise TypeError(f"messages[{index}] is missing a string content")
normalized_messages.append(dict(message))

return normalized_messages


def _normalize_extra_body(extra_body: object) -> dict[str, object]:
if extra_body is None:
return {}
if not isinstance(extra_body, Mapping):
raise TypeError("extra_body must be a mapping when provided")

unsupported_keys = sorted(set(extra_body) - _SUPPORTED_EXTRA_BODY_KEYS)
if unsupported_keys:
raise NotImplementedError(f"Unsupported OpenAI compatibility extra_body keys: {', '.join(unsupported_keys)}")

return dict(extra_body)


def _extract_top_k(top_k: object, extra_body: dict[str, object]) -> int | None:
body_top_k = extra_body.get("top_k")
if top_k is not None and body_top_k is not None and top_k != body_top_k:
raise ValueError("top_k and extra_body['top_k'] must match when both are provided")
if body_top_k is not None and not isinstance(body_top_k, int):
raise TypeError("extra_body['top_k'] must be an int")
if top_k is not None and not isinstance(top_k, int):
raise TypeError("top_k must be an int")
return top_k if top_k is not None else body_top_k


def _validate_generation_kwargs(
*,
tools: object,
tool_choice: object,
logprobs: object,
top_logprobs: object,
timeout: object,
kwargs: dict[str, object],
) -> None:
if tools not in (None, []):
raise NotImplementedError("Tool-enabled verifier environments are not supported yet")
if tool_choice is not None:
raise NotImplementedError("tool_choice is not supported by Marin's OpenAI compatibility adapter")
if not _supports_logprobs(logprobs):
raise NotImplementedError("Only logprobs=True is supported by Marin's OpenAI compatibility adapter")
if top_logprobs not in _SUPPORTED_TOP_LOGPROBS:
raise NotImplementedError("Only top_logprobs=1 is supported by Marin's OpenAI compatibility adapter")
if timeout is not None and not isinstance(timeout, (int, float)):
raise TypeError("timeout must be numeric when provided")
if kwargs:
unsupported = ", ".join(sorted(kwargs))
raise NotImplementedError(f"Unsupported OpenAI compatibility kwargs: {unsupported}")


class _CompatChatCompletions:
def __init__(self, ctx: Any):
self._ctx = ctx

async def create(
self,
*,
messages: object,
model: str,
temperature: float | None = None,
n: int = 1,
max_tokens: int | None = None,
max_completion_tokens: int | None = None,
stop: list[str] | None = None,
top_k: int | None = None,
logprobs: bool | int | None = None,
top_logprobs: int | None = None,
extra_body: object = None,
timeout: float | int | None = None,
tools: object = None,
tool_choice: object = None,
**kwargs: object,
) -> ChatCompletion:
del model

if not isinstance(n, int) or n < 1:
raise ValueError("n must be a positive integer")

_validate_generation_kwargs(
tools=tools,
tool_choice=tool_choice,
logprobs=logprobs,
top_logprobs=top_logprobs,
timeout=timeout,
kwargs=kwargs,
)

if max_tokens is not None and max_completion_tokens is not None and max_tokens != max_completion_tokens:
raise ValueError("max_tokens and max_completion_tokens must match when both are provided")

normalized_messages = _normalize_messages(messages)
normalized_extra_body = _normalize_extra_body(extra_body)
resolved_top_k = _extract_top_k(top_k, normalized_extra_body)
resolved_max_tokens = max_tokens if max_tokens is not None else max_completion_tokens
resolved_temperature = (
temperature
if temperature is not None
else getattr(getattr(self._ctx, "sampling_config", None), "temperature", 1.0)
)

completions = await asyncio.to_thread(
self._ctx.batch_completions,
prompts=[normalized_messages],
temperature=resolved_temperature,
n=n,
max_tokens=resolved_max_tokens,
top_k=resolved_top_k,
stop=stop,
system_prompt=None,
)

if len(completions) != 1:
raise ValueError(f"Expected exactly one completion for one prompt, got {len(completions)}")

return completions[0]


class _CompatChatNamespace:
def __init__(self, ctx: Any):
self.completions = _CompatChatCompletions(ctx)


class _CompatCompletionsNamespace:
async def create(self, **_kwargs: object) -> ChatCompletion:
raise NotImplementedError("Completion-format requests are not supported by Marin's OpenAI compatibility adapter")


class OpenAICompatClient:
"""AsyncOpenAI-compatible client for verifier environments."""

def __init__(self, ctx: Any):
self.chat = _CompatChatNamespace(ctx)
self.completions = _CompatCompletionsNamespace()

async def close(self) -> None:
return None
Loading