From 77e523cfecf6369066395864779e555821023071 Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Thu, 16 Apr 2026 01:37:03 -0700 Subject: [PATCH 1/3] [rl] Add env prepare hook and OpenAI compat client This gives RL environments an explicit one-time setup lifecycle instead of pushing provisioning into sample-time code. It also defines a verifier-facing inference contract that supports chat prompts across backends while keeping unsupported OpenAI features explicit. --- lib/marin/src/marin/rl/environments/base.py | 4 + .../rl/environments/inference_ctx/__init__.py | 3 +- .../rl/environments/inference_ctx/base.py | 37 +++- .../rl/environments/inference_ctx/levanter.py | 16 +- .../inference_ctx/openai_compat.py | 170 ++++++++++++++++ .../rl/environments/inference_ctx/vllm.py | 36 ++-- lib/marin/src/marin/rl/rollout_worker.py | 2 + .../rl/environments/test_load_environment.py | 28 +++ tests/rl/test_inference_ctx.py | 191 +++++++++++++++++- tests/rl/test_rollout_worker.py | 69 +++++++ 10 files changed, 528 insertions(+), 28 deletions(-) create mode 100644 lib/marin/src/marin/rl/environments/inference_ctx/openai_compat.py diff --git a/lib/marin/src/marin/rl/environments/base.py b/lib/marin/src/marin/rl/environments/base.py index 176a7469ae..523d4e8ed8 100644 --- a/lib/marin/src/marin/rl/environments/base.py +++ b/lib/marin/src/marin/rl/environments/base.py @@ -19,6 +19,10 @@ class MarinEnv(ABC): Subclasses must implement sample() method. """ + def prepare(self) -> None: + """Perform one-time worker-local setup before sampling.""" + return None + @abstractmethod def sample( self, diff --git a/lib/marin/src/marin/rl/environments/inference_ctx/__init__.py b/lib/marin/src/marin/rl/environments/inference_ctx/__init__.py index 2e973796cb..44f779ba1e 100644 --- a/lib/marin/src/marin/rl/environments/inference_ctx/__init__.py +++ b/lib/marin/src/marin/rl/environments/inference_ctx/__init__.py @@ -1,7 +1,7 @@ # Copyright The Marin Authors # SPDX-License-Identifier: Apache-2.0 -from .base import BaseInferenceContext +from .base import BaseInferenceContext, PromptLike from .levanter import LevanterInferenceContext, LevanterInferenceContextConfig from .vllm import ( vLLMInferenceContext, @@ -19,6 +19,7 @@ "BaseInferenceContext", "LevanterInferenceContext", "LevanterInferenceContextConfig", + "PromptLike", "VLLMSamplingConfig", "vLLMInferenceContext", "vLLMInferenceContextConfig", diff --git a/lib/marin/src/marin/rl/environments/inference_ctx/base.py b/lib/marin/src/marin/rl/environments/inference_ctx/base.py index 987aa91b5c..3cd0b7b4cb 100644 --- a/lib/marin/src/marin/rl/environments/inference_ctx/base.py +++ b/lib/marin/src/marin/rl/environments/inference_ctx/base.py @@ -17,9 +17,26 @@ from marin.rl.types import Rollout from levanter.models.lm_model import LmHeadModel +from marin.rl.environments.inference_ctx.openai_compat import OpenAICompatClient logger = logging.getLogger(__name__) +PromptMessage = dict[str, object] +PromptLike = str | list[PromptMessage] + + +def prompt_to_messages(prompt: PromptLike, system_prompt: str | None = None) -> list[PromptMessage]: + """Normalize plain-string or chat-message prompts to a message list.""" + if isinstance(prompt, str): + messages: list[PromptMessage] = [{"role": "user", "content": prompt}] + else: + messages = [dict(message) for message in prompt] + + if system_prompt is None: + return messages + + return [{"role": "system", "content": system_prompt}, *messages] + class BaseInferenceContext: """Base class for inference contexts.""" @@ -34,9 +51,13 @@ def get_metrics(self) -> dict[str, Any]: """Return implementation-specific metrics for tracker logging.""" return {} + def openai_client(self) -> OpenAICompatClient: + """Return an AsyncOpenAI-compatible client for verifier environments.""" + return OpenAICompatClient(self) + def batch_completions( self, - prompts: list[str] | list[list[dict]], + prompts: list[PromptLike], temperature: float, n: int, max_tokens: int | None = None, @@ -47,9 +68,14 @@ def batch_completions( """Batch completions from the inference server.""" raise NotImplementedError - def tokenize_prompt(self, prompt: str, choice: Choice | None = None, system_prompt: str | None = None) -> np.ndarray: + def tokenize_prompt( + self, + prompt: PromptLike, + choice: Choice | None = None, + system_prompt: str | None = None, + ) -> np.ndarray: """Tokenize with chat template matching server behavior.""" - messages = [{"role": "user", "content": prompt}] + messages = prompt_to_messages(prompt, system_prompt) try: tokens = self.tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) except Exception as e: @@ -57,7 +83,8 @@ def tokenize_prompt(self, prompt: str, choice: Choice | None = None, system_prom prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True) if not tokens: - raise ValueError(f"Failed to tokenize: {prompt[:100]}...") from None + prompt_preview = prompt[:100] if isinstance(prompt, str) else repr(messages)[:100] + raise ValueError(f"Failed to tokenize: {prompt_preview}...") from None return np.array(tokens, dtype=np.int32) @@ -93,7 +120,7 @@ def logprobs_from_choice(self, choice: Choice) -> np.ndarray: def create_rollout_from_choice( self, - prompt: str, + prompt: PromptLike, choice: Choice, env_name: str, env_example_id: str, diff --git a/lib/marin/src/marin/rl/environments/inference_ctx/levanter.py b/lib/marin/src/marin/rl/environments/inference_ctx/levanter.py index 47900f56ba..6f603ad301 100644 --- a/lib/marin/src/marin/rl/environments/inference_ctx/levanter.py +++ b/lib/marin/src/marin/rl/environments/inference_ctx/levanter.py @@ -19,7 +19,7 @@ from levanter.tokenizers import MarinTokenizer from levanter.models.lm_model import LmHeadModel import haliax as hax -from marin.rl.environments.inference_ctx.base import BaseInferenceContext +from marin.rl.environments.inference_ctx.base import BaseInferenceContext, PromptLike # TODO(chris): use a different weight transfer method update model, take it out from here from marin.rl.weight_transfer.arrow_flight import update_model @@ -85,10 +85,9 @@ def start_server(self, model: LmHeadModel) -> None: def shutdown(self) -> None: self._inference_server.shutdown() - # TODO: add support for ChatCompletion style [ { role, content} ] messages def batch_completions( self, - prompts: list[str] | list[list[dict]], + prompts: list[PromptLike], temperature: float, n: int, max_tokens: int | None = None, @@ -108,10 +107,17 @@ def batch_completions( asyncio.set_event_loop(loop) client = self.openai_client() - async def create_completion(prompt: str) -> ChatCompletion: + async def create_completion(prompt: PromptLike) -> ChatCompletion: + if isinstance(prompt, list): + messages = prompt + elif system_prompt is not None: + messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}] + else: + messages = [{"role": "user", "content": prompt}] + return await client.chat.completions.create( model=getattr(self._inference_server.config, "model_name", "test-model"), - messages=[{"role": "user", "content": prompt}], + messages=messages, logprobs=True, max_tokens=max_tokens, temperature=temperature, diff --git a/lib/marin/src/marin/rl/environments/inference_ctx/openai_compat.py b/lib/marin/src/marin/rl/environments/inference_ctx/openai_compat.py new file mode 100644 index 0000000000..1b7a855de6 --- /dev/null +++ b/lib/marin/src/marin/rl/environments/inference_ctx/openai_compat.py @@ -0,0 +1,170 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""AsyncOpenAI-compatible adapter built on top of Marin inference contexts.""" + +import asyncio +from collections.abc import Mapping +from typing import Any + +from openai.types.chat import ChatCompletion + +_SUPPORTED_EXTRA_BODY_KEYS = frozenset({"top_k", "return_tokens_as_token_ids"}) +_SUPPORTED_TOP_LOGPROBS = {None, 1} + + +def _supports_logprobs(value: object) -> bool: + return value is None or value is True or value == 1 + + +def _normalize_messages(messages: object) -> list[dict[str, object]]: + if not isinstance(messages, list): + raise TypeError("messages must be a list of chat message dicts") + + normalized_messages: list[dict[str, object]] = [] + for index, message in enumerate(messages): + if not isinstance(message, Mapping): + raise TypeError(f"messages[{index}] must be a mapping") + if not isinstance(message.get("role"), str): + raise TypeError(f"messages[{index}] is missing a string role") + if not isinstance(message.get("content"), str): + raise TypeError(f"messages[{index}] is missing a string content") + normalized_messages.append(dict(message)) + + return normalized_messages + + +def _normalize_extra_body(extra_body: object) -> dict[str, object]: + if extra_body is None: + return {} + if not isinstance(extra_body, Mapping): + raise TypeError("extra_body must be a mapping when provided") + + unsupported_keys = sorted(set(extra_body) - _SUPPORTED_EXTRA_BODY_KEYS) + if unsupported_keys: + raise NotImplementedError(f"Unsupported OpenAI compatibility extra_body keys: {', '.join(unsupported_keys)}") + + return dict(extra_body) + + +def _extract_top_k(top_k: object, extra_body: dict[str, object]) -> int | None: + body_top_k = extra_body.get("top_k") + if top_k is not None and body_top_k is not None and top_k != body_top_k: + raise ValueError("top_k and extra_body['top_k'] must match when both are provided") + if body_top_k is not None and not isinstance(body_top_k, int): + raise TypeError("extra_body['top_k'] must be an int") + if top_k is not None and not isinstance(top_k, int): + raise TypeError("top_k must be an int") + return top_k if top_k is not None else body_top_k + + +def _validate_generation_kwargs( + *, + tools: object, + tool_choice: object, + logprobs: object, + top_logprobs: object, + timeout: object, + kwargs: dict[str, object], +) -> None: + if tools not in (None, []): + raise NotImplementedError("Tool-enabled verifier environments are not supported yet") + if tool_choice is not None: + raise NotImplementedError("tool_choice is not supported by Marin's OpenAI compatibility adapter") + if not _supports_logprobs(logprobs): + raise NotImplementedError("Only logprobs=True is supported by Marin's OpenAI compatibility adapter") + if top_logprobs not in _SUPPORTED_TOP_LOGPROBS: + raise NotImplementedError("Only top_logprobs=1 is supported by Marin's OpenAI compatibility adapter") + if timeout is not None and not isinstance(timeout, (int, float)): + raise TypeError("timeout must be numeric when provided") + if kwargs: + unsupported = ", ".join(sorted(kwargs)) + raise NotImplementedError(f"Unsupported OpenAI compatibility kwargs: {unsupported}") + + +class _CompatChatCompletions: + def __init__(self, ctx: Any): + self._ctx = ctx + + async def create( + self, + *, + messages: object, + model: str, + temperature: float | None = None, + n: int = 1, + max_tokens: int | None = None, + max_completion_tokens: int | None = None, + stop: list[str] | None = None, + top_k: int | None = None, + logprobs: bool | int | None = None, + top_logprobs: int | None = None, + extra_body: object = None, + timeout: float | int | None = None, + tools: object = None, + tool_choice: object = None, + **kwargs: object, + ) -> ChatCompletion: + del model + + if not isinstance(n, int) or n < 1: + raise ValueError("n must be a positive integer") + + _validate_generation_kwargs( + tools=tools, + tool_choice=tool_choice, + logprobs=logprobs, + top_logprobs=top_logprobs, + timeout=timeout, + kwargs=kwargs, + ) + + if max_tokens is not None and max_completion_tokens is not None and max_tokens != max_completion_tokens: + raise ValueError("max_tokens and max_completion_tokens must match when both are provided") + + normalized_messages = _normalize_messages(messages) + normalized_extra_body = _normalize_extra_body(extra_body) + resolved_top_k = _extract_top_k(top_k, normalized_extra_body) + resolved_max_tokens = max_tokens if max_tokens is not None else max_completion_tokens + resolved_temperature = ( + temperature + if temperature is not None + else getattr(getattr(self._ctx, "sampling_config", None), "temperature", 1.0) + ) + + completions = await asyncio.to_thread( + self._ctx.batch_completions, + prompts=[normalized_messages], + temperature=resolved_temperature, + n=n, + max_tokens=resolved_max_tokens, + top_k=resolved_top_k, + stop=stop, + system_prompt=None, + ) + + if len(completions) != 1: + raise ValueError(f"Expected exactly one completion for one prompt, got {len(completions)}") + + return completions[0] + + +class _CompatChatNamespace: + def __init__(self, ctx: Any): + self.completions = _CompatChatCompletions(ctx) + + +class _CompatCompletionsNamespace: + async def create(self, **_kwargs: object) -> ChatCompletion: + raise NotImplementedError("Completion-format requests are not supported by Marin's OpenAI compatibility adapter") + + +class OpenAICompatClient: + """AsyncOpenAI-compatible client for verifier environments.""" + + def __init__(self, ctx: Any): + self.chat = _CompatChatNamespace(ctx) + self.completions = _CompatCompletionsNamespace() + + async def close(self) -> None: + return None diff --git a/lib/marin/src/marin/rl/environments/inference_ctx/vllm.py b/lib/marin/src/marin/rl/environments/inference_ctx/vllm.py index 35c34e9728..53ea13b095 100644 --- a/lib/marin/src/marin/rl/environments/inference_ctx/vllm.py +++ b/lib/marin/src/marin/rl/environments/inference_ctx/vllm.py @@ -16,7 +16,7 @@ from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.completion_usage import CompletionUsage from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob, TopLogprob -from marin.rl.environments.inference_ctx.base import BaseInferenceContext +from marin.rl.environments.inference_ctx.base import BaseInferenceContext, PromptLike from marin.rl.environments.inference_ctx.inflight.worker import SyncVLLMWrapper from marin.rl.environments.inference_ctx.render import Llama3Renderer, Qwen3Renderer, Renderer, Message from marin.rl.environments.inference_ctx.vllm_utils import MODEL_MAPPINGS, MODEL_TRANSPOSE_KEYS @@ -194,7 +194,12 @@ def _get_llm_engine(inference_config: vLLMInferenceContextConfig): kv_cache_metrics=inference_config.kv_cache_metrics, ) - def tokenize_prompt(self, prompt: str, choice: Choice | None = None, system_prompt: str | None = None) -> np.ndarray: + def tokenize_prompt( + self, + prompt: PromptLike, + choice: Choice | None = None, + system_prompt: str | None = None, + ) -> np.ndarray: """Tokenize the prompt with the choice's prompt token IDs. NOTE(chris): This is a hack to get the prompt token IDs the same since @@ -202,6 +207,9 @@ def tokenize_prompt(self, prompt: str, choice: Choice | None = None, system_prom This is a known issue documented here: https://github.com/vllm-project/vllm/issues/27486 """ + del prompt, system_prompt + if choice is None: + raise ValueError("vLLMInferenceContext.tokenize_prompt requires a choice with prompt_token_ids") return np.array(choice.prompt_token_ids, dtype=np.int32) def response_tokens_from_choice(self, choice: Choice) -> np.ndarray: @@ -319,7 +327,7 @@ def shutdown(self) -> None: def batch_completions( self, - prompts: list[str] | list[list[dict]], + prompts: list[PromptLike], temperature: float, n: int, max_tokens: int | None = None, @@ -354,19 +362,15 @@ def batch_completions( # Convert prompts to message lists if they aren't already message_lists: list[list[Message]] = [] - if prompts and isinstance(prompts[0], list): - # Prompts are already message lists with few-shot examples - message_lists = prompts # type: ignore - elif system_prompt: - # Plain string prompts with system prompt - assert all(isinstance(p, str) for p in prompts), "prompts must be strings when system_prompt is provided" - message_lists = [ - [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}] for prompt in prompts # type: ignore - ] - else: - # Plain string prompts without system prompt - assert all(isinstance(p, str) for p in prompts), "prompts must be strings when no system_prompt is provided" - message_lists = [[{"role": "user", "content": prompt}] for prompt in prompts] # type: ignore + for prompt in prompts: + if isinstance(prompt, list): + message_list = [dict(message) for message in prompt] + elif system_prompt: + message_list = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}] + else: + message_list = [{"role": "user", "content": prompt}] + + message_lists.append(message_list) # Render messages to token IDs using the appropriate renderer prompt_token_ids = [] diff --git a/lib/marin/src/marin/rl/rollout_worker.py b/lib/marin/src/marin/rl/rollout_worker.py index a0ee18ee72..8348f59750 100644 --- a/lib/marin/src/marin/rl/rollout_worker.py +++ b/lib/marin/src/marin/rl/rollout_worker.py @@ -496,6 +496,8 @@ def _load_environment(self, lesson_id: str) -> MarinEnv: lesson_config = self.config.curriculum_config.lessons[lesson_id] env = load_environment_from_spec(lesson_config.env_config) + # Only cache environments after one-time setup succeeds. + env.prepare() self._environments[lesson_id] = env return env diff --git a/tests/rl/environments/test_load_environment.py b/tests/rl/environments/test_load_environment.py index 24dc6d0555..c99f1aabcb 100644 --- a/tests/rl/environments/test_load_environment.py +++ b/tests/rl/environments/test_load_environment.py @@ -3,6 +3,9 @@ """Tests for environment loading from EnvConfig.""" +import sys +from types import ModuleType + from marin.rl.environments import EnvConfig, load_environment_from_spec from marin.rl.environments.mock_env import MockEnv from marin.rl.environments.math_env import MathEnv @@ -36,3 +39,28 @@ def test_load_math_environment(): assert isinstance(env, MathEnv) assert len(env.train_examples) > 0 assert len(env.eval_examples) > 0 + + +def test_load_environment_from_spec_does_not_call_prepare(): + module_name = "test_prepare_env_module" + test_module = ModuleType(module_name) + + class PreparingEnv: + prepare_calls = 0 + + def __init__(self): + pass + + def prepare(self): + type(self).prepare_calls += 1 + + test_module.PreparingEnv = PreparingEnv + sys.modules[module_name] = test_module + + try: + env = load_environment_from_spec(EnvConfig(env_class=f"{module_name}.PreparingEnv", env_args={})) + finally: + sys.modules.pop(module_name, None) + + assert isinstance(env, PreparingEnv) + assert PreparingEnv.prepare_calls == 0 diff --git a/tests/rl/test_inference_ctx.py b/tests/rl/test_inference_ctx.py index 2b7b1599dd..e2fd2a91c7 100644 --- a/tests/rl/test_inference_ctx.py +++ b/tests/rl/test_inference_ctx.py @@ -3,6 +3,7 @@ """Tests for InferenceContext utilities and chat template handling.""" +import asyncio import sys from dataclasses import dataclass from types import SimpleNamespace @@ -11,8 +12,10 @@ import numpy as np import pytest from levanter.inference.openai import ChatMessage -from openai.types.chat import ChatCompletionMessage +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import ChatCompletionTokenLogprob, Choice, ChoiceLogprobs +from openai.types.completion_usage import CompletionUsage from transformers import AutoTokenizer from marin.rl.environments.inference_ctx import ( @@ -26,6 +29,7 @@ ) from marin.rl.environments.inference_ctx.vllm import InferenceMode from marin.rl.environments.inference_ctx.inflight.worker import WorkerExtension +from marin.rl.environments.inference_ctx.openai_compat import OpenAICompatClient _LLAMA3_MODEL_ID = "NousResearch/Meta-Llama-3-8B-Instruct" @@ -106,6 +110,24 @@ def create_choice_with_logprobs(tokenizer, response_text: str, logprobs_values: ) +def create_chat_completion(response_text: str) -> ChatCompletion: + return ChatCompletion( + id="chatcmpl-test", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=response_text), + logprobs=ChoiceLogprobs(content=[]), + ) + ], + created=1234567890, + model="test-model", + object="chat.completion", + usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2), + ) + + def test_apply_chat_template(llama3_tokenizer): messages = [ ChatMessage(role="system", content="You are a helpful assistant."), @@ -147,6 +169,20 @@ def test_tokenize_prompt_adds_special_tokens(inference_ctx, llama3_tokenizer): assert prompt in decoded +def test_tokenize_prompt_supports_message_lists(inference_ctx, llama3_tokenizer): + messages = [ + {"role": "system", "content": "You are careful."}, + {"role": "user", "content": "What is 2+2?"}, + ] + + tokens = inference_ctx.tokenize_prompt(messages) + decoded = llama3_tokenizer.decode(tokens) + + assert "You are careful." in decoded + assert "What is 2+2?" in decoded + assert "<|start_header_id|>assistant<|end_header_id|>" in decoded + + def test_tokenize_prompt_fallback_no_template(gpt2_tokenizer, dummy_server): """Test fallback when tokenizer has no chat template.""" ctx = LevanterInferenceContext( @@ -169,6 +205,29 @@ def test_tokenize_prompt_fallback_no_template(gpt2_tokenizer, dummy_server): assert prompt in decoded +def test_tokenize_prompt_fallback_message_list_no_template(gpt2_tokenizer): + ctx = LevanterInferenceContext( + LevanterInferenceContextConfig( + inference_server_config=None, + tokenizer=gpt2_tokenizer, + stop_tokens=None, + max_tokens=100, + mesh=None, + axis_mapping={}, + ) + ) + + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Test prompt"}, + ] + tokens = ctx.tokenize_prompt(messages) + + decoded = gpt2_tokenizer.decode(tokens) + assert "system: You are helpful." in decoded + assert "user: Test prompt" in decoded + + def test_response_tokens_from_choice(inference_ctx, llama3_tokenizer): """Test extracting token IDs from Choice using BPE round-trip.""" response_text = "The answer is 42" @@ -330,6 +389,136 @@ def __init__(self, **kwargs): assert calls["kv_cache_metrics"] is True +def test_levanter_openai_client_stays_native(inference_ctx, dummy_server): + inference_ctx._inference_server = dummy_server + + client = inference_ctx.openai_client() + + assert isinstance(client, AsyncOpenAI) + assert str(client.base_url).endswith("/v1/") + + +def test_vllm_openai_client_delegates_to_batch_completions(monkeypatch): + monkeypatch.setattr( + vLLMInferenceContext, + "_get_llm_engine", + staticmethod(lambda _config: object()), + ) + monkeypatch.setattr( + "marin.rl.environments.inference_ctx.vllm.load_tokenizer", + lambda _path: SimpleNamespace(get_vocab=lambda: {}), + ) + monkeypatch.setattr( + vLLMInferenceContext, + "_get_renderer", + staticmethod(lambda model_name, _tokenizer: model_name), + ) + + ctx = vLLMInferenceContext( + vLLMInferenceContextConfig( + model_name="test-model", + canonical_model_name="meta-llama/Llama-3.1-8B-Instruct", + max_model_len=1024, + tensor_parallel_size=1, + gpu_memory_utilization=0.9, + sampling_params=VLLMSamplingConfig(temperature=0.2, top_k=13), + ) + ) + + completion = create_chat_completion("hello") + completion.choices[0].prompt_token_ids = [11, 12, 13] + completion.choices[0].response_token_ids = [21, 22] + captured = {} + + def _fake_batch_completions(*, prompts, temperature, n, max_tokens, top_k, stop, system_prompt): + captured.update( + prompts=prompts, + temperature=temperature, + n=n, + max_tokens=max_tokens, + top_k=top_k, + stop=stop, + system_prompt=system_prompt, + ) + return [completion] + + monkeypatch.setattr(ctx, "batch_completions", _fake_batch_completions) + + client = ctx.openai_client() + result = asyncio.run( + client.chat.completions.create( + model="marin-model", + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Solve 2+2"}, + ], + temperature=0.7, + n=2, + max_completion_tokens=32, + stop=[""], + extra_body={"top_k": 17, "return_tokens_as_token_ids": True}, + logprobs=True, + top_logprobs=1, + ) + ) + + assert isinstance(client, OpenAICompatClient) + assert result is completion + assert captured == { + "prompts": [[{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Solve 2+2"}]], + "temperature": 0.7, + "n": 2, + "max_tokens": 32, + "top_k": 17, + "stop": [""], + "system_prompt": None, + } + assert result.choices[0].prompt_token_ids == [11, 12, 13] + assert result.choices[0].response_token_ids == [21, 22] + + +def test_vllm_openai_client_rejects_unsupported_kwargs(): + class _DummyCtx: + def batch_completions(self, **_kwargs): + raise AssertionError("batch_completions should not be called for unsupported kwargs") + + client = OpenAICompatClient(_DummyCtx()) + + with pytest.raises(NotImplementedError, match="Tool-enabled verifier environments"): + asyncio.run( + client.chat.completions.create( + model="marin-model", + messages=[{"role": "user", "content": "hello"}], + tools=[{"type": "function"}], + ) + ) + + with pytest.raises(NotImplementedError, match="Unsupported OpenAI compatibility extra_body keys"): + asyncio.run( + client.chat.completions.create( + model="marin-model", + messages=[{"role": "user", "content": "hello"}], + extra_body={"min_p": 0.1}, + ) + ) + + with pytest.raises(NotImplementedError, match="Unsupported OpenAI compatibility kwargs"): + asyncio.run( + client.chat.completions.create( + model="marin-model", + messages=[{"role": "user", "content": "hello"}], + response_format={"type": "json_object"}, + ) + ) + + +def test_vllm_openai_client_completion_endpoint_is_not_supported(): + client = OpenAICompatClient(SimpleNamespace()) + + with pytest.raises(NotImplementedError, match="Completion-format requests are not supported"): + asyncio.run(client.completions.create(model="marin-model", prompt="hello")) + + def test_worker_extension_uses_public_sync_weights(): calls = {} diff --git a/tests/rl/test_rollout_worker.py b/tests/rl/test_rollout_worker.py index 636c61d806..112159d0e1 100644 --- a/tests/rl/test_rollout_worker.py +++ b/tests/rl/test_rollout_worker.py @@ -245,6 +245,75 @@ def __init__(self, *, inference_config): assert captured["config"].kv_cache_metrics is True +def test_load_environment_prepares_once_before_caching(monkeypatch): + class _FakeEnv: + def __init__(self): + self.prepare_calls = 0 + + def prepare(self) -> None: + self.prepare_calls += 1 + + env = _FakeEnv() + monkeypatch.setattr("marin.rl.rollout_worker.load_environment_from_spec", lambda _config: env) + + worker = object.__new__(RolloutWorker) + worker._environments = {} + worker.config = SimpleNamespace( + curriculum_config=SimpleNamespace( + lessons={"lesson-a": SimpleNamespace(env_config=object())}, + ) + ) + + loaded_env = worker._load_environment("lesson-a") + cached_env = worker._load_environment("lesson-a") + + assert loaded_env is env + assert cached_env is env + assert env.prepare_calls == 1 + + +def test_load_environment_does_not_cache_failed_prepare(monkeypatch): + class _FailingEnv: + def __init__(self): + self.prepare_calls = 0 + + def prepare(self) -> None: + self.prepare_calls += 1 + raise RuntimeError("prepare failed") + + class _HealthyEnv: + def __init__(self): + self.prepare_calls = 0 + + def prepare(self) -> None: + self.prepare_calls += 1 + + failing_env = _FailingEnv() + healthy_env = _HealthyEnv() + created_envs = iter([failing_env, healthy_env]) + monkeypatch.setattr("marin.rl.rollout_worker.load_environment_from_spec", lambda _config: next(created_envs)) + + worker = object.__new__(RolloutWorker) + worker._environments = {} + worker.config = SimpleNamespace( + curriculum_config=SimpleNamespace( + lessons={"lesson-a": SimpleNamespace(env_config=object())}, + ) + ) + + with pytest.raises(RuntimeError, match="prepare failed"): + worker._load_environment("lesson-a") + + recovered_env = worker._load_environment("lesson-a") + cached_env = worker._load_environment("lesson-a") + + assert recovered_env is healthy_env + assert cached_env is healthy_env + assert worker._environments["lesson-a"] is healthy_env + assert failing_env.prepare_calls == 1 + assert healthy_env.prepare_calls == 1 + + @pytest.mark.parametrize( ("current_train_step", "last_eval_train_step", "eval_frequency", "worker_index", "expected"), [ From 227cade3b30384cea3ffebe37e22cbb5384b8023 Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Thu, 16 Apr 2026 02:15:37 -0700 Subject: [PATCH 2/3] [rl] Rewrite Prime verifier env for Phase 1 --- .../rl/environments/prime_intellect_env.py | 387 +++++++---- .../rl/environments/process_vllm_results.py | 180 ----- .../environments/test_prime_intellect_env.py | 622 +++++++++++++++--- .../environments/test_process_vllm_results.py | 107 --- 4 files changed, 777 insertions(+), 519 deletions(-) delete mode 100644 lib/marin/src/marin/rl/environments/process_vllm_results.py delete mode 100644 tests/rl/environments/test_process_vllm_results.py diff --git a/lib/marin/src/marin/rl/environments/prime_intellect_env.py b/lib/marin/src/marin/rl/environments/prime_intellect_env.py index edebc3db38..d00626026e 100644 --- a/lib/marin/src/marin/rl/environments/prime_intellect_env.py +++ b/lib/marin/src/marin/rl/environments/prime_intellect_env.py @@ -1,74 +1,244 @@ # Copyright The Marin Authors # SPDX-License-Identifier: Apache-2.0 -""" -Environment Wrapper for the Environments Hub by Prime-Intellect, which contains a collection of environments. -https://app.primeintellect.ai/dashboard/environments?ex_sort=most_stars -""" +"""Environment wrapper for Prime Intellect verifier environments.""" + import logging -from typing import Any, ClassVar, cast +import shutil +import subprocess +from collections.abc import Mapping +from typing import Any, ClassVar -import jax.numpy as jnp import numpy as np from marin.rl.environments import MarinEnv -from marin.rl.environments.process_vllm_results import process_vllm_chat_results -from marin.rl.environments.inference_ctx import BaseInferenceContext -from marin.rl.types import Rollout, RolloutGroup +from marin.rl.environments.inference_ctx import BaseInferenceContext, PromptLike +from marin.rl.types import RolloutGroup logger = logging.getLogger(__name__) +_SUPPORTED_OWNER = "primeintellect" +_ENV_NAME_PREFIX = "prime_intellect:" + + +def _freeze_cache_value(value: object) -> object: + if isinstance(value, (str, int, float, bool, type(None))): + return value + if isinstance(value, list | tuple): + return tuple(_freeze_cache_value(item) for item in value) + if isinstance(value, Mapping): + frozen_items = [] + for key, item in value.items(): + if not isinstance(key, str): + raise TypeError(f"PrimeIntellectEnv env_args keys must be strings, got {type(key).__name__}") + frozen_items.append((key, _freeze_cache_value(item))) + return tuple(sorted(frozen_items)) + raise TypeError( + "PrimeIntellectEnv env_args must contain only JSON-like values, " f"got unsupported type {type(value).__name__}" + ) + + +def _scalarize_metric(metric_name: str, values: object) -> float: + if isinstance(values, (int, float, np.number, bool)): + return float(values) + + if not isinstance(values, list): + raise TypeError(f"Metric {metric_name!r} must be numeric or a list of numeric values") + if not values: + raise ValueError(f"Metric {metric_name!r} cannot be an empty list") + + try: + return float(np.mean(np.asarray(values, dtype=np.float32))) + except (TypeError, ValueError) as exc: + raise TypeError(f"Metric {metric_name!r} must contain only numeric values") from exc + class PrimeIntellectEnv(MarinEnv): - """ - Environment Wrapper for the Environments Hub by Prime-Intellect, which contains a collection of environments. - """ + """Adapter for Phase 1 Prime Intellect verifier environments.""" - ENVS: ClassVar[dict[str, Any]] = {} + INSTALLED_ENV_IDS: ClassVar[set[str]] = set() + LOADED_ENVIRONMENTS: ClassVar[dict[tuple[str, object], Any]] = {} def __init__( self, env_id: str, - env_args: dict = {}, # noqa: B006 + env_args: dict[str, object] | None = None, max_tokens: int = 1024, max_concurrent: int = 32, ): - """Initialize PrimeIntellect environment. - - Args: - env_id: Environment ID like "primeintellect/gsm8k" - env_args: Dict with verifier-specific args (num_train_examples, etc.) - max_tokens: Maximum tokens for generation - max_concurrent: Maximum concurrent requests - """ self.env_id = env_id - self.env_args = env_args + self.env_args = dict(env_args or {}) self.max_tokens = max_tokens self.max_concurrent = max_concurrent + self._normalized_env_args = _freeze_cache_value(self.env_args) + self._is_prepared = False - def _ensure_verifiers_installed(self): - """Ensure verifiers package is installed.""" + def _verifiers_module(self) -> Any: try: - import verifiers # noqa: F401 - except ImportError as e: + import verifiers as vf + except ImportError as exc: raise ImportError( "The 'verifiers' package is required to use PrimeIntellectEnv. " "Please install it with: uv pip install 'marin[rl]' or uv pip install verifiers" - ) from e - - def load_prime_intellect_env(self, env_id: str, env_args: dict) -> Any: - """ - Get the Verifiers environment for the environment ID. - """ - self._ensure_verifiers_installed() - import verifiers as vf + ) from exc + + return vf + + def _short_env_id(self) -> str: + owner, separator, slug = self.env_id.partition("/") + if owner != _SUPPORTED_OWNER or separator == "" or not slug: + raise ValueError(f"PrimeIntellectEnv Phase 1 only supports '{_SUPPORTED_OWNER}/*' IDs, got {self.env_id!r}") + return slug + + def _verifier_cache_key(self) -> tuple[str, object]: + return self.env_id, self._normalized_env_args + + def prepare(self) -> None: + self._verifiers_module() + self._short_env_id() + + prime_executable = shutil.which("prime") + if prime_executable is None: + raise RuntimeError( + "PrimeIntellectEnv requires the 'prime' executable on PATH. " + "Install the Prime CLI before running Prime verifier environments." + ) + + if self.env_id not in self.INSTALLED_ENV_IDS: + subprocess.run([prime_executable, "env", "install", self.env_id], check=True) + self.INSTALLED_ENV_IDS.add(self.env_id) + + self._is_prepared = True + + def _load_verifier_env(self) -> Any: + cache_key = self._verifier_cache_key() + if cache_key in self.LOADED_ENVIRONMENTS: + return self.LOADED_ENVIRONMENTS[cache_key] + + vf = self._verifiers_module() + short_env_id = self._short_env_id() + verifier_env = vf.load_environment(env_id=short_env_id, **self.env_args) + self.LOADED_ENVIRONMENTS[cache_key] = verifier_env + return verifier_env + + def _validate_sample_request(self, mode: str, system_prompt: str | None) -> None: + if not self._is_prepared: + raise RuntimeError("PrimeIntellectEnv.prepare() must be called before sample()") + if mode not in ("train", "eval"): + raise ValueError(f"Unsupported mode: {mode!r}") + if system_prompt is not None: + raise ValueError("PrimeIntellectEnv Phase 1 does not support Marin-level system prompts") + + def _validate_verifier_env(self, verifier_env: Any) -> None: + message_type = getattr(verifier_env, "message_type", None) + if message_type != "chat": + raise ValueError( + f"PrimeIntellectEnv Phase 1 only supports chat-format verifier environments, got {message_type!r}" + ) + + if getattr(verifier_env, "oai_tools", None): + raise ValueError("PrimeIntellectEnv Phase 1 does not support tool-enabled verifier environments") + + def _select_inputs(self, verifier_env: Any, mode: str, n_examples: int) -> Any: + if mode == "train": + return verifier_env.get_dataset(n=n_examples) + return verifier_env.get_eval_dataset(n=n_examples) + + def _repeat_inputs(self, inputs: Any, n_generations: int) -> Any: + if n_generations == 1: + return inputs + if hasattr(inputs, "repeat"): + return inputs.repeat(n_generations) + raise TypeError("PrimeIntellectEnv expects verifier datasets to expose repeat()") + + def _extract_example_ids(self, inputs: Any) -> list[str]: + if hasattr(inputs, "column_names") and "id" in inputs.column_names: + return [str(example_id) for example_id in inputs["id"]] + return [f"example_{index}" for index in range(len(inputs))] + + def _validate_generate_outputs(self, result: Any, expected_rollouts: int) -> None: + for field_name in ("prompt", "completion", "state", "reward"): + field_value = getattr(result, field_name, None) + if not isinstance(field_value, list): + raise ValueError(f"PrimeIntellectEnv expected result.{field_name} to be a list") + if len(field_value) != expected_rollouts: + raise ValueError( + f"PrimeIntellectEnv expected {expected_rollouts} {field_name} entries, got {len(field_value)}" + ) - logger.debug(f"Loading Verifiers environment for {env_id} with arguments: {env_args}") + metrics = getattr(result, "metrics", {}) + if metrics is None: + return + if not isinstance(metrics, Mapping): + raise ValueError("PrimeIntellectEnv expected result.metrics to be a mapping") - if env_id not in self.ENVS: - self.ENVS[env_id] = vf.load_environment(env_id=env_id, **env_args) + def _scalarize_metrics(self, raw_metrics: Mapping[str, object]) -> dict[str, float]: + metrics = {} + for metric_name, values in raw_metrics.items(): + metrics[f"{self.env_id}.{metric_name}"] = _scalarize_metric(metric_name, values) + return metrics - return self.ENVS[env_id] + def _extract_phase1_rollout( + self, + rollout_index: int, + prompt: PromptLike, + completion: object, + state: object, + ) -> tuple[list[dict[str, object]], Any]: + if not isinstance(prompt, list): + raise ValueError( + f"PrimeIntellectEnv Phase 1 only supports chat prompts, got {type(prompt).__name__} " + f"for rollout {rollout_index}" + ) + prompt_messages = [dict(message) for message in prompt] + + if not isinstance(completion, list): + raise ValueError( + f"PrimeIntellectEnv Phase 1 only supports chat completions, got {type(completion).__name__} " + f"for rollout {rollout_index}" + ) + + completion_messages: list[dict[str, object]] = [] + for message in completion: + if not isinstance(message, Mapping): + raise TypeError( + f"PrimeIntellectEnv expected completion messages to be mappings, got {type(message).__name__}" + ) + completion_messages.append(dict(message)) + + if any(message.get("role") != "assistant" for message in completion_messages): + raise ValueError("PrimeIntellectEnv Phase 1 does not support non-assistant turns in completions") + if len(completion_messages) != 1: + raise ValueError("PrimeIntellectEnv Phase 1 requires exactly one assistant completion turn") + + assistant_content = completion_messages[0].get("content") + if not isinstance(assistant_content, str): + raise ValueError("PrimeIntellectEnv Phase 1 expects assistant completion content to be a string") + + if not isinstance(state, Mapping): + raise TypeError(f"PrimeIntellectEnv expected rollout state to be a mapping, got {type(state).__name__}") + + responses = state.get("responses") + if not isinstance(responses, list): + raise ValueError("PrimeIntellectEnv Phase 1 expected state['responses'] to be a list") + if len(responses) != 1: + raise ValueError("PrimeIntellectEnv Phase 1 requires exactly one response object per rollout") + + response = responses[0] + if not hasattr(response, "choices"): + raise ValueError("PrimeIntellectEnv Phase 1 expected state['responses'] entries to be ChatCompletion-like") + if len(response.choices) != 1: + raise ValueError("PrimeIntellectEnv Phase 1 requires exactly one assistant choice per rollout") + + choice = response.choices[0] + if choice.message.role != "assistant": + raise ValueError("PrimeIntellectEnv Phase 1 only supports assistant response choices") + if choice.message.content is None: + raise ValueError("PrimeIntellectEnv Phase 1 requires assistant responses with text content") + if choice.message.content != assistant_content: + raise ValueError("PrimeIntellectEnv Phase 1 requires completion messages to match response choices") + + return prompt_messages, choice def sample( self, @@ -83,114 +253,77 @@ def sample( stop: list[str] | None = None, system_prompt: str | None = None, ) -> tuple[list[RolloutGroup], dict[str, float]]: - """Sample problems and generate responses using the model.""" - del prng_key, system_prompt - self._ensure_verifiers_installed() - from verifiers.types import GenerateOutputs - import subprocess + del prng_key - # Download/install the environment - subprocess.run(["prime", "env", "install", self.env_id], check=True) + self._validate_sample_request(mode, system_prompt) + verifier_env = self._load_verifier_env() + self._validate_verifier_env(verifier_env) - # Extract just the env name after slash - env_id = self.env_id.split("/", 1)[-1] - vf_env = self.load_prime_intellect_env(env_id, self.env_args) + base_inputs = self._select_inputs(verifier_env, mode, n_examples) + if base_inputs is None: + raise ValueError(f"PrimeIntellectEnv could not load any inputs for mode {mode!r}") + + example_ids = self._extract_example_ids(base_inputs) + repeated_inputs = self._repeat_inputs(base_inputs, n_generations) + expected_rollouts = len(example_ids) * n_generations - # Prepare sampling arguments sampling_args = { "max_tokens": max_tokens or self.max_tokens, "temperature": temperature, "top_k": top_k, "logprobs": True, "stop": stop, - # Note: return_tokens_as_token_ids is not supported by current vLLM version - # We use convert_tokens_to_ids() in process_vllm_results.py instead } - logger.info( - f"Starting evaluation: n_examples={n_examples}, " - f"n_generations={n_generations}, max_concurrent={self.max_concurrent}" + result = verifier_env.generate( + inputs=repeated_inputs, + client=inference_ctx.openai_client(), + model="marin-model", + sampling_args=sampling_args, + max_concurrent=self.max_concurrent, ) + self._validate_generate_outputs(result, expected_rollouts) - # Get dataset based on mode - if mode == "train": - if vf_env.dataset is None: - raise ValueError(f"Train dataset missing for {env_id}") - inputs = vf_env.get_dataset(n=n_examples) - else: - if vf_env.eval_dataset is None: - raise ValueError(f"Eval dataset missing for {env_id}") - inputs = vf_env.get_eval_dataset(n=n_examples) - - # Repeat inputs for multiple generations - if n_generations > 1: - inputs = inputs.repeat(n_generations) - - result = cast( - GenerateOutputs, - vf_env.generate( - inputs=inputs, - client=inference_ctx.openai_client(), - model="marin-model", - sampling_args=sampling_args, - max_concurrent=self.max_concurrent, - ), - ) + raw_metrics = getattr(result, "metrics", {}) or {} + metrics = self._scalarize_metrics(raw_metrics) - logger.info("Result:") - logger.info(f"Prompt: {result.prompt[0]}") - logger.info(f"Completion: {result.completion[0]}") - logger.info(f"State: {result.state[0]}") - logger.info(f"Reward: {result.reward[0]}") + if expected_rollouts == 0: + metrics[f"{self.env_id}.total_rollouts"] = 0.0 + return [], metrics - # Use custom processing function to handle vLLM output format correctly - processed_outputs = process_vllm_chat_results( - result.prompt, result.completion, result.state, result.reward, inference_ctx.tokenizer - ) - - # Convert to RolloutGroups rollout_groups = [] - for prompt_idx in range(len(processed_outputs.prompt_ids)): - rollouts = [] - for gen_idx in range(n_generations): - overall_idx = prompt_idx * n_generations + gen_idx - if overall_idx >= len(processed_outputs.completion_ids): - break - - reward = processed_outputs.rewards[overall_idx] if overall_idx < len(processed_outputs.rewards) else 0.0 - - # Use tokenizer from inference context - prompt_tokens = processed_outputs.prompt_ids[prompt_idx] - response_tokens = processed_outputs.completion_ids[overall_idx] - - token_rewards = jnp.full(len(response_tokens), reward, dtype=jnp.float32) - response_logprobs = processed_outputs.completion_logprobs[overall_idx] - - rollout = Rollout( - env_name=f"prime_intellect:{self.env_id}", - env_example_id=f"{self.env_id}_example_{prompt_idx}", - prompt_tokens=jnp.array(prompt_tokens, dtype=jnp.int32), - response_tokens=jnp.array(response_tokens, dtype=jnp.int32), - response_logprobs=response_logprobs, - token_rewards=token_rewards, - episode_reward=float(reward), + reward_sum = 0.0 + + # Dataset.repeat(n) orders rows as [ex0, ex1, ex0, ex1, ...], so regroup + # by generation first and example second. + n_sampled_examples = len(example_ids) + for example_index, example_id in enumerate(example_ids): + group_rollouts = [] + for generation_index in range(n_generations): + rollout_index = generation_index * n_sampled_examples + example_index + prompt_messages, choice = self._extract_phase1_rollout( + rollout_index=rollout_index, + prompt=result.prompt[rollout_index], + completion=result.completion[rollout_index], + state=result.state[rollout_index], + ) + reward = float(result.reward[rollout_index]) + rollout = inference_ctx.create_rollout_from_choice( + prompt=prompt_messages, + choice=choice, + env_name=f"{_ENV_NAME_PREFIX}{self.env_id}", + env_example_id=f"{self.env_id}:{example_id}", + reward=reward, temperature=temperature, top_k=top_k, - is_truncated=False, # prime intellect doesn't seem to report this easily ) - rollouts.append(rollout) + group_rollouts.append(rollout) + reward_sum += reward - if rollouts: - rollout_groups.append(RolloutGroup(rollouts=rollouts)) - - # Extract metrics - metrics = {} - if hasattr(result, "metrics") and result.metrics: - metrics.update(result.metrics) + rollout_groups.append(RolloutGroup(rollouts=group_rollouts)) - if result.reward: - metrics[f"{self.env_id}.mean_reward"] = float(np.mean(result.reward)) - metrics[f"{self.env_id}.total_rollouts"] = len(result.reward) + metrics[f"{self.env_id}.mean_reward"] = reward_sum / expected_rollouts + metrics[f"{self.env_id}.total_rollouts"] = float(expected_rollouts) - logger.info(f"Generated {len(rollout_groups)} rollout groups") + logger.info("Generated %d rollout groups for %s", len(rollout_groups), self.env_id) return rollout_groups, metrics diff --git a/lib/marin/src/marin/rl/environments/process_vllm_results.py b/lib/marin/src/marin/rl/environments/process_vllm_results.py deleted file mode 100644 index 23d1df9015..0000000000 --- a/lib/marin/src/marin/rl/environments/process_vllm_results.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright The Marin Authors -# SPDX-License-Identifier: Apache-2.0 - -"""Custom processing functions for vLLM environment results.""" - -import logging -from dataclasses import dataclass -from typing import Any - -from openai.types.chat import ChatCompletion - -logger = logging.getLogger(__name__) - - -@dataclass -class ProcessedVLLMOutputs: - """Processed outputs from vLLM environment.""" - - prompt_ids: list[list[int]] - """Tokenized prompts, one per example.""" - - completion_ids: list[list[int]] - """Tokenized completions, one per generation.""" - - completion_logprobs: list[list[float]] - """Log probabilities for completion tokens, one per generation.""" - - rewards: list[float] - """Reward for each generation.""" - - -def parse_chat_completion_tokens_from_bytes(chat_completion: ChatCompletion, tokenizer: Any) -> list[int]: - """ - Parse token IDs from chat completion. - - vLLM returns tokens as their string representations (from convert_ids_to_tokens()). - We need to convert them back via the vocabulary to get the correct token IDs. - - Args: - chat_completion: ChatCompletion object from vLLM - tokenizer: Tokenizer (MarinTokenizer) to use for converting token strings to IDs - - Returns: - List of token IDs - """ - assert len(chat_completion.choices) == 1, f"Expected 1 choice, got {len(chat_completion.choices)}: {chat_completion}" - assert chat_completion.choices[0].logprobs is not None, f"Logprobs should not be None: {chat_completion}" - assert ( - chat_completion.choices[0].logprobs.content is not None - ), f"Logprob content should not be None: {chat_completion}" - - vocab = tokenizer.get_vocab() - tokens = [] - logprob_content = chat_completion.choices[0].logprobs.content - - for token_logprob in logprob_content: - token_str = token_logprob.token - - # Case 1: Token is in format "token_id:" (when return_tokens_as_token_ids=True works) - if token_str.startswith("token_id:"): - try: - token_id = int(token_str.split(":", 1)[1]) - tokens.append(token_id) - continue - except (ValueError, IndexError): - pass - - # Case 2: Token is a string representation (the standard case with vLLM) - # Look up in vocab for correct BPE round-trip - token_id = vocab.get(token_str) - if token_id is None: - logger.warning(f"Token '{token_str}' not found in vocabulary, may indicate tokenizer mismatch") - tokens.append(0) - else: - tokens.append(token_id) - - return tokens - - -def parse_chat_completion_logprobs(chat_completion: ChatCompletion) -> list[float]: - """ - Parse log probabilities from chat completion. - - Args: - chat_completion: ChatCompletion object from vLLM - - Returns: - List of log probabilities - """ - assert len(chat_completion.choices) == 1, "Response should always have one choice" - assert chat_completion.choices[0].logprobs is not None, "Logprobs should not be None" - assert chat_completion.choices[0].logprobs.content is not None, "Logprob content should not be None" - - logprobs = [logprob.logprob for logprob in chat_completion.choices[0].logprobs.content] - return logprobs - - -def process_vllm_chat_results( - prompts: list[list[dict[str, str]]], - completions: list[list[dict[str, str]]], - states: list[dict[str, Any]], - rewards: list[float], - tokenizer: Any, -) -> ProcessedVLLMOutputs: - """ - Process vLLM results for chat format conversations. - - Args: - prompts: List of chat message prompts - completions: List of chat message completions - states: List of state dicts containing responses - rewards: List of rewards - tokenizer: Tokenizer to use for encoding - - Returns: - ProcessedVLLMOutputs with tokenized data and logprobs - """ - all_prompt_ids = [] - all_completion_ids = [] - all_completion_logprobs = [] - all_rewards = [] - - for idx, (prompt, completion, state, reward) in enumerate(zip(prompts, completions, states, rewards, strict=False)): - try: - # Tokenize the prompt using chat template - prompt_ids = tokenizer.apply_chat_template( - conversation=prompt, - add_generation_prompt=True, - ) - - # Extract responses from state - responses = state.get("responses", []) - - # Process completion messages and extract tokens/logprobs from responses - completion_ids = [] - completion_logprobs = [] - - response_idx = 0 - for msg_idx, message in enumerate(completion): - # This is a model-generated response - if response_idx < len(responses): - response = responses[response_idx] - - # Parse tokens and logprobs from the response - tokens = parse_chat_completion_tokens_from_bytes(response, tokenizer) - logprobs = parse_chat_completion_logprobs(response) - - completion_ids.extend(tokens) - completion_logprobs.extend(logprobs) - logger.debug(f"Example {idx}, message {msg_idx}: Parsed {len(tokens)} tokens") - - response_idx += 1 - else: - # No response available, tokenize the message content - logger.warning( - f"Example {idx}, message {msg_idx}: No response available, using fallback tokenization" - ) - content = message.get("content", "") - tokens = tokenizer.encode(content, add_special_tokens=False) - completion_ids.extend(tokens) - completion_logprobs.extend([0.0] * len(tokens)) - - all_prompt_ids.append(prompt_ids) - all_completion_ids.append(completion_ids) - all_completion_logprobs.append(completion_logprobs) - all_rewards.append(reward) - - except Exception as e: - logger.error(f"Example {idx}: Failed to process: {e}", exc_info=True) - # Skip this example - continue - - logger.info(f"Processed {len(all_prompt_ids)} examples successfully out of {len(prompts)} total") - - return ProcessedVLLMOutputs( - prompt_ids=all_prompt_ids, - completion_ids=all_completion_ids, - completion_logprobs=all_completion_logprobs, - rewards=all_rewards, - ) diff --git a/tests/rl/environments/test_prime_intellect_env.py b/tests/rl/environments/test_prime_intellect_env.py index 20abbdc3a5..176c990339 100644 --- a/tests/rl/environments/test_prime_intellect_env.py +++ b/tests/rl/environments/test_prime_intellect_env.py @@ -1,168 +1,580 @@ # Copyright The Marin Authors # SPDX-License-Identifier: Apache-2.0 -"""Tests for PrimeIntellectEnv integration with verifiers library.""" +"""Tests for the Phase 1 PrimeIntellectEnv verifier adapter.""" -from unittest.mock import AsyncMock, patch +from collections import defaultdict +from dataclasses import dataclass +import sys +from types import ModuleType, SimpleNamespace +from unittest.mock import Mock -import jax.random -import numpy as np import pytest +from datasets import Dataset +from openai import AsyncOpenAI from openai.types.chat import ChatCompletion, ChatCompletionMessage -from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion import ChatCompletionTokenLogprob, Choice, ChoiceLogprobs from openai.types.completion_usage import CompletionUsage -from levanter.tokenizers import load_tokenizer -try: - import verifiers as vf +from marin.rl.environments.inference_ctx import ( + LevanterInferenceContext, + LevanterInferenceContextConfig, + VLLMSamplingConfig, + vLLMInferenceContext, + vLLMInferenceContextConfig, +) +from marin.rl.environments.inference_ctx.openai_compat import OpenAICompatClient +from marin.rl.environments.inference_ctx.vllm import InferenceMode +from marin.rl.environments.prime_intellect_env import PrimeIntellectEnv + + +@dataclass +class DummyInferenceServer: + """Minimal inference server for Levanter OpenAI client construction.""" + + host: str = "localhost" + port: int = 8000 + + def address(self) -> str: + return f"{self.host}:{self.port}" + + @property + def config(self): + @dataclass + class Config: + model_name: str = "test-model" + + return Config() + + +class FakeVerifierEnv: + """Verifier env test double with real Dataset behavior.""" + + def __init__( + self, + dataset: Dataset, + eval_dataset: Dataset | None = None, + *, + message_type: str = "chat", + oai_tools: list[dict[str, object]] | None = None, + generate_result_factory=None, + ): + self.dataset = dataset + self.eval_dataset = eval_dataset + self.message_type = message_type + self.oai_tools = oai_tools + self.generate_result_factory = generate_result_factory + self.generate_calls: list[dict[str, object]] = [] + + def get_dataset(self, n: int = -1): + if n > 0: + return self.dataset.select(range(min(n, len(self.dataset)))) + return self.dataset + + def get_eval_dataset(self, n: int = -1): + dataset = self.eval_dataset if self.eval_dataset is not None else self.dataset + if n > 0: + return dataset.select(range(min(n, len(dataset)))) + return dataset + + def generate(self, *, inputs, client, model, sampling_args, max_concurrent): + self.generate_calls.append( + { + "inputs": inputs, + "client": client, + "model": model, + "sampling_args": dict(sampling_args), + "max_concurrent": max_concurrent, + } + ) + return self.generate_result_factory(inputs=inputs) + - from marin.rl.environments.prime_intellect_env import PrimeIntellectEnv -except ImportError: - pytest.skip("verifiers library not installed", allow_module_level=True) +def _prompt_dataset(example_ids: list[str], prefix: str) -> Dataset: + return Dataset.from_dict( + { + "id": example_ids, + "prompt": [[{"role": "user", "content": f"{prefix} prompt {example_id}"}] for example_id in example_ids], + "answer": [""] * len(example_ids), + } + ) -from marin.rl.environments.inference_ctx import LevanterInferenceContext +def _chat_completion(tokenizer, response_text: str, prompt_token_ids: list[int]) -> ChatCompletion: + response_token_ids = tokenizer.encode(response_text, add_special_tokens=False) + logprobs_content = [ + ChatCompletionTokenLogprob( + token=tokenizer.convert_ids_to_tokens(token_id), + logprob=-0.1 * (index + 1), + bytes=None, + top_logprobs=[], + ) + for index, token_id in enumerate(response_token_ids) + ] -def create_mock_chat_completion(response_text: str = "42") -> ChatCompletion: - """Create a mock ChatCompletion for testing.""" - return ChatCompletion( + completion = ChatCompletion( id="chatcmpl-test", choices=[ Choice( finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=response_text), - logprobs=None, + logprobs=ChoiceLogprobs(content=logprobs_content), ) ], created=1234567890, model="test-model", object="chat.completion", usage=CompletionUsage( - completion_tokens=len(response_text.split()), prompt_tokens=10, total_tokens=10 + len(response_text.split()) + completion_tokens=len(response_token_ids), + prompt_tokens=len(prompt_token_ids), + total_tokens=len(prompt_token_ids) + len(response_token_ids), ), ) + completion.choices[0].prompt_token_ids = prompt_token_ids + completion.choices[0].response_token_ids = response_token_ids + return completion + + +def _single_turn_generate_outputs(inputs, tokenizer, *, metric_scale: float = 1.0): + counts_by_example_id: dict[str, int] = defaultdict(int) + prompts = [] + completions = [] + states = [] + rewards = [] + metric_values = [] + + for rollout_index, row in enumerate(inputs): + example_id = str(row["id"]) + generation_index = counts_by_example_id[example_id] + counts_by_example_id[example_id] += 1 + + response_text = f"resp-{example_id}-{generation_index}" + prompt_token_ids = [100 + rollout_index, 200 + rollout_index] + response = _chat_completion(tokenizer, response_text, prompt_token_ids) + + prompts.append(row["prompt"]) + completions.append([{"role": "assistant", "content": response_text}]) + states.append({"responses": [response]}) + rewards.append(float(rollout_index + 1)) + metric_values.append(metric_scale * float((rollout_index + 1) * 10)) + + return SimpleNamespace( + prompt=prompts, + completion=completions, + state=states, + reward=rewards, + metrics={"score": metric_values}, + ) -@pytest.fixture -def tokenizer(): - """Create a real tokenizer for testing.""" - return load_tokenizer("gpt2") +def _install_fake_verifiers(monkeypatch, loader): + fake_verifiers = ModuleType("verifiers") + load_calls: list[tuple[str, dict[str, object]]] = [] + def load_environment(env_id: str, **env_args): + load_calls.append((env_id, dict(env_args))) + return loader(env_id, env_args) -@pytest.fixture -def inference_ctx(tokenizer): - """Create a real inference context with mock openai_client.""" + fake_verifiers.load_environment = load_environment + monkeypatch.setitem(sys.modules, "verifiers", fake_verifiers) + return load_calls + + +def _levanter_inference_ctx(gpt2_tokenizer): + ctx = LevanterInferenceContext( + LevanterInferenceContextConfig( + inference_server_config=None, + tokenizer=gpt2_tokenizer, + stop_tokens=None, + max_tokens=128, + mesh=None, + axis_mapping={}, + ) + ) + ctx._inference_server = DummyInferenceServer() + return ctx - class DummyInferenceContext(LevanterInferenceContext): - def __init__(self, tokenizer): - self.tokenizer = tokenizer - self._stop_tokens = None - self.max_tokens = 1024 - def openai_client(self): - """Return a mock AsyncOpenAI client that returns proper ChatCompletion objects.""" - mock_client = AsyncMock() - # Configure the mock to return a proper ChatCompletion - mock_client.chat.completions.create = AsyncMock(return_value=create_mock_chat_completion()) - return mock_client +def _vllm_inference_ctx(monkeypatch, gpt2_tokenizer): + monkeypatch.setattr( + vLLMInferenceContext, + "_get_llm_engine", + staticmethod(lambda _config: object()), + ) + monkeypatch.setattr( + "marin.rl.environments.inference_ctx.vllm.load_tokenizer", + lambda _path: gpt2_tokenizer, + ) + monkeypatch.setattr( + vLLMInferenceContext, + "_get_renderer", + staticmethod(lambda _model_name, _tokenizer: object()), + ) + + return vLLMInferenceContext( + vLLMInferenceContextConfig( + model_name="test-model", + canonical_model_name="meta-llama/Llama-3.1-8B-Instruct", + max_model_len=1024, + tensor_parallel_size=1, + gpu_memory_utilization=0.9, + sampling_params=VLLMSamplingConfig(), + mode=InferenceMode.SYNC, + ) + ) - return DummyInferenceContext(tokenizer) + +@pytest.fixture(autouse=True) +def clear_prime_intellect_env_caches(): + PrimeIntellectEnv.INSTALLED_ENV_IDS.clear() + PrimeIntellectEnv.LOADED_ENVIRONMENTS.clear() + yield + PrimeIntellectEnv.INSTALLED_ENV_IDS.clear() + PrimeIntellectEnv.LOADED_ENVIRONMENTS.clear() @pytest.fixture -def vf_env(): - """Create a real verifiers SingleTurnEnv with example dataset.""" - dataset = vf.load_example_dataset("gsm8k", n=2) - return vf.SingleTurnEnv(dataset=dataset) +def prime_cli(monkeypatch): + subprocess_run = Mock() + monkeypatch.setattr("marin.rl.environments.prime_intellect_env.shutil.which", lambda executable: "/usr/bin/prime") + monkeypatch.setattr("marin.rl.environments.prime_intellect_env.subprocess.run", subprocess_run) + return subprocess_run + + +def test_prime_intellect_env_sample_supports_levanter_single_turn_chat(monkeypatch, prime_cli, gpt2_tokenizer): + train_dataset = _prompt_dataset(["train-0", "train-1"], "train") + eval_dataset = _prompt_dataset(["eval-0", "eval-1"], "eval") + verifier_env = FakeVerifierEnv( + train_dataset, + eval_dataset, + generate_result_factory=lambda *, inputs: _single_turn_generate_outputs( + inputs, gpt2_tokenizer, metric_scale=0.5 + ), + ) + load_calls = _install_fake_verifiers(monkeypatch, lambda _env_id, _env_args: verifier_env) + env = PrimeIntellectEnv( + env_id="primeintellect/gsm8k", + env_args={"difficulty": "easy"}, + max_tokens=128, + max_concurrent=7, + ) + inference_ctx = _levanter_inference_ctx(gpt2_tokenizer) + + env.prepare() + rollout_groups, metrics = env.sample( + inference_ctx=inference_ctx, + n_examples=2, + n_generations=2, + temperature=0.7, + prng_key=None, + mode="eval", + top_k=11, + stop=[""], + ) + assert prime_cli.call_count == 1 + assert prime_cli.call_args.args == (["/usr/bin/prime", "env", "install", "primeintellect/gsm8k"],) + assert prime_cli.call_args.kwargs == {"check": True} + assert load_calls == [("gsm8k", {"difficulty": "easy"})] + assert isinstance(verifier_env.generate_calls[0]["client"], AsyncOpenAI) + assert verifier_env.generate_calls[0]["model"] == "marin-model" + assert verifier_env.generate_calls[0]["max_concurrent"] == 7 + assert verifier_env.generate_calls[0]["sampling_args"] == { + "max_tokens": 128, + "temperature": 0.7, + "top_k": 11, + "logprobs": True, + "stop": [""], + } + + assert [rollout.env_example_id for rollout in rollout_groups[0].rollouts] == [ + "primeintellect/gsm8k:eval-0", + "primeintellect/gsm8k:eval-0", + ] + assert [rollout.env_example_id for rollout in rollout_groups[1].rollouts] == [ + "primeintellect/gsm8k:eval-1", + "primeintellect/gsm8k:eval-1", + ] + assert [gpt2_tokenizer.decode(rollout.response_tokens.tolist()) for rollout in rollout_groups[0].rollouts] == [ + "resp-eval-0-0", + "resp-eval-0-1", + ] + assert [gpt2_tokenizer.decode(rollout.response_tokens.tolist()) for rollout in rollout_groups[1].rollouts] == [ + "resp-eval-1-0", + "resp-eval-1-1", + ] + assert all( + rollout.env_name == "prime_intellect:primeintellect/gsm8k" + for group in rollout_groups + for rollout in group.rollouts + ) + assert metrics == { + "primeintellect/gsm8k.score": pytest.approx(12.5), + "primeintellect/gsm8k.mean_reward": pytest.approx(2.5), + "primeintellect/gsm8k.total_rollouts": 4.0, + } + + +def test_prime_intellect_env_sample_supports_vllm_single_turn_chat(monkeypatch, prime_cli, gpt2_tokenizer): + train_dataset = _prompt_dataset(["0", "1"], "train") + verifier_env = FakeVerifierEnv( + train_dataset, + generate_result_factory=lambda *, inputs: _single_turn_generate_outputs(inputs, gpt2_tokenizer), + ) + load_calls = _install_fake_verifiers(monkeypatch, lambda _env_id, _env_args: verifier_env) + env = PrimeIntellectEnv(env_id="primeintellect/gsm8k", max_tokens=64) + inference_ctx = _vllm_inference_ctx(monkeypatch, gpt2_tokenizer) + + env.prepare() + rollout_groups, metrics = env.sample( + inference_ctx=inference_ctx, + n_examples=2, + n_generations=2, + temperature=0.2, + prng_key=None, + mode="train", + ) + + assert load_calls == [("gsm8k", {})] + assert isinstance(verifier_env.generate_calls[0]["client"], OpenAICompatClient) + assert [rollout.prompt_tokens.tolist() for rollout in rollout_groups[0].rollouts] == [[100, 200], [102, 202]] + assert [rollout.prompt_tokens.tolist() for rollout in rollout_groups[1].rollouts] == [[101, 201], [103, 203]] + assert [gpt2_tokenizer.decode(rollout.response_tokens.tolist()) for rollout in rollout_groups[0].rollouts] == [ + "resp-0-0", + "resp-0-1", + ] + assert [gpt2_tokenizer.decode(rollout.response_tokens.tolist()) for rollout in rollout_groups[1].rollouts] == [ + "resp-1-0", + "resp-1-1", + ] + assert metrics["primeintellect/gsm8k.score"] == pytest.approx(25.0) + assert metrics["primeintellect/gsm8k.mean_reward"] == pytest.approx(2.5) + assert metrics["primeintellect/gsm8k.total_rollouts"] == 4.0 + + +def test_prime_intellect_env_prepare_installs_once_per_env_id(monkeypatch, prime_cli, gpt2_tokenizer): + load_calls = _install_fake_verifiers( + monkeypatch, + lambda _env_id, _env_args: FakeVerifierEnv( + _prompt_dataset(["0"], "train"), + generate_result_factory=lambda *, inputs: _single_turn_generate_outputs(inputs, gpt2_tokenizer), + ), + ) + env_one = PrimeIntellectEnv(env_id="primeintellect/gsm8k", env_args={"difficulty": "easy"}) + env_two = PrimeIntellectEnv(env_id="primeintellect/gsm8k", env_args={"difficulty": "hard"}) + + env_one.prepare() + env_two.prepare() + + assert prime_cli.call_count == 1 + assert load_calls == [] + + +def test_prime_intellect_env_load_cache_keys_include_env_args(monkeypatch, prime_cli, gpt2_tokenizer): + load_calls = [] + + def loader(_env_id: str, env_args: dict[str, object]): + load_calls.append(env_args) + return FakeVerifierEnv( + _prompt_dataset([str(env_args["difficulty"])], "train"), + generate_result_factory=lambda *, inputs: _single_turn_generate_outputs(inputs, gpt2_tokenizer), + ) + + _install_fake_verifiers(monkeypatch, loader) + inference_ctx = _levanter_inference_ctx(gpt2_tokenizer) -def test_prime_intellect_env_sample(tokenizer, inference_ctx, vf_env): - """Test sampling from PrimeIntellectEnv with real components.""" - env = PrimeIntellectEnv(env_id="primeintellect/gsm8k", env_args={}, max_tokens=1024, max_concurrent=32) + easy_one = PrimeIntellectEnv(env_id="primeintellect/gsm8k", env_args={"difficulty": "easy"}) + easy_two = PrimeIntellectEnv(env_id="primeintellect/gsm8k", env_args={"difficulty": "easy"}) + hard = PrimeIntellectEnv(env_id="primeintellect/gsm8k", env_args={"difficulty": "hard"}) - # Patch only external dependencies - with patch.object(env, "load_prime_intellect_env", return_value=vf_env), patch("subprocess.run"): - prng_key = jax.random.PRNGKey(0) - rollout_groups, metrics = env.sample( + for env in (easy_one, easy_two, hard): + env.prepare() + env.sample( inference_ctx=inference_ctx, - n_examples=2, - n_generations=2, - temperature=0.7, - prng_key=prng_key, - mode="train", + n_examples=1, + n_generations=1, + temperature=1.0, + prng_key=None, ) - # Verify rollout groups structure - assert len(rollout_groups) == 2, "Should have 2 rollout groups (one per prompt)" + assert load_calls == [{"difficulty": "easy"}, {"difficulty": "hard"}] - for group in rollout_groups: - assert len(group.rollouts) == 2, "Each group should have 2 rollouts" - for rollout in group.rollouts: - assert rollout.env_name == "prime_intellect:primeintellect/gsm8k" - assert rollout.env_example_id.startswith("primeintellect/gsm8k_example_") - assert len(rollout.prompt_tokens) > 0 - assert len(rollout.response_tokens) > 0 - assert len(rollout.response_logprobs) == len(rollout.response_tokens) - assert len(rollout.token_rewards) == len(rollout.response_tokens) - assert 0.0 <= rollout.episode_reward <= 1.0 +def test_prime_intellect_env_prepare_rejects_non_primeintellect_ids(monkeypatch, prime_cli): + _install_fake_verifiers(monkeypatch, lambda _env_id, _env_args: None) + env = PrimeIntellectEnv(env_id="someone-else/gsm8k") - # Verify tokens are valid int32 - assert rollout.prompt_tokens.dtype == np.int32 - assert rollout.response_tokens.dtype == np.int32 + with pytest.raises(ValueError, match="only supports 'primeintellect/\\*' IDs"): + env.prepare() - # Verify metrics exist - assert "primeintellect/gsm8k.mean_reward" in metrics - assert "primeintellect/gsm8k.total_rollouts" in metrics + assert prime_cli.call_count == 0 -def test_prime_intellect_env_openai_client_called(tokenizer, inference_ctx, vf_env): - """Test that the OpenAI client is properly requested from inference context.""" - env = PrimeIntellectEnv(env_id="primeintellect/gsm8k", env_args={}) +def test_prime_intellect_env_prepare_requires_prime_cli(monkeypatch): + _install_fake_verifiers(monkeypatch, lambda _env_id, _env_args: None) + monkeypatch.setattr("marin.rl.environments.prime_intellect_env.shutil.which", lambda executable: None) + env = PrimeIntellectEnv(env_id="primeintellect/gsm8k") - with patch.object(env, "load_prime_intellect_env", return_value=vf_env), patch("subprocess.run"): - prng_key = jax.random.PRNGKey(42) - rollout_groups, _ = env.sample( - inference_ctx=inference_ctx, + with pytest.raises(RuntimeError, match="requires the 'prime' executable"): + env.prepare() + + +def test_prime_intellect_env_sample_rejects_invalid_mode(monkeypatch, prime_cli, gpt2_tokenizer): + verifier_env = FakeVerifierEnv( + _prompt_dataset(["0"], "train"), + generate_result_factory=lambda *, inputs: _single_turn_generate_outputs(inputs, gpt2_tokenizer), + ) + load_calls = _install_fake_verifiers(monkeypatch, lambda _env_id, _env_args: verifier_env) + env = PrimeIntellectEnv(env_id="primeintellect/gsm8k") + + env.prepare() + with pytest.raises(ValueError, match="Unsupported mode"): + env.sample( + inference_ctx=_levanter_inference_ctx(gpt2_tokenizer), + n_examples=1, + n_generations=1, + temperature=1.0, + prng_key=None, + mode="debug", + ) + + assert load_calls == [] + + +def test_prime_intellect_env_sample_rejects_system_prompt(monkeypatch, prime_cli, gpt2_tokenizer): + verifier_env = FakeVerifierEnv( + _prompt_dataset(["0"], "train"), + generate_result_factory=lambda *, inputs: _single_turn_generate_outputs(inputs, gpt2_tokenizer), + ) + load_calls = _install_fake_verifiers(monkeypatch, lambda _env_id, _env_args: verifier_env) + env = PrimeIntellectEnv(env_id="primeintellect/gsm8k") + + env.prepare() + with pytest.raises(ValueError, match="does not support Marin-level system prompts"): + env.sample( + inference_ctx=_levanter_inference_ctx(gpt2_tokenizer), n_examples=1, n_generations=1, temperature=1.0, - prng_key=prng_key, + prng_key=None, + system_prompt="You are helpful.", ) - # Verify we got rollout groups (which means generate() was called successfully) - assert len(rollout_groups) >= 0 + assert load_calls == [] -def test_prime_intellect_env_tokenization(tokenizer, inference_ctx, vf_env): - """Test that prompts include chat template and completions are properly tokenized.""" - env = PrimeIntellectEnv(env_id="primeintellect/tokenize", env_args={}) +def test_prime_intellect_env_sample_rejects_non_chat_verifier_env(monkeypatch, prime_cli, gpt2_tokenizer): + verifier_env = FakeVerifierEnv( + _prompt_dataset(["0"], "train"), + message_type="completion", + generate_result_factory=lambda *, inputs: _single_turn_generate_outputs(inputs, gpt2_tokenizer), + ) + _install_fake_verifiers(monkeypatch, lambda _env_id, _env_args: verifier_env) + env = PrimeIntellectEnv(env_id="primeintellect/gsm8k") - with patch.object(env, "load_prime_intellect_env", return_value=vf_env), patch("subprocess.run"): - prng_key = jax.random.PRNGKey(123) - rollout_groups, _ = env.sample( - inference_ctx=inference_ctx, - n_examples=2, - n_generations=2, - temperature=0.8, - prng_key=prng_key, + env.prepare() + with pytest.raises(ValueError, match="only supports chat-format verifier environments"): + env.sample( + inference_ctx=_levanter_inference_ctx(gpt2_tokenizer), + n_examples=1, + n_generations=1, + temperature=1.0, + prng_key=None, + ) + + +def test_prime_intellect_env_sample_rejects_tool_enabled_verifier_env(monkeypatch, prime_cli, gpt2_tokenizer): + verifier_env = FakeVerifierEnv( + _prompt_dataset(["0"], "train"), + oai_tools=[{"type": "function"}], + generate_result_factory=lambda *, inputs: _single_turn_generate_outputs(inputs, gpt2_tokenizer), + ) + _install_fake_verifiers(monkeypatch, lambda _env_id, _env_args: verifier_env) + env = PrimeIntellectEnv(env_id="primeintellect/gsm8k") + + env.prepare() + with pytest.raises(ValueError, match="does not support tool-enabled verifier environments"): + env.sample( + inference_ctx=_levanter_inference_ctx(gpt2_tokenizer), + n_examples=1, + n_generations=1, + temperature=1.0, + prng_key=None, + ) + + +def test_prime_intellect_env_sample_rejects_non_assistant_completion_turns(monkeypatch, prime_cli, gpt2_tokenizer): + dataset = _prompt_dataset(["0"], "train") + + def generate_result_factory(*, inputs): + output = _single_turn_generate_outputs(inputs, gpt2_tokenizer) + output.completion[0] = [ + {"role": "assistant", "content": "resp-0-0"}, + {"role": "user", "content": "tool feedback"}, + ] + return output + + verifier_env = FakeVerifierEnv(dataset, generate_result_factory=generate_result_factory) + _install_fake_verifiers(monkeypatch, lambda _env_id, _env_args: verifier_env) + env = PrimeIntellectEnv(env_id="primeintellect/gsm8k") + + env.prepare() + with pytest.raises(ValueError, match="does not support non-assistant turns in completions"): + env.sample( + inference_ctx=_levanter_inference_ctx(gpt2_tokenizer), + n_examples=1, + n_generations=1, + temperature=1.0, + prng_key=None, + ) + + +def test_prime_intellect_env_sample_rejects_multiple_assistant_completion_turns(monkeypatch, prime_cli, gpt2_tokenizer): + dataset = _prompt_dataset(["0"], "train") + + def generate_result_factory(*, inputs): + output = _single_turn_generate_outputs(inputs, gpt2_tokenizer) + output.completion[0] = [ + {"role": "assistant", "content": "resp-0-0"}, + {"role": "assistant", "content": "resp-0-1"}, + ] + return output + + verifier_env = FakeVerifierEnv(dataset, generate_result_factory=generate_result_factory) + _install_fake_verifiers(monkeypatch, lambda _env_id, _env_args: verifier_env) + env = PrimeIntellectEnv(env_id="primeintellect/gsm8k") + + env.prepare() + with pytest.raises(ValueError, match="requires exactly one assistant completion turn"): + env.sample( + inference_ctx=_levanter_inference_ctx(gpt2_tokenizer), + n_examples=1, + n_generations=1, + temperature=1.0, + prng_key=None, ) - # Verify tokenization and chat template application - for group in rollout_groups: - for rollout in group.rollouts: - # Tokens should be valid int32 arrays - assert rollout.prompt_tokens.dtype == np.int32 - assert rollout.response_tokens.dtype == np.int32 - # Decode the prompt tokens - decoded_prompt = tokenizer.decode(rollout.prompt_tokens.tolist()) - decoded_response = tokenizer.decode(rollout.response_tokens.tolist()) +def test_prime_intellect_env_sample_rejects_multiple_response_objects(monkeypatch, prime_cli, gpt2_tokenizer): + dataset = _prompt_dataset(["0"], "train") - # Verify chat template was applied: prompt tokens should contain "user:" prefix - # from the fallback chat template (GPT-2 doesn't have a native chat template) - assert "user:" in decoded_prompt.lower(), "Chat template should add 'user:' prefix to prompt" + def generate_result_factory(*, inputs): + output = _single_turn_generate_outputs(inputs, gpt2_tokenizer) + response = output.state[0]["responses"][0] + output.state[0] = {"responses": [response, response]} + return output - # Verify response is non-empty - assert len(decoded_response) > 0 + verifier_env = FakeVerifierEnv(dataset, generate_result_factory=generate_result_factory) + _install_fake_verifiers(monkeypatch, lambda _env_id, _env_args: verifier_env) + env = PrimeIntellectEnv(env_id="primeintellect/gsm8k") + + env.prepare() + with pytest.raises(ValueError, match="requires exactly one response object per rollout"): + env.sample( + inference_ctx=_levanter_inference_ctx(gpt2_tokenizer), + n_examples=1, + n_generations=1, + temperature=1.0, + prng_key=None, + ) diff --git a/tests/rl/environments/test_process_vllm_results.py b/tests/rl/environments/test_process_vllm_results.py deleted file mode 100644 index 2a7ab911d9..0000000000 --- a/tests/rl/environments/test_process_vllm_results.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright The Marin Authors -# SPDX-License-Identifier: Apache-2.0 - -"""Tests for vLLM result processing functions.""" - -import pytest -from openai.types.chat import ChatCompletion, ChatCompletionMessage -from openai.types.chat.chat_completion import Choice, ChoiceLogprobs, ChatCompletionTokenLogprob -from openai.types.completion_usage import CompletionUsage - -from marin.rl.environments.process_vllm_results import ( - parse_chat_completion_tokens_from_bytes, - parse_chat_completion_logprobs, -) - - -@pytest.fixture -def tokenizer(gpt2_tokenizer): - """Alias the session-scoped local GPT-2 tokenizer.""" - return gpt2_tokenizer - - -def create_mock_chat_completion_with_logprobs( - response_text: str, token_strs: list[str], logprobs: list[float] -) -> ChatCompletion: - """Create a mock ChatCompletion with logprobs for testing.""" - assert len(token_strs) == len(logprobs), "Token strings and logprobs must have same length" - - logprob_content = [ - ChatCompletionTokenLogprob( - token=token_str, - logprob=logprob, - bytes=list(token_str.encode("utf-8")), - top_logprobs=[], - ) - for token_str, logprob in zip(token_strs, logprobs, strict=False) - ] - - return ChatCompletion( - id="chatcmpl-test", - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage(role="assistant", content=response_text), - logprobs=ChoiceLogprobs(content=logprob_content, refusal=None), - ) - ], - created=1234567890, - model="test-model", - object="chat.completion", - usage=CompletionUsage(completion_tokens=len(token_strs), prompt_tokens=10, total_tokens=10 + len(token_strs)), - ) - - -def test_parse_chat_completion_tokens_special_tokens(tokenizer): - """Test token parsing with special tokens like newlines.""" - # Ċ is how GPT-2 represents newlines in BPE - # "Hello" and "World" are not raw vocab entries in byte-level BPE, so they - # fall back to 0 (unknown). Ċ (newline byte) is a real vocab token. - token_strs = ["Hello", "Ċ", "World"] - logprobs = [-0.1, -0.2, -0.3] - - chat_completion = create_mock_chat_completion_with_logprobs("Hello\nWorld", token_strs, logprobs) - - parsed_tokens = parse_chat_completion_tokens_from_bytes(chat_completion, tokenizer) - - assert len(parsed_tokens) == 3 - assert all(isinstance(t, int) for t in parsed_tokens) - - vocab = tokenizer.get_vocab() - expected_ids = [vocab.get(t, 0) for t in token_strs] - assert parsed_tokens == expected_ids - - -def test_parse_chat_completion_tokens_token_id_format(tokenizer): - """Test token parsing when tokens are in 'token_id:' format.""" - # Some vLLM configurations might return tokens in this format - token_strs = ["token_id:123", "token_id:456", "token_id:789"] - logprobs = [-0.1, -0.2, -0.3] - - chat_completion = create_mock_chat_completion_with_logprobs("test", token_strs, logprobs) - - parsed_tokens = parse_chat_completion_tokens_from_bytes(chat_completion, tokenizer) - - assert parsed_tokens == [123, 456, 789] - - -def test_parse_chat_completion_logprobs(tokenizer): - """Test logprob extraction.""" - token_strs = ["Hello", "Ġworld"] - logprobs = [-0.123, -0.456] - - chat_completion = create_mock_chat_completion_with_logprobs("Hello world", token_strs, logprobs) - - parsed_logprobs = parse_chat_completion_logprobs(chat_completion) - - assert len(parsed_logprobs) == 2 - assert parsed_logprobs == logprobs - - -def test_parse_chat_completion_tokens_empty_response(gpt2_tokenizer): - """Test handling of empty response.""" - chat_completion = create_mock_chat_completion_with_logprobs("", [], []) - - parsed_tokens = parse_chat_completion_tokens_from_bytes(chat_completion, gpt2_tokenizer) - assert parsed_tokens == [] From 90c3bdad98821ffebd1ef2ab28538b03921ddbea Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Thu, 16 Apr 2026 02:39:26 -0700 Subject: [PATCH 3/3] [rl] Add rollout response loss masks --- .../rl/environments/inference_ctx/base.py | 2 ++ lib/marin/src/marin/rl/train_batch.py | 14 ++++++---- lib/marin/src/marin/rl/types.py | 12 ++++++++- tests/rl/environments/test_mock_env.py | 4 +++ tests/rl/integration/tasks.py | 8 ++++++ tests/rl/test_inference_ctx.py | 1 + tests/rl/test_loss.py | 23 ++++++++++++++++ tests/rl/test_replay_buffer.py | 1 + tests/rl/test_rollout_storage.py | 1 + tests/rl/test_train_batch.py | 26 +++++++++++++++++++ 10 files changed, 86 insertions(+), 6 deletions(-) diff --git a/lib/marin/src/marin/rl/environments/inference_ctx/base.py b/lib/marin/src/marin/rl/environments/inference_ctx/base.py index 3cd0b7b4cb..c539fb9e6d 100644 --- a/lib/marin/src/marin/rl/environments/inference_ctx/base.py +++ b/lib/marin/src/marin/rl/environments/inference_ctx/base.py @@ -145,6 +145,7 @@ def create_rollout_from_choice( logger.error(f"Prompt tokenization failed for {env_example_id}") token_rewards = np.full(len(response_tokens), reward, dtype=np.float32) + response_loss_mask = np.ones(len(response_tokens), dtype=np.float32) is_truncated = choice.finish_reason == "length" return Rollout( @@ -153,6 +154,7 @@ def create_rollout_from_choice( prompt_tokens=prompt_tokens, response_tokens=response_tokens, response_logprobs=response_logprobs, + response_loss_mask=response_loss_mask, token_rewards=token_rewards, episode_reward=float(reward), correctness_reward=correctness_reward, diff --git a/lib/marin/src/marin/rl/train_batch.py b/lib/marin/src/marin/rl/train_batch.py index c49ea8c077..21440533df 100644 --- a/lib/marin/src/marin/rl/train_batch.py +++ b/lib/marin/src/marin/rl/train_batch.py @@ -49,25 +49,29 @@ def convert_rollout_to_training_format( """ input_tokens = np.concatenate([rollout.prompt_tokens, rollout.response_tokens]) position_ids = np.arange(len(input_tokens), dtype=np.int32) + response_loss_mask = rollout.response_loss_mask.astype(np.float32) - # Loss mask (only on response tokens) + # Loss mask (only on response tokens selected by the rollout) loss_mask = np.concatenate( [ np.zeros(len(rollout.prompt_tokens), dtype=np.float32), - np.ones(len(rollout.response_tokens), dtype=np.float32), + response_loss_mask, ] ) - # Loss weights (advantage for all response tokens) + # Loss weights (advantage for response tokens selected by the rollout) loss_weight = np.concatenate( [ np.zeros(len(rollout.prompt_tokens), dtype=np.float32), - np.full(len(rollout.response_tokens), advantage, dtype=np.float32), + response_loss_mask * np.float32(advantage), ] ) policy_logprob = np.concatenate( - [np.zeros(len(rollout.prompt_tokens), dtype=np.float32), rollout.response_logprobs.astype(np.float32)] + [ + np.zeros(len(rollout.prompt_tokens), dtype=np.float32), + rollout.response_logprobs.astype(np.float32) * response_loss_mask, + ] ) max_seq_len = max_tokens diff --git a/lib/marin/src/marin/rl/types.py b/lib/marin/src/marin/rl/types.py index ed93bb667a..778e4a561c 100644 --- a/lib/marin/src/marin/rl/types.py +++ b/lib/marin/src/marin/rl/types.py @@ -63,6 +63,9 @@ class Rollout(eqx.Module): response_logprobs: np.ndarray """Array of (response_length,) log probabilities for each generated token.""" + response_loss_mask: np.ndarray + """Array of (response_length,) mask values indicating which response tokens receive loss.""" + token_rewards: np.ndarray """The reward assigned to each generated token.""" @@ -84,6 +87,13 @@ class Rollout(eqx.Module): correctness_reward: float | None = None """The reward for the correctness of the response.""" + def __post_init__(self): + if len(self.response_loss_mask) != len(self.response_tokens): + raise ValueError( + "response_loss_mask length must match response_tokens length, " + f"got {len(self.response_loss_mask)} and {len(self.response_tokens)}" + ) + class RolloutGroup(eqx.Module): """Multiple rollouts for the same prompt (e.g., n_generations samples).""" @@ -112,7 +122,7 @@ class TrainingBatch(eqx.Module): input_ids: ht.Int[NamedArray, "batch position"] position_ids: ht.Int[NamedArray, "batch position"] loss_weights: ht.Float[NamedArray, "batch position"] - loss_masks: ht.Int[NamedArray, "batch position"] + loss_masks: ht.Float[NamedArray, "batch position"] policy_logprobs: ht.Float[NamedArray, "batch position"] temperature: ht.Float[NamedArray, "batch"] # noqa: F821 top_k: ht.Int[NamedArray, "batch"] # noqa: F821 diff --git a/tests/rl/environments/test_mock_env.py b/tests/rl/environments/test_mock_env.py index dd3b20cb89..e7d21595b5 100644 --- a/tests/rl/environments/test_mock_env.py +++ b/tests/rl/environments/test_mock_env.py @@ -120,6 +120,7 @@ def create_rollout_from_choice( env_example_id, reward, temperature, + top_k=None, system_prompt=None, correctness_reward=None, ): @@ -127,6 +128,7 @@ def create_rollout_from_choice( response_tokens = self.get_choice_tokens(choice) response_logprobs = self.get_choice_logprobs(choice) token_rewards = jnp.full(len(response_tokens), reward, dtype=jnp.float32) + response_loss_mask = jnp.ones(len(response_tokens), dtype=jnp.float32) return Rollout( env_name=env_name, @@ -134,9 +136,11 @@ def create_rollout_from_choice( prompt_tokens=jnp.array(prompt_tokens, dtype=jnp.int32), response_tokens=jnp.array(response_tokens, dtype=jnp.int32), response_logprobs=jnp.array(response_logprobs, dtype=jnp.float32), + response_loss_mask=response_loss_mask, token_rewards=token_rewards, episode_reward=float(reward), temperature=temperature, + top_k=top_k, is_truncated=False, correctness_reward=correctness_reward, ) diff --git a/tests/rl/integration/tasks.py b/tests/rl/integration/tasks.py index caed1e39e0..31e0589d3d 100644 --- a/tests/rl/integration/tasks.py +++ b/tests/rl/integration/tasks.py @@ -110,8 +110,12 @@ def create_cats_rollout_batch( prompt_tokens=individual_prompt, response_tokens=individual_response, response_logprobs=individual_logprobs, + response_loss_mask=np.ones(len(individual_response), dtype=np.float32), token_rewards=token_rewards, episode_reward=episode_reward, + temperature=1.0, + top_k=None, + is_truncated=False, metadata=RolloutMetadata( worker_id=worker_id, timestamp=time.time(), @@ -300,8 +304,12 @@ def create_sequential_digits_rollout_batch( prompt_tokens=individual_prompt, response_tokens=individual_response, response_logprobs=individual_logprobs, + response_loss_mask=np.ones(len(individual_response), dtype=np.float32), token_rewards=token_rewards, episode_reward=episode_reward, + temperature=1.0, + top_k=None, + is_truncated=False, metadata=RolloutMetadata( worker_id=worker_id, timestamp=time.time(), diff --git a/tests/rl/test_inference_ctx.py b/tests/rl/test_inference_ctx.py index e2fd2a91c7..cf092a9823 100644 --- a/tests/rl/test_inference_ctx.py +++ b/tests/rl/test_inference_ctx.py @@ -306,6 +306,7 @@ def test_create_rollout_from_choice_end_to_end(inference_ctx, llama3_tokenizer): # Verify token rewards assert len(rollout.token_rewards) == len(expected_response_tokens) np.testing.assert_array_equal(rollout.token_rewards, np.full(len(expected_response_tokens), reward)) + np.testing.assert_array_equal(rollout.response_loss_mask, np.ones(len(expected_response_tokens), dtype=np.float32)) def test_vllm_inference_context_uses_canonical_model_name(monkeypatch): diff --git a/tests/rl/test_loss.py b/tests/rl/test_loss.py index e0440f2958..1375ac3c6c 100644 --- a/tests/rl/test_loss.py +++ b/tests/rl/test_loss.py @@ -13,11 +13,14 @@ def create_test_rollout( env_name: str = "test_env", episode_reward: float = 1.0, unique_id: int = 12345, + response_loss_mask: np.ndarray | None = None, ) -> Rollout: """Create a test rollout with predictable token values.""" prompt_tokens = np.full(prompt_len, unique_id, dtype=np.int32) response_tokens = np.arange(response_len, dtype=np.int32) + 1000 response_logprobs = np.full(response_len, -0.5, dtype=np.float32) + if response_loss_mask is None: + response_loss_mask = np.ones(response_len, dtype=np.float32) token_rewards = np.full(response_len, 0.1, dtype=np.float32) return Rollout( @@ -26,6 +29,7 @@ def create_test_rollout( prompt_tokens=prompt_tokens, response_tokens=response_tokens, response_logprobs=response_logprobs, + response_loss_mask=response_loss_mask, token_rewards=token_rewards, episode_reward=episode_reward, temperature=1.0, @@ -141,3 +145,22 @@ def test_rloo_loss_needs_reference_model_only_when_kl_enabled(): def test_rloo_loss_rejects_missing_reference_model_when_kl_enabled(): with pytest.raises(ValueError, match="reference_model is required"): RLOOLoss(kl_coef=0.01).create_loss_fn(reference_model=None, train_model=None) + + +def test_ppo_objective_stays_finite_with_sparse_masks(): + importance_sampling_ratio = np.array([[1.0, 1.0, 1.0, 1.2, 0.0, 0.8]], dtype=np.float32) + loss_weights = np.array([[0.0, 0.0, 0.0, 0.5, 0.0, 1.0]], dtype=np.float32) + loss_masks = np.array([[0.0, 0.0, 0.0, 1.0, 0.0, 1.0]], dtype=np.float32) + + loss, metadata = compute_ppo_loss_objective( + importance_sampling_ratio, + loss_weights, + loss_masks, + clip_epsilon_low=0.2, + clip_epsilon_high=0.2, + max_output_tokens=loss_masks.shape[-1], + trainer_inference_importance_sampling_ratio=None, + ) + + assert np.isfinite(loss) + assert set(metadata) == {"loss_max_over_batch", "loss_std_over_batch"} diff --git a/tests/rl/test_replay_buffer.py b/tests/rl/test_replay_buffer.py index 93cf92218c..7155351fde 100644 --- a/tests/rl/test_replay_buffer.py +++ b/tests/rl/test_replay_buffer.py @@ -71,6 +71,7 @@ def create_test_batch( prompt_tokens=prompt_tokens, response_tokens=response_tokens, response_logprobs=response_logprobs, + response_loss_mask=np.ones(response_len, dtype=np.float32), token_rewards=token_rewards, episode_reward=episode_reward, temperature=1.0, diff --git a/tests/rl/test_rollout_storage.py b/tests/rl/test_rollout_storage.py index d6ae899855..d936c4dc96 100644 --- a/tests/rl/test_rollout_storage.py +++ b/tests/rl/test_rollout_storage.py @@ -42,6 +42,7 @@ def create_test_rollout(idx: int) -> Rollout: prompt_tokens=prompt_tokens, response_tokens=response_tokens, response_logprobs=response_logprobs, + response_loss_mask=np.ones(response_len, dtype=np.float32), token_rewards=token_rewards, episode_reward=episode_reward, temperature=1.0, diff --git a/tests/rl/test_train_batch.py b/tests/rl/test_train_batch.py index 24c243979c..9c91cda569 100644 --- a/tests/rl/test_train_batch.py +++ b/tests/rl/test_train_batch.py @@ -16,11 +16,14 @@ def create_test_rollout( env_name: str = "test_env", episode_reward: float = 1.0, unique_id: int = 12345, + response_loss_mask: np.ndarray | None = None, ) -> Rollout: """Create a test rollout with predictable token values.""" prompt_tokens = np.full(prompt_len, unique_id, dtype=np.int32) response_tokens = np.arange(response_len, dtype=np.int32) + 1000 response_logprobs = np.full(response_len, -0.5, dtype=np.float32) + if response_loss_mask is None: + response_loss_mask = np.ones(response_len, dtype=np.float32) token_rewards = np.full(response_len, 0.1, dtype=np.float32) return Rollout( @@ -29,6 +32,7 @@ def create_test_rollout( prompt_tokens=prompt_tokens, response_tokens=response_tokens, response_logprobs=response_logprobs, + response_loss_mask=response_loss_mask, token_rewards=token_rewards, episode_reward=episode_reward, temperature=1.0, @@ -110,6 +114,28 @@ def test_loss_weights_have_advantage(): np.testing.assert_array_equal(loss_weights, expected_weights) +def test_sparse_response_loss_mask_is_preserved(): + """Test that sparse response masks propagate into training masks and weights.""" + response_loss_mask = np.array([1.0, 0.0, 1.0], dtype=np.float32) + rollout = create_test_rollout(prompt_len=4, response_len=3, response_loss_mask=response_loss_mask) + + result = train_batch.convert_rollout_to_training_format(rollout, 2.5, max_tokens=16, pad_token_id=0, pad_to=16) + + expected_loss_mask = np.array([0, 0, 0, 0, 1, 0, 1] + [0] * 9, dtype=np.float32) + expected_loss_weights = np.array([0, 0, 0, 0, 2.5, 0, 2.5] + [0] * 9, dtype=np.float32) + expected_policy_logprobs = np.array([0, 0, 0, 0, -0.5, 0, -0.5] + [0] * 9, dtype=np.float32) + + np.testing.assert_array_equal(result["loss_masks"], expected_loss_mask) + np.testing.assert_array_equal(result["loss_weights"], expected_loss_weights) + np.testing.assert_array_equal(result["policy_logprobs"], expected_policy_logprobs) + + +def test_rollout_rejects_mismatched_response_loss_mask(): + """Test that rollout construction rejects mismatched response loss masks.""" + with pytest.raises(ValueError, match="response_loss_mask length must match response_tokens length"): + create_test_rollout(response_len=3, response_loss_mask=np.array([1.0, 0.0], dtype=np.float32)) + + def test_token_sequence_shifted_correctly(): """Test that input sequence contains full prompt+response (shifting now happens in rl_losses.py).""" rollout = create_test_rollout(prompt_len=3, response_len=2)