Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions lib/marin/src/marin/rl/rl_experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def _build_rl_job_config(
tags = [*config.tags, config.model_config.name.split("/")[-1]]
checkpoints_path = join_path(output_path, "checkpoints")
rollout_storage_path = join_path(output_path, "rollouts")
metadata_path = join_path(output_path, "metadata")

trainer_config = TrainerConfig(
tracker=WandbConfig(
Expand Down Expand Up @@ -328,6 +329,12 @@ def _build_rl_job_config(
name=f"{name}-rollout",
tags=[*config.tags, "rollout", config.model_config.name.split("/")[-1]],
),
eval_tracker=RolloutTrackerConfig(
project=config.project_name,
name=f"{name}-eval",
tags=[*config.tags, "eval", config.model_config.name.split("/")[-1]],
),
metadata_path=metadata_path,
pip_dependency_groups=(
config.model_config.pip_dependency_groups if config.model_config.pip_dependency_groups else ["vllm", "math"]
),
Expand Down
17 changes: 17 additions & 0 deletions lib/marin/src/marin/rl/rl_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ class RLJobConfig:
rollout_tracker: RolloutTrackerConfig | None = None
"""Tracker configuration for rollout workers. Uses a standalone tracker to avoid JAX deadlocks."""

eval_tracker: RolloutTrackerConfig | None = None
"""Dedicated tracker configuration for eval and micro-eval streams."""

eval_owner_worker_index: int = 0
"""Worker index that owns the eval tracker stream."""

metadata_path: str | None = None
"""Base path for durable RL telemetry artifacts and event shards."""

pip_dependency_groups: list[str] = field(default_factory=list)
"""Extra pip dependency groups to include for all workers."""

Expand Down Expand Up @@ -292,8 +301,11 @@ def to_worker_configs(self) -> tuple[TrainWorkerConfig, RolloutWorkerConfig]:
initial_checkpoint=self.config.initial_checkpoint,
vocab_size=self.config.vocab_size,
run_id=self.config.run_id,
root_run_id=self.config.run_id,
curriculum_config=self.config.curriculum,
seed=self.config.seed,
metadata_path=self.config.metadata_path,
instance_id=self.config.resolved_instance_id,
)

# Create rollout worker config
Expand All @@ -309,12 +321,17 @@ def to_worker_configs(self) -> tuple[TrainWorkerConfig, RolloutWorkerConfig]:
weight_transfer=weight_transfer_config,
rollout_storage=self.config.rollout_storage,
run_id=self.config.run_id,
root_run_id=self.config.run_id,
seed=self.config.seed + 1000,
inference_type=self.config.inference_type,
inference_config=inference_config,
system_prompt=self.config.system_prompt,
inflight_weight_updates=self.config.inflight_weight_updates,
tracker_config=self.config.rollout_tracker,
eval_tracker_config=self.config.eval_tracker,
eval_owner_worker_index=self.config.eval_owner_worker_index,
metadata_path=self.config.metadata_path,
instance_id=self.config.resolved_instance_id,
)

return train_worker_config, rollout_worker_config
160 changes: 157 additions & 3 deletions lib/marin/src/marin/rl/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Literal

import equinox as eqx
import haliax as hax
Expand All @@ -31,10 +31,9 @@
from jax.experimental import multihost_utils
from levanter.inference.openai import InferenceServer
from levanter.models.lm_model import LmConfig
from levanter.tokenizers import MarinTokenizer
from levanter.trainer import TrainerConfig
from levanter.utils.jax_utils import barrier_sync
from levanter.tokenizers import MarinTokenizer
from typing import Literal

from levanter.utils.mesh import MeshConfig
from marin.rl.curriculum import CurriculumConfig
Expand All @@ -55,6 +54,7 @@
)
from marin.rl.metrics import pass_at_k_estimator
from marin.rl.model_utils import load_model_from_checkpoint
from marin.rl.telemetry import EventShardWriter, StepProvenance, TelemetryEvent, TrackerRunRef, TrackerStream

