diff --git a/experiments/reasoning_gym_curriculum_examples.py b/experiments/reasoning_gym_curriculum_examples.py new file mode 100644 index 0000000000..cf62bd50e6 --- /dev/null +++ b/experiments/reasoning_gym_curriculum_examples.py @@ -0,0 +1,66 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Small curriculum examples for Reasoning Gym-backed RL lessons. + +These helpers are intentionally minimal. They show the expected `EnvConfig` +shape for `ReasoningGymEnv` without introducing another full launcher script. +""" + +from marin.rl.curriculum import CurriculumConfig, LessonConfig, SamplingParams +from marin.rl.environments import EnvConfig +from marin.rl.rl_experiment_utils import RLExperimentConfig + +DEFAULT_REASONING_GYM_EVAL_N_EXAMPLES = 128 +REASONING_GYM_LEG_COUNTING_SEED = 42 + + +def build_leg_counting_curriculum( + run_id: str, + config: RLExperimentConfig, + eval_frequency: int, +) -> CurriculumConfig: + """Build a minimal single-lesson Reasoning Gym curriculum example.""" + sampling_params = SamplingParams( + temperature=1.0, + n_prompts=config.n_prompts, + n_generations_per_prompt=config.n_generations_per_prompt, + max_output_tokens=config.max_output_tokens, + top_k=config.inference_top_k, + stop_tokens=None, + ) + + return CurriculumConfig( + lessons={ + "rg_leg_counting": LessonConfig( + lesson_id="rg_leg_counting", + env_config=EnvConfig( + env_class="marin.rl.environments.reasoning_gym_env.ReasoningGymEnv", + env_args={ + "dataset_name": "leg_counting", + "train_dataset_args": { + "seed": REASONING_GYM_LEG_COUNTING_SEED, + "size": 10_000, + "min_animals": 2, + "max_animals": 4, + }, + "eval_dataset_args": { + "seed": REASONING_GYM_LEG_COUNTING_SEED + 1, + "size": DEFAULT_REASONING_GYM_EVAL_N_EXAMPLES, + "min_animals": 2, + "max_animals": 4, + }, + "success_threshold": 1.0, + "prompt_template": "{question}", + }, + ), + dependencies=[], + sampling_params=sampling_params, + ), + }, + eval_frequency=eval_frequency, + micro_eval_frequency=None, + actor_name=f"curriculum-{run_id}", + eval_n_examples=DEFAULT_REASONING_GYM_EVAL_N_EXAMPLES, + max_seq_len=config.max_input_tokens + config.max_output_tokens, + ) diff --git a/lib/marin/pyproject.toml b/lib/marin/pyproject.toml index b2bda71f4b..2e664b324b 100644 --- a/lib/marin/pyproject.toml +++ b/lib/marin/pyproject.toml @@ -143,6 +143,10 @@ rl = [ "verifiers==0.1.5", ] +reasoning-gym = [ + "reasoning-gym==0.1.19", +] + eval = [ "lm-eval[math,api]@git+https://github.com/stanford-crfm/lm-evaluation-harness@d5e3391f22cde186c827674d5c3ec7c5f4fe0cab", ] diff --git a/lib/marin/src/marin/rl/environments/reasoning_gym_env.py b/lib/marin/src/marin/rl/environments/reasoning_gym_env.py new file mode 100644 index 0000000000..e616d6f34a --- /dev/null +++ b/lib/marin/src/marin/rl/environments/reasoning_gym_env.py @@ -0,0 +1,293 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import copy +import importlib +import logging +import math +import re +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from marin.rl.environments.inference_ctx.base import BaseInferenceContext +from marin.rl.types import Rollout, RolloutGroup + +from .base import MarinEnv, extract_seed + +logger = logging.getLogger(__name__) + +MODE_TRAIN = "train" +MODE_EVAL = "eval" +COMPOSITE_DATASET_NAME = "composite" +REQUIRED_DATASET_ARG_KEYS = ("seed", "size") +QUESTION_KEY = "question" +METADATA_KEY = "metadata" +ENTRY_ID_KEY = "entry_id" +SOURCE_DATASET_KEY = "source_dataset" +SOURCE_INDEX_KEY = "source_index" +DATASETS_KEY = "datasets" +ENV_NAME_PREFIX = "reasoning_gym" +DEFAULT_SUCCESS_THRESHOLD = 1.0 +QUESTION_TEMPLATE_FIELD = "{question}" +NON_ALNUM_UNDERSCORE_PATTERN = re.compile(r"[^A-Za-z0-9_]") + + +@dataclass(frozen=True) +class ReasoningGymExample: + """Normalized view of a single Reasoning Gym example.""" + + prompt: str + example_id: str + raw_entry: dict[str, Any] + source_dataset: str | None + + +class ReasoningGymEnv(MarinEnv): + """Marin RL environment backed by the Reasoning Gym Python API.""" + + def __init__( + self, + dataset_name: str, + train_dataset_args: dict[str, Any], + eval_dataset_args: dict[str, Any], + success_threshold: float = DEFAULT_SUCCESS_THRESHOLD, + sample_with_replacement: bool = False, + prompt_template: str = QUESTION_TEMPLATE_FIELD, + ) -> None: + if QUESTION_TEMPLATE_FIELD not in prompt_template: + raise ValueError("prompt_template must include '{question}'") + if not math.isfinite(success_threshold): + raise ValueError("success_threshold must be finite") + + self.dataset_name = dataset_name + self.env_name = f"{ENV_NAME_PREFIX}:{dataset_name}" + self.success_threshold = success_threshold + self.sample_with_replacement = sample_with_replacement + self.prompt_template = prompt_template + + reasoning_gym = self._ensure_reasoning_gym_installed() + self._train_dataset_args = self._normalize_dataset_args(dataset_name, train_dataset_args) + self._eval_dataset_args = self._normalize_dataset_args(dataset_name, eval_dataset_args) + self._validate_dataset_args(MODE_TRAIN, self._train_dataset_args) + self._validate_dataset_args(MODE_EVAL, self._eval_dataset_args) + self._train_dataset = reasoning_gym.create_dataset(dataset_name, **self._train_dataset_args) + self._eval_dataset = reasoning_gym.create_dataset(dataset_name, **self._eval_dataset_args) + + def sample( + self, + inference_ctx: BaseInferenceContext, + n_examples: int, + n_generations: int, + temperature: float, + prng_key, + mode: str = MODE_TRAIN, + max_tokens: int | None = None, + top_k: int | None = None, + stop: list[str] | None = None, + system_prompt: str | None = None, + ) -> tuple[list[RolloutGroup], dict[str, float]]: + """Sample prompts from Reasoning Gym, score completions, and build rollouts.""" + if n_examples <= 0: + raise ValueError("n_examples must be positive") + if n_generations <= 0: + raise ValueError("n_generations must be positive") + + dataset = self._dataset_for_mode(mode) + rng = np.random.default_rng(extract_seed(prng_key)) + indices = self._sample_indices(dataset, n_examples, rng) + sampled_examples = [self._build_example(dataset, mode, int(idx)) for idx in indices] + prompts = [example.prompt for example in sampled_examples] + + completions = inference_ctx.batch_completions( + prompts=prompts, + temperature=temperature, + n=n_generations, + max_tokens=max_tokens, + top_k=top_k, + stop=stop, + system_prompt=system_prompt, + ) + + rollout_groups: list[RolloutGroup] = [] + total_choices = 0 + reward_sum = 0.0 + solve_sum = 0.0 + response_token_count = 0 + truncated_count = 0 + source_counts: dict[str, int] = {} + + for example, completion in zip(sampled_examples, completions, strict=True): + group_rollouts: list[Rollout] = [] + source_name = example.source_dataset or self.dataset_name + source_counts[source_name] = source_counts.get(source_name, 0) + 1 + + for choice in completion.choices: + response_text = choice.message.content or "" + reward = self._score_choice(dataset, example.raw_entry, response_text) + solved = float(reward >= self.success_threshold) + + rollout = inference_ctx.create_rollout_from_choice( + prompt=example.prompt, + choice=choice, + env_name=self.env_name, + env_example_id=example.example_id, + reward=reward, + correctness_reward=solved, + temperature=temperature, + top_k=top_k, + system_prompt=system_prompt, + ) + + group_rollouts.append(rollout) + total_choices += 1 + reward_sum += reward + solve_sum += solved + response_token_count += rollout.response_tokens.size + + if choice.finish_reason == "length": + truncated_count += 1 + + if group_rollouts: + rollout_groups.append(RolloutGroup(rollouts=group_rollouts)) + + if total_choices == 0: + raise RuntimeError("Inference context returned no choices; cannot compute metrics") + + prefix = self._metrics_prefix(mode) + metrics: dict[str, float] = { + f"{prefix}_mean_reward": reward_sum / total_choices, + f"{prefix}_solve_rate": solve_sum / total_choices, + f"{prefix}_mean_response_tokens": response_token_count / total_choices, + f"{prefix}_total_responses": float(total_choices), + f"{prefix}_sampled_examples": float(len(sampled_examples)), + f"{prefix}_truncated_percentage": float(truncated_count) / total_choices, + } + for source_name, count in sorted(source_counts.items()): + metrics[f"{prefix}_source_{self._metric_name_fragment(source_name)}_count"] = float(count) + + return rollout_groups, metrics + + def _dataset_for_mode(self, mode: str): + if mode == MODE_TRAIN: + dataset = self._train_dataset + elif mode == MODE_EVAL: + dataset = self._eval_dataset + else: + raise ValueError(f"Unsupported mode: {mode}") + + if len(dataset) == 0: + raise ValueError(f"No examples available for mode '{mode}'") + + return dataset + + def _sample_indices(self, dataset, n_examples: int, rng: np.random.Generator) -> np.ndarray: + dataset_size = len(dataset) + if self.sample_with_replacement: + return rng.choice(dataset_size, size=n_examples, replace=True) + + n_to_sample = min(n_examples, dataset_size) + return rng.choice(dataset_size, size=n_to_sample, replace=False) + + def _build_example(self, dataset, mode: str, idx: int) -> ReasoningGymExample: + entry = dataset[idx] + question = entry.get(QUESTION_KEY) + if not isinstance(question, str): + raise ValueError(f"Reasoning Gym entry at index {idx} is missing a string '{QUESTION_KEY}' field") + + metadata = entry.get(METADATA_KEY, {}) + if not isinstance(metadata, dict): + raise ValueError(f"Reasoning Gym entry at index {idx} has non-dict metadata: {type(metadata)!r}") + + source_dataset = metadata.get(SOURCE_DATASET_KEY) + if source_dataset is not None and not isinstance(source_dataset, str): + raise ValueError(f"Reasoning Gym metadata field '{SOURCE_DATASET_KEY}' must be a string") + + example_id = self._build_example_id(dataset, mode, idx, metadata) + prompt = self.prompt_template.format(question=question) + return ReasoningGymExample( + prompt=prompt, + example_id=example_id, + raw_entry=entry, + source_dataset=source_dataset, + ) + + def _build_example_id(self, dataset, mode: str, idx: int, metadata: dict[str, Any]) -> str: + if isinstance(metadata.get(ENTRY_ID_KEY), str): + return f"{self.env_name}:{mode}:{metadata[ENTRY_ID_KEY]}" + + source_dataset = metadata.get(SOURCE_DATASET_KEY) + source_index = metadata.get(SOURCE_INDEX_KEY) + if isinstance(source_dataset, str): + source_index_fragment = source_index if source_index is not None else idx + return f"{self.env_name}:{mode}:{source_dataset}:{source_index_fragment}" + + dataset_seed = getattr(dataset, "seed", "unknown") + if source_index is not None: + return f"{self.env_name}:{mode}:{dataset_seed}:{source_index}" + return f"{self.env_name}:{mode}:{dataset_seed}:{idx}" + + def _score_choice(self, dataset, entry: dict[str, Any], response_text: str) -> float: + score = float(dataset.score_answer(response_text, entry)) + if not math.isfinite(score): + raise ValueError(f"Reasoning Gym returned a non-finite score for dataset '{self.dataset_name}': {score}") + if score < 0.0 or score > 1.0: + logger.warning( + "Reasoning Gym score for dataset '%s' fell outside [0, 1]: %f", + self.dataset_name, + score, + ) + return score + + def _normalize_dataset_args(self, dataset_name: str, dataset_args: dict[str, Any]) -> dict[str, Any]: + normalized_args = copy.deepcopy(dataset_args) + if dataset_name == COMPOSITE_DATASET_NAME: + return self._normalize_composite_specs(normalized_args) + return normalized_args + + def _normalize_composite_specs(self, dataset_args: dict[str, Any]) -> dict[str, Any]: + datasets = dataset_args.get(DATASETS_KEY) + if datasets is None: + return dataset_args + if not isinstance(datasets, list): + raise ValueError(f"Composite dataset args field '{DATASETS_KEY}' must be a list") + + composite_module = importlib.import_module("reasoning_gym.composite") + dataset_spec_cls = composite_module.DatasetSpec + normalized_specs = [] + for dataset_spec in datasets: + if isinstance(dataset_spec, dict): + normalized_specs.append(dataset_spec_cls(**dataset_spec)) + else: + normalized_specs.append(dataset_spec) + dataset_args[DATASETS_KEY] = normalized_specs + return dataset_args + + def _validate_dataset_args(self, mode: str, dataset_args: dict[str, Any]) -> None: + missing_keys = [key for key in REQUIRED_DATASET_ARG_KEYS if key not in dataset_args] + if missing_keys: + missing = ", ".join(sorted(missing_keys)) + raise ValueError(f"{mode}_dataset_args must include explicit {missing}") + if self.dataset_name == COMPOSITE_DATASET_NAME and not dataset_args.get(DATASETS_KEY): + raise ValueError(f"{mode}_dataset_args must include a non-empty '{DATASETS_KEY}' list for composite") + + def _metrics_prefix(self, mode: str) -> str: + return f"{ENV_NAME_PREFIX}.{self._metric_name_fragment(self.dataset_name)}.{mode}" + + def _ensure_reasoning_gym_installed(self): + try: + return importlib.import_module("reasoning_gym") + except ModuleNotFoundError as exc: + if exc.name != "reasoning_gym": + raise + raise ImportError( + "The 'reasoning_gym' package is required to use ReasoningGymEnv. " + "Install it with: uv sync --extra reasoning-gym" + ) from exc + + @staticmethod + def _metric_name_fragment(name: str) -> str: + return NON_ALNUM_UNDERSCORE_PATTERN.sub("_", name) diff --git a/tests/rl/environments/reasoning_gym_test_support.py b/tests/rl/environments/reasoning_gym_test_support.py new file mode 100644 index 0000000000..9552e9dd0d --- /dev/null +++ b/tests/rl/environments/reasoning_gym_test_support.py @@ -0,0 +1,205 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import copy +import sys +from dataclasses import dataclass +from types import ModuleType +from typing import Any + +import numpy as np +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.completion_usage import CompletionUsage + +from marin.rl.types import Rollout + + +class TestTokenizer: + """Tiny deterministic tokenizer for environment tests.""" + + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: + del add_special_tokens + return [ord(char) for char in text] + + def decode(self, token_ids) -> str: + return "".join(chr(int(token_id)) for token_id in token_ids) + + def apply_chat_template(self, messages, tokenize: bool, add_generation_prompt: bool): + del tokenize, add_generation_prompt + return [ord(char) for char in messages[0]["content"]] + + +def create_test_chat_completion(prompt: str, responses: list[str]) -> ChatCompletion: + """Create a minimal chat completion with one choice per response.""" + choices = [ + Choice( + finish_reason="stop", + index=index, + message=ChatCompletionMessage(role="assistant", content=response_text), + logprobs=None, + ) + for index, response_text in enumerate(responses) + ] + completion_tokens = sum(len(response_text) for response_text in responses) + return ChatCompletion( + id=f"chatcmpl-test-{hash(prompt)}", + choices=choices, + created=1234567890, + model="test-model", + object="chat.completion", + usage=CompletionUsage( + completion_tokens=completion_tokens, + prompt_tokens=len(prompt), + total_tokens=len(prompt) + completion_tokens, + ), + ) + + +class DummyInferenceContext: + """Inference context with deterministic prompt -> response mapping.""" + + def __init__(self, responses_by_prompt: dict[str, list[str]], default_responses: list[str] | None = None): + self.tokenizer = TestTokenizer() + self.responses_by_prompt = responses_by_prompt + self.default_responses = default_responses or ["fallback"] + self.last_request: dict[str, Any] | None = None + + def batch_completions( + self, + prompts, + temperature, + n, + max_tokens=None, + top_k=None, + stop=None, + system_prompt=None, + ): + self.last_request = { + "prompts": prompts, + "temperature": temperature, + "n": n, + "max_tokens": max_tokens, + "top_k": top_k, + "stop": stop, + "system_prompt": system_prompt, + } + completions = [] + for prompt in prompts: + responses = list(self.responses_by_prompt.get(prompt, self.default_responses)) + if len(responses) < n: + responses.extend([responses[-1]] * (n - len(responses))) + completions.append(create_test_chat_completion(prompt, responses[:n])) + return completions + + def create_rollout_from_choice( + self, + prompt, + choice, + env_name, + env_example_id, + reward, + temperature, + top_k=None, + system_prompt=None, + correctness_reward=None, + ): + del system_prompt + prompt_tokens = np.array([ord(char) for char in prompt], dtype=np.int32) + response_text = choice.message.content or "" + response_tokens = np.array([ord(char) for char in response_text], dtype=np.int32) + response_logprobs = np.full(len(response_tokens), -1.0, dtype=np.float32) + token_rewards = np.full(len(response_tokens), reward, dtype=np.float32) + + return Rollout( + env_name=env_name, + env_example_id=env_example_id, + prompt_tokens=prompt_tokens, + response_tokens=response_tokens, + response_logprobs=response_logprobs, + token_rewards=token_rewards, + episode_reward=float(reward), + correctness_reward=correctness_reward, + temperature=temperature, + top_k=top_k, + is_truncated=choice.finish_reason == "length", + ) + + +class FakeReasoningGymDataset: + """Minimal procedural dataset for Reasoning Gym adapter tests.""" + + def __init__(self, entries: list[dict[str, Any]], *, seed: int, size: int): + self._entries = [copy.deepcopy(entry) for entry in entries] + self.seed = seed + self.size = size + + def __len__(self) -> int: + return min(self.size, len(self._entries)) + + def __getitem__(self, idx: int) -> dict[str, Any]: + return copy.deepcopy(self._entries[idx]) + + def score_answer(self, answer: str | None, entry: dict[str, Any]) -> float: + metadata = entry.get("metadata", {}) + score_map = metadata.get("score_map") + if isinstance(score_map, dict): + return float(score_map.get(answer, 0.0)) + return 1.0 if answer == entry.get("answer") else 0.0 + + +@dataclass +class FakeReasoningGymModules: + dataset_spec_cls: type + create_calls: list[dict[str, Any]] + + +def install_fake_reasoning_gym( + monkeypatch, + *, + datasets_by_seed: dict[int, list[dict[str, Any]]] | None = None, +) -> FakeReasoningGymModules: + """Install fake reasoning_gym modules into sys.modules for testing.""" + + class DatasetSpec: + def __init__(self, name: str, weight: float, config: dict[str, Any]): + self.name = name + self.weight = weight + self.config = config + + def validate(self) -> None: + if not self.name: + raise ValueError("Dataset name cannot be empty") + if self.weight <= 0: + raise ValueError("Weight must be positive") + + create_calls: list[dict[str, Any]] = [] + seeded_entries = datasets_by_seed or {} + + def create_dataset(name: str, **kwargs): + create_calls.append({"name": name, "kwargs": copy.deepcopy(kwargs)}) + seed = kwargs["seed"] + size = kwargs["size"] + if seed in seeded_entries: + entries = seeded_entries[seed] + else: + entries = [ + { + "question": f"{name} question {seed}", + "answer": "ok", + "metadata": {"source_index": 0, "source_dataset": name}, + } + ] + return FakeReasoningGymDataset(entries, seed=seed, size=size) + + reasoning_gym_module = ModuleType("reasoning_gym") + reasoning_gym_module.create_dataset = create_dataset + composite_module = ModuleType("reasoning_gym.composite") + composite_module.DatasetSpec = DatasetSpec + + monkeypatch.setitem(sys.modules, "reasoning_gym", reasoning_gym_module) + monkeypatch.setitem(sys.modules, "reasoning_gym.composite", composite_module) + + return FakeReasoningGymModules(dataset_spec_cls=DatasetSpec, create_calls=create_calls) diff --git a/tests/rl/environments/test_load_environment.py b/tests/rl/environments/test_load_environment.py index 24dc6d0555..03609c95c5 100644 --- a/tests/rl/environments/test_load_environment.py +++ b/tests/rl/environments/test_load_environment.py @@ -3,6 +3,7 @@ """Tests for environment loading from EnvConfig.""" +from tests.rl.environments.reasoning_gym_test_support import install_fake_reasoning_gym from marin.rl.environments import EnvConfig, load_environment_from_spec from marin.rl.environments.mock_env import MockEnv from marin.rl.environments.math_env import MathEnv @@ -36,3 +37,22 @@ def test_load_math_environment(): assert isinstance(env, MathEnv) assert len(env.train_examples) > 0 assert len(env.eval_examples) > 0 + + +def test_load_reasoning_gym_environment(monkeypatch): + """Test loading ReasoningGymEnv via EnvConfig.""" + install_fake_reasoning_gym(monkeypatch) + from marin.rl.environments.reasoning_gym_env import ReasoningGymEnv + + config = EnvConfig( + env_class="marin.rl.environments.reasoning_gym_env.ReasoningGymEnv", + env_args={ + "dataset_name": "leg_counting", + "train_dataset_args": {"seed": 42, "size": 2}, + "eval_dataset_args": {"seed": 43, "size": 2}, + }, + ) + + env = load_environment_from_spec(config) + + assert isinstance(env, ReasoningGymEnv) diff --git a/tests/rl/environments/test_reasoning_gym_env.py b/tests/rl/environments/test_reasoning_gym_env.py new file mode 100644 index 0000000000..185dc57b3c --- /dev/null +++ b/tests/rl/environments/test_reasoning_gym_env.py @@ -0,0 +1,214 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +import importlib + +import jax.random +import pytest + +from marin.rl.environments.reasoning_gym_env import ReasoningGymEnv +from tests.rl.environments.reasoning_gym_test_support import DummyInferenceContext, install_fake_reasoning_gym + + +def test_reasoning_gym_env_sample_uses_scores_and_binary_correctness(monkeypatch): + modules = install_fake_reasoning_gym( + monkeypatch, + datasets_by_seed={ + 0: [ + { + "question": "How many legs?", + "answer": "4", + "metadata": { + "source_index": 3, + "score_map": {"4": 1.0, "5": 0.25}, + }, + } + ], + 1: [ + { + "question": "Eval legs?", + "answer": "6", + "metadata": { + "source_index": 8, + "score_map": {"6": 1.0}, + }, + } + ], + }, + ) + env = ReasoningGymEnv( + dataset_name="leg_counting", + train_dataset_args={"seed": 0, "size": 1}, + eval_dataset_args={"seed": 1, "size": 1}, + prompt_template="Solve carefully:\n{question}", + ) + inference_ctx = DummyInferenceContext({"Solve carefully:\nHow many legs?": ["4", "5"]}) + + rollout_groups, metrics = env.sample( + inference_ctx=inference_ctx, + n_examples=1, + n_generations=2, + temperature=0.7, + prng_key=jax.random.PRNGKey(0), + mode="train", + system_prompt="Use concise answers.", + ) + + assert len(modules.create_calls) == 2 + assert len(rollout_groups) == 1 + assert len(rollout_groups[0].rollouts) == 2 + + correct_rollout, partial_rollout = rollout_groups[0].rollouts + assert correct_rollout.env_name == "reasoning_gym:leg_counting" + assert correct_rollout.episode_reward == pytest.approx(1.0) + assert correct_rollout.correctness_reward == pytest.approx(1.0) + assert partial_rollout.episode_reward == pytest.approx(0.25) + assert partial_rollout.correctness_reward == pytest.approx(0.0) + assert correct_rollout.env_example_id.endswith(":0:3") + + assert inference_ctx.last_request is not None + assert inference_ctx.last_request["system_prompt"] == "Use concise answers." + assert metrics["reasoning_gym.leg_counting.train_mean_reward"] == pytest.approx(0.625) + assert metrics["reasoning_gym.leg_counting.train_solve_rate"] == pytest.approx(0.5) + assert metrics["reasoning_gym.leg_counting.train_sampled_examples"] == pytest.approx(1.0) + + +def test_reasoning_gym_env_normalizes_composite_dataset_specs(monkeypatch): + modules = install_fake_reasoning_gym( + monkeypatch, + datasets_by_seed={ + 7: [ + { + "question": "Composite question", + "answer": "ABC", + "metadata": { + "source_dataset": "tower_of_hanoi", + "source_index": 11, + "score_map": {"ABC": 1.0}, + }, + } + ], + 8: [ + { + "question": "Composite eval question", + "answer": "XYZ", + "metadata": { + "source_dataset": "leg_counting", + "source_index": 5, + "score_map": {"XYZ": 1.0}, + }, + } + ], + }, + ) + env = ReasoningGymEnv( + dataset_name="composite", + train_dataset_args={ + "seed": 7, + "size": 1, + "datasets": [ + {"name": "tower_of_hanoi", "weight": 1.0, "config": {"min_disks": 3, "max_disks": 4}}, + {"name": "leg_counting", "weight": 1.0, "config": {"min_animals": 2, "max_animals": 3}}, + ], + }, + eval_dataset_args={ + "seed": 8, + "size": 1, + "datasets": [ + {"name": "leg_counting", "weight": 1.0, "config": {"min_animals": 2, "max_animals": 3}}, + ], + }, + ) + inference_ctx = DummyInferenceContext({"Composite question": ["ABC"]}) + + rollout_groups, metrics = env.sample( + inference_ctx=inference_ctx, + n_examples=1, + n_generations=1, + temperature=1.0, + prng_key=jax.random.PRNGKey(1), + mode="train", + ) + + train_call = modules.create_calls[0] + assert train_call["name"] == "composite" + assert all(isinstance(spec, modules.dataset_spec_cls) for spec in train_call["kwargs"]["datasets"]) + assert len(rollout_groups) == 1 + rollout = rollout_groups[0].rollouts[0] + assert rollout.env_name == "reasoning_gym:composite" + assert "tower_of_hanoi" in rollout.env_example_id + assert metrics["reasoning_gym.composite.train_source_tower_of_hanoi_count"] == pytest.approx(1.0) + + +def test_reasoning_gym_env_sampling_is_deterministic_for_fixed_prng_key(monkeypatch): + install_fake_reasoning_gym( + monkeypatch, + datasets_by_seed={ + 10: [ + {"question": "Q0", "answer": "A0", "metadata": {"source_index": 0}}, + {"question": "Q1", "answer": "A1", "metadata": {"source_index": 1}}, + {"question": "Q2", "answer": "A2", "metadata": {"source_index": 2}}, + ], + 11: [{"question": "Eval", "answer": "A", "metadata": {"source_index": 0}}], + }, + ) + env = ReasoningGymEnv( + dataset_name="leg_counting", + train_dataset_args={"seed": 10, "size": 3}, + eval_dataset_args={"seed": 11, "size": 1}, + ) + inference_ctx = DummyInferenceContext({"Q0": ["A0"], "Q1": ["A1"], "Q2": ["A2"]}) + prng_key = jax.random.PRNGKey(123) + + first_groups, _ = env.sample( + inference_ctx=inference_ctx, + n_examples=2, + n_generations=1, + temperature=1.0, + prng_key=prng_key, + mode="train", + ) + second_groups, _ = env.sample( + inference_ctx=inference_ctx, + n_examples=2, + n_generations=1, + temperature=1.0, + prng_key=prng_key, + mode="train", + ) + + first_ids = [group.rollouts[0].env_example_id for group in first_groups] + second_ids = [group.rollouts[0].env_example_id for group in second_groups] + assert first_ids == second_ids + + +def test_reasoning_gym_env_requires_explicit_seed_and_size(monkeypatch): + install_fake_reasoning_gym(monkeypatch) + + with pytest.raises(ValueError, match="train_dataset_args must include explicit size"): + ReasoningGymEnv( + dataset_name="leg_counting", + train_dataset_args={"seed": 0}, + eval_dataset_args={"seed": 1, "size": 1}, + ) + + +def test_reasoning_gym_env_raises_informative_import_error(monkeypatch): + module = importlib.import_module("marin.rl.environments.reasoning_gym_env") + real_import_module = module.importlib.import_module + + def fake_import_module(name: str): + if name == "reasoning_gym": + error = ModuleNotFoundError("No module named 'reasoning_gym'") + error.name = "reasoning_gym" + raise error + return real_import_module(name) + + monkeypatch.setattr(module.importlib, "import_module", fake_import_module) + + with pytest.raises(ImportError, match="uv sync --extra reasoning-gym"): + ReasoningGymEnv( + dataset_name="leg_counting", + train_dataset_args={"seed": 0, "size": 1}, + eval_dataset_args={"seed": 1, "size": 1}, + ) diff --git a/tests/rl/test_rollout_worker.py b/tests/rl/test_rollout_worker.py index 636c61d806..8bf72b3211 100644 --- a/tests/rl/test_rollout_worker.py +++ b/tests/rl/test_rollout_worker.py @@ -5,6 +5,7 @@ from types import SimpleNamespace import fsspec +import numpy as np import pytest from marin.rl.environments.inference_ctx.staging import stage_vllm_metadata_locally @@ -15,10 +16,44 @@ RolloutTrackerConfig, RolloutTransferCounterSnapshot, RolloutWorker, + _compute_batch_stats, _should_run_curriculum_eval, _should_run_micro_eval, create_inference_context, ) +from marin.rl.types import Rollout, RolloutBatch, RolloutGroup, RolloutMetadata + + +def test_compute_batch_stats_uses_correctness_reward_for_pass_metrics(): + batch = RolloutBatch( + groups=[ + RolloutGroup( + rollouts=[ + Rollout( + env_name="reasoning_gym:leg_counting", + env_example_id="reasoning_gym:leg_counting:train:0:0", + prompt_tokens=np.array([1, 2], dtype=np.int32), + response_tokens=np.array([3, 4], dtype=np.int32), + response_logprobs=np.array([-1.0, -1.0], dtype=np.float32), + token_rewards=np.array([0.2, 0.2], dtype=np.float32), + episode_reward=0.2, + correctness_reward=0.0, + temperature=1.0, + top_k=None, + is_truncated=False, + ) + ] + ) + ], + metadata=RolloutMetadata(worker_id="worker", timestamp=0.0, weight_step=0), + ) + + stats = _compute_batch_stats(batch, "rg_lesson") + + assert stats.avg_reward == pytest.approx(0.2) + assert stats.pass_at_one == pytest.approx(0.0) + assert stats.pass_at_k == pytest.approx(0.0) + assert stats.avg_at_k == pytest.approx(0.0) def test_rollout_tracker_uses_explicit_name_when_provided(monkeypatch): diff --git a/uv.lock b/uv.lock index b64205fb77..247a2df58f 100644 --- a/uv.lock +++ b/uv.lock @@ -500,6 +500,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/96/32cb1e91e37f52967033101236190c44de3714666663e3c7e265454e1cc8/aqtp-0.9.0-py3-none-any.whl", hash = "sha256:5efdccd24657e149d3ac17599bb3d4eb88297fdc6aab14eaba0c70625057a257", size = 901643, upload-time = "2025-08-01T17:54:57.762Z" }, ] +[[package]] +name = "arckit" +version = "0.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "drawsvg" }, + { name = "numpy" }, + { name = "rich" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/1a/7f33e162a8a8f7974ebe73102cb0b7401a87d0b88cefacdb2e6670bd4b7e/arckit-0.1.0.tar.gz", hash = "sha256:06364596df7817ac63f70a1cfab8ed2ef48bea75f72f83377f4800d885c5d506", size = 708276, upload-time = "2024-06-30T18:52:49.964Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/9d/47866ae3ae517df13919a69f34a6c86e7ddf52ecc219f08e2cf8301ec5d5/arckit-0.1.0-py3-none-any.whl", hash = "sha256:5a22619309d2c29b04796eadf2a58792e563d6b811ee021e5d8ca97880ebbe80", size = 730262, upload-time = "2024-06-30T18:52:48.052Z" }, +] + [[package]] name = "argon2-cffi" version = "25.1.0" @@ -673,6 +687,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/96/2e2e02a4142448d278f855de3c48ec75317f6b50e715bf94c643948b5361/bespokelabs_curator-0.1.9.post1-py3-none-any.whl", hash = "sha256:c9c8099f2301c3d45062f61b42c821f23bbdb741e82a792f2c4b3272adb1afcd", size = 1205745, upload-time = "2024-11-19T20:20:19.955Z" }, ] +[[package]] +name = "bfi" +version = "1.0.4" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/47/6a303e4e0fadf5b8604b229b84f786a3f02bf8825468ca29c1e8a2201cea/bfi-1.0.4-py3-none-any.whl", hash = "sha256:d58a10b1f3ba2c7cfd9a9ebba3756455850a0cf9e63223c67e88438676c28579", size = 159200, upload-time = "2022-06-05T23:42:39.271Z" }, +] + [[package]] name = "black" version = "25.9.0" @@ -879,6 +901,16 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d6/4f/101071f880b4da05771128c0b89f41e334cff044dee05fb013c8f4be661c/cbor2-5.8.0-py3-none-any.whl", hash = "sha256:3727d80f539567b03a7aa11890e57798c67092c38df9e6c23abb059e0f65069c", size = 24374, upload-time = "2025-12-30T18:44:21.476Z" }, ] +[[package]] +name = "cellpylib" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/1e/c606ac9bf0277524cd935d6c2bb5b387c3bf401fa8ddbb59b622d6dd50af/cellpylib-2.4.0.tar.gz", hash = "sha256:e23ff7b5f647c23958748727c899b4e57b667c9f18cbc66d7577dbfaab3729ef", size = 38296, upload-time = "2023-02-15T03:48:53.06Z" } + [[package]] name = "certifi" version = "2026.1.4" @@ -1622,6 +1654,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b9/52/3144a4a8341f125d357a8f6e07aeb23b0bde7133dabbd0c78b984c079a1f/draccus-0.11.5-py3-none-any.whl", hash = "sha256:67da70bfbdea8fe67667322c64068d6d8cd832694a418b6c71d9d3b9f5c3c005", size = 78447, upload-time = "2025-04-22T20:41:47.414Z" }, ] +[[package]] +name = "drawsvg" +version = "2.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/07/a2c3db84e6af6fa761de905b39109fe24eff2c8d52653c1bff968b6b965d/drawsvg-2.4.1-py3-none-any.whl", hash = "sha256:241ff024968e03542bc8685b41a285427303c17f81eae1933229d26bb65b7fda", size = 44067, upload-time = "2026-01-04T00:06:09.817Z" }, +] + [[package]] name = "duckdb" version = "1.5.0" @@ -4591,6 +4631,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/4a/63a9540e3ca73709f4200564a737d63a4c8c9c4dd032bab8535f507c190a/lxml_html_clean-0.4.3-py3-none-any.whl", hash = "sha256:63fd7b0b9c3a2e4176611c2ca5d61c4c07ffca2de76c14059a81a2825833731e", size = 14177, upload-time = "2025-10-02T20:49:23.749Z" }, ] +[[package]] +name = "magiccube" +version = "0.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/29/aa/008a6dc9884d855cf5f7e3e02466b0b44cb1e64fd4b5ce906d4ce03030f9/magiccube-0.3.0.tar.gz", hash = "sha256:4fa47df1e99c19ac74597279a17de5a652083d4d68f0d5e52cf09202317e5d04", size = 19407, upload-time = "2024-07-15T06:58:03.757Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/8d/354a7357e03fb9277401989b7866b8f05ae407319e7bc450b958f9187308/magiccube-0.3.0-py3-none-any.whl", hash = "sha256:a8da38cac53dff1667949c9a12485bebb9d62c01ae73db41eb7aba1bc1c8a00c", size = 16931, upload-time = "2024-07-15T06:58:02.631Z" }, +] + [[package]] name = "marin" version = "0.99" @@ -4750,6 +4802,9 @@ math = [ { name = "pylatexenc" }, { name = "sympy" }, ] +reasoning-gym = [ + { name = "reasoning-gym" }, +] metrics = [ { name = "google-cloud-logging" }, ] @@ -4809,6 +4864,7 @@ requires-dist = [ { name = "pyarrow", specifier = ">=22" }, { name = "pylatexenc", marker = "extra == 'math'" }, { name = "ray", specifier = "==2.53.0" }, + { name = "reasoning-gym", marker = "extra == 'reasoning-gym'", specifier = "==0.1.19" }, { name = "regex" }, { name = "requests" }, { name = "resiliparse", specifier = ">=0.17.2", index = "https://marin-community.github.io/chatnoir-resiliparse/simple" }, @@ -4835,7 +4891,7 @@ requires-dist = [ { name = "wandb" }, { name = "warcio" }, ] -provides-extras = ["gpu", "tpu", "cpu", "rl", "eval", "evalchemy", "vizier", "vllm", "harbor", "dedup", "math"] +provides-extras = ["gpu", "tpu", "cpu", "rl", "reasoning-gym", "eval", "evalchemy", "vizier", "vllm", "harbor", "dedup", "math"] [package.metadata.requires-dev] dev = [ @@ -7987,6 +8043,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cd/8a/37362fc2b949d5f733a8b0f2ff51ba423914cabefe69f1d1b6aab710f5fe/pybind11-3.0.1-py3-none-any.whl", hash = "sha256:aa8f0aa6e0a94d3b64adfc38f560f33f15e589be2175e103c0a33c6bce55ee89", size = 293611, upload-time = "2025-08-22T20:09:25.235Z" }, ] +[[package]] +name = "pycosat" +version = "0.6.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5b/81/cf8ebf77fc4f06f680ad3ee43d0d01826f6d6054828f1cf3b42d944b82a1/pycosat-0.6.6.tar.gz", hash = "sha256:a376cfae20b16fcfbef24bf3c047a8a294c35032bb051fa98842c12bbab6f0ff", size = 71623, upload-time = "2023-10-03T15:45:48.058Z" } + [[package]] name = "pycountry" version = "24.6.1" @@ -8133,6 +8195,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9b/4d/b9add7c84060d4c1906abe9a7e5359f2a60f7a9a4f67268b2766673427d8/pyee-13.0.0-py3-none-any.whl", hash = "sha256:48195a3cddb3b1515ce0695ed76036b5ccc2ef3a9f963ff9f77aec0139845498", size = 15730, upload-time = "2025-03-17T18:53:14.532Z" }, ] +[[package]] +name = "pyfiglet" +version = "1.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a0/f2/2649b2acace54f861eccd4ab163bfd914236fc93ddb1df02dad2a2552b14/pyfiglet-1.0.2.tar.gz", hash = "sha256:758788018ab8faaddc0984e1ea05ff330d3c64be663c513cc1f105f6a3066dab", size = 832345, upload-time = "2023-09-13T20:56:21.022Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/03/bef6fff907e212d67a0003f8ea4819307bba91b2856074a0763dd483ccc4/pyfiglet-1.0.2-py3-none-any.whl", hash = "sha256:889b351d79c99e50a3f619c8f8e6ffdb27fd8c939fc43ecbd7559bd57d5f93ea", size = 1085824, upload-time = "2023-09-13T20:56:18.707Z" }, +] + [[package]] name = "pygments" version = "2.19.2" @@ -8973,6 +9044,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7a/01/e093a0270f33fad4cf8aa92849abb8db98b8bd9ede8d71a987faea368b02/realtime-2.27.2-py3-none-any.whl", hash = "sha256:34a9cbb26a274e707e8fc9e3ee0a66de944beac0fe604dc336d1e985db2c830f", size = 22219, upload-time = "2026-01-14T04:53:36.827Z" }, ] +[[package]] +name = "reasoning-gym" +version = "0.1.19" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "arckit" }, + { name = "bfi" }, + { name = "cellpylib" }, + { name = "magiccube" }, + { name = "pycosat" }, + { name = "pyfiglet" }, + { name = "pytz" }, + { name = "pyyaml" }, + { name = "sympy" }, + { name = "tabulate" }, + { name = "zss" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2d/38/fada242ffdcb1a04bc252ab1fb2f03ddc443c67de9fd3f03094a3cb447f2/reasoning_gym-0.1.19.tar.gz", hash = "sha256:74cfdbcb99fa56a0efeed5c5dd90417afbfaba5242de2115651f530206a2b84f", size = 6775247, upload-time = "2025-06-04T06:16:43.345Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7f/c0/82a283b8ccf2c8bcfae41936bb7e52611330c29759380aa5ed6db26a0a88/reasoning_gym-0.1.19-py3-none-any.whl", hash = "sha256:233f3b960b097b339e61e67cf64a5ac65ad2112b6b90c8f424c26ff97b80b1e1", size = 6969786, upload-time = "2025-06-04T06:16:33.391Z" }, +] + [[package]] name = "referencing" version = "0.37.0" @@ -12394,6 +12487,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, ] +[[package]] +name = "zss" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1e/d1/ed34d12f55d07cc1efb61d74fb2f64f46a705557f5bdd1ef1b810f0e2ec5/zss-1.2.0.tar.gz", hash = "sha256:07bb937441929ccb82961f4f7b80fbce9e2b20d0e46ddcbcbc1fcb094f585b50", size = 9790, upload-time = "2018-03-12T15:02:20.208Z" } + [[package]] name = "zstandard" version = "0.25.0"