from .rollout_storage import RolloutStorageConfig, RolloutWriter
from .types import (
Expand Down Expand Up @@ -148,6 +148,35 @@ class _NoOpTracker:
def log(self, metrics, step=None):
pass

def log_summary(self, metrics):
pass

def log_artifact(self, artifact_path, *, name=None, artifact_type=None):
pass

@property
def run_id(self) -> str | None:
return None

@property
def run_name(self) -> str | None:
return None

@property
def run_url(self) -> str | None:
return None

@property
def entity(self) -> str | None:
return None

@property
def project(self) -> str | None:
return None

def as_tracker_ref(self, *, stream: TrackerStream, worker_index: int | None = None) -> TrackerRunRef | None:
return None

def finish(self):
pass

Expand Down Expand Up @@ -184,6 +213,7 @@ class RolloutTracker:
"""

def __init__(self, config: RolloutTrackerConfig, run_id: str):
self._config = config
run_name = config.name or run_id
self._run = wandb.init(
entity=config.entity,
Expand All @@ -201,6 +231,46 @@ def log(self, metrics: Mapping[str, Any], *, step: int | None = None):
return
self._run.log(metrics, step=step)

def log_summary(self, metrics: Mapping[str, Any]) -> None:
self._run.summary.update(metrics)

def log_artifact(self, artifact_path, *, name: str | None = None, artifact_type: str | None = None):
self._run.log_artifact(artifact_path, name=name, type=artifact_type)

@property
def run_id(self) -> str | None:
return getattr(self._run, "id", None)

@property
def run_name(self) -> str | None:
return getattr(self._run, "name", None)

@property
def run_url(self) -> str | None:
return getattr(self._run, "url", None)

@property
def entity(self) -> str | None:
return getattr(self._run, "entity", None) or self._config.entity

@property
def project(self) -> str | None:
return getattr(self._run, "project", None) or self._config.project

def as_tracker_ref(self, *, stream: TrackerStream, worker_index: int | None = None) -> TrackerRunRef:
tracker_run_id = self.run_id
if tracker_run_id is None:
raise RuntimeError("Tracker run ID is unavailable")
return TrackerRunRef(
stream=stream,
tracker_run_id=tracker_run_id,
project=self.project,
entity=self.entity,
run_name=self.run_name,
run_url=self.run_url,
worker_index=worker_index,
)

def finish(self):
self._run.finish()

Expand All @@ -222,9 +292,15 @@ class RolloutWorkerConfig:
inference_config: LevanterInferenceContextConfig | vLLMInferenceContextConfig
"""Configuration for inference context."""

root_run_id: str | None = None
"""Stable RL job run id used for shared telemetry and artifact paths."""

tracker_config: RolloutTrackerConfig | None = None
"""Configuration for the rollout worker's tracker. If None, tracking is disabled."""

eval_tracker_config: RolloutTrackerConfig | None = None
"""Configuration for the dedicated eval tracker."""

seed: int = 0
"""Random seed to use for sampling."""
max_rollouts: int | None = None
Expand All @@ -248,6 +324,15 @@ class RolloutWorkerConfig:
worker_index: int = 0
"""Index of this worker among all rollout workers."""

eval_owner_worker_index: int = 0
"""Worker index responsible for owning the eval tracker stream."""

metadata_path: str | None = None
"""Base metadata path for local RL telemetry artifacts."""

instance_id: str | None = None
"""Coordinator invocation identifier used for per-attempt artifact naming."""


def find_open_port() -> int:
"""Find an open port on localhost."""
Expand Down Expand Up @@ -415,6 +500,8 @@ class RolloutWorker:
_tokenizer: MarinTokenizer
_environments: dict[str, MarinEnv]
tracker: Any # levanter.Tracker or RolloutTracker
_event_writer: EventShardWriter | None
_eval_event_writer: EventShardWriter | None

def __init__(self, config: RolloutWorkerConfig, runtime: RLRuntimeHandles):
config.trainer.id = f"{config.run_id}-rollout"
Expand Down Expand Up @@ -447,12 +534,15 @@ def __init__(self, config: RolloutWorkerConfig, runtime: RLRuntimeHandles):
self._current_train_step: int = -1
self._last_transfer_counters = RolloutTransferCounterSnapshot()
self._last_eval_train_step: int | None = None
self._event_writer = None
self._eval_event_writer = None

self._tokenizer = config.tokenizer

# Event to signal that the first weight transfer has completed.
# For inflight weight updates, we block inference until initial weights are received.
self._first_weights_received = threading.Event()
self._initialize_telemetry()

logger.info("Starting rollout policy context with weight transfer config %s", self.config.weight_transfer)

Expand Down Expand Up @@ -489,6 +579,70 @@ def __init__(self, config: RolloutWorkerConfig, runtime: RLRuntimeHandles):
)
self.weight_transfer_thread.start()

def _initialize_telemetry(self) -> None:
metadata_path = self.config.metadata_path
instance_id = self.config.instance_id
if metadata_path is None or instance_id is None:
return

telemetry_run_id = self.config.root_run_id or self.config.run_id

self._event_writer = EventShardWriter(
metadata_path=metadata_path,
run_id=telemetry_run_id,
stream=TrackerStream.ROLLOUT,
instance_id=instance_id,
worker_index=self.config.worker_index,
)
self._event_writer.append(
TelemetryEvent(
run_id=telemetry_run_id,
stream=TrackerStream.ROLLOUT,
event_type="worker_started",
provenance=StepProvenance(
instance_id=instance_id,
worker_index=self.config.worker_index,
),
payload={"worker_role": "rollout"},
)
)
self._register_artifact_ref(self._event_writer.artifact_ref())

tracker_ref_fn = getattr(self.tracker, "as_tracker_ref", None)
register_tracker = getattr(self._runtime.run_state, "register_tracker_ref", None)
if tracker_ref_fn is not None and register_tracker is not None:
tracker_ref = tracker_ref_fn(stream=TrackerStream.ROLLOUT, worker_index=self.config.worker_index)
if tracker_ref is not None:
register_tracker.remote(tracker_ref).result()

if self.config.worker_index != self.config.eval_owner_worker_index:
return

self._eval_event_writer = EventShardWriter(
metadata_path=metadata_path,
run_id=telemetry_run_id,
stream=TrackerStream.EVAL,
instance_id=instance_id,
)
self._eval_event_writer.append(
TelemetryEvent(
run_id=telemetry_run_id,
stream=TrackerStream.EVAL,
event_type="stream_initialized",
provenance=StepProvenance(
instance_id=instance_id,
worker_index=self.config.worker_index,
),
payload={"worker_role": "eval_owner"},
)
)
self._register_artifact_ref(self._eval_event_writer.artifact_ref())

def _register_artifact_ref(self, artifact_ref) -> None:
register_artifact = getattr(self._runtime.run_state, "register_artifact_ref", None)
if register_artifact is not None:
register_artifact.remote(artifact_ref).result()

def _load_environment(self, lesson_id: str) -> MarinEnv:
"""Load environment from lesson ID."""
if lesson_id in self._environments:
Expand Down
31 changes: 31 additions & 0 deletions lib/marin/src/marin/rl/run_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from dataclasses import dataclass
from enum import StrEnum

from marin.rl.telemetry import ArtifactRef, TrackerRunRef, TrackerStream

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -51,7 +53,10 @@ def __init__(self):
self._status: RunStatus = RunStatus.RUNNING
self._failure_message: str | None = None
self._train_step: int = -1
self._next_eval_sequence: int = 0
self._rollout_transfer_counters: dict[int, RolloutTransferCounters] = {}
self._tracker_refs: dict[tuple[str, int | None], TrackerRunRef] = {}
self._artifact_refs: dict[str, ArtifactRef] = {}

def get_status(self) -> str:
return self._status.value
Expand All @@ -73,6 +78,11 @@ def update_train_step(self, step: int) -> None:
def get_train_step(self) -> int:
return self._train_step

def next_eval_sequence(self) -> int:
sequence = self._next_eval_sequence
self._next_eval_sequence += 1
return sequence

def get_rollout_transfer_counters(self, worker_index: int) -> RolloutTransferCounters:
return self._rollout_transfer_counters.get(worker_index, RolloutTransferCounters())

Expand All @@ -96,6 +106,27 @@ def add_rollout_transfer_counters(
self._rollout_transfer_counters[worker_index] = updated
return updated

def register_tracker_ref(self, tracker_ref: TrackerRunRef) -> TrackerRunRef:
key = (tracker_ref.stream.value, tracker_ref.worker_index)
self._tracker_refs[key] = tracker_ref
return tracker_ref

def get_tracker_ref(self, stream: TrackerStream, worker_index: int | None = None) -> TrackerRunRef | None:
return self._tracker_refs.get((stream.value, worker_index))

def list_tracker_refs(self) -> list[TrackerRunRef]:
return list(self._tracker_refs.values())

def register_artifact_ref(self, artifact_ref: ArtifactRef) -> ArtifactRef:
self._artifact_refs[artifact_ref.name] = artifact_ref
return artifact_ref

def get_artifact_ref(self, name: str) -> ArtifactRef | None:
return self._artifact_refs.get(name)

def list_artifact_refs(self) -> list[ArtifactRef]:
return list(self._artifact_refs.values())

def mark_completed(self) -> None:
if self._status == RunStatus.RUNNING:
self._status = RunStatus.COMPLETED
Expand Down
Loading