Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 72 additions & 14 deletions skyrl/backends/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections import defaultdict
from enum import StrEnum
from functools import wraps
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Iterable, List, Mapping, Optional, Tuple, Union

import numpy as np
import ray
Expand Down Expand Up @@ -341,6 +341,50 @@ def get(cls, name: str) -> Callable:
raise ValueError(f"Unknown {cls._function_type.lower()} '{name}'. Available: {available}")
return cls._functions[name]

@classmethod
def get_local(cls, name: Union[str, StrEnum]) -> Callable:
"""Get a function from the local process registry without syncing with Ray."""
if isinstance(name, StrEnum):
name = name.value

if name not in cls._functions:
available = list(cls._functions.keys())
raise ValueError(f"Unknown local {cls._function_type.lower()} '{name}'. Available locally: {available}")
return cls._functions[name]

@classmethod
def snapshot_serialized(
cls,
names: Optional[Iterable[Union[str, StrEnum]]] = None,
*,
sync_with_actor: bool = True,
) -> dict[str, bytes]:
"""Create a serialized registry snapshot for worker-local installation.

This is intended for driver/control paths. When requested, it may sync with
the named Ray actor before serializing, but the resulting snapshot can be
installed in workers without any Ray actor calls.
"""
if sync_with_actor and ray.is_initialized():
cls.sync_with_actor()

if names is None:
function_names = list(cls._functions.keys())
else:
function_names = [name.value if isinstance(name, StrEnum) else name for name in names]

return {name: cloudpickle.dumps(cls.get_local(name)) for name in function_names}

@classmethod
def install_serialized(cls, snapshot: Mapping[str, bytes]) -> None:
"""Install a serialized registry snapshot locally without syncing with Ray."""
for name, func_serialized in snapshot.items():
try:
cls._functions[name] = cloudpickle.loads(func_serialized)
except Exception as e:
logger.error(f"Error deserializing {name} into local {cls._function_type} registry: {e}")
raise e

@classmethod
def list_available(cls) -> List[str]:
"""List all registered functions."""
Expand Down Expand Up @@ -471,26 +515,34 @@ class PolicyLossRegistry(BaseFunctionRegistry):
_actor_name = "policy_loss_registry"
_function_type = "policy loss"

@classmethod
def _default_policy_losses(cls) -> dict[str, tuple[PolicyLossType, Callable]]:
return {
"regular": (PolicyLossType.REGULAR, ppo_policy_loss),
"dual_clip": (PolicyLossType.DUAL_CLIP, ppo_policy_loss),
"gspo": (PolicyLossType.GSPO, gspo_policy_loss),
"clip_cov": (PolicyLossType.CLIP_COV, compute_policy_loss_clip_cov),
"kl_cov": (PolicyLossType.KL_COV, compute_policy_loss_kl_cov),
"sapo": (PolicyLossType.SAPO, sapo_policy_loss),
"cross_entropy": (PolicyLossType.CROSS_ENTROPY, cross_entropy_loss),
"importance_sampling": (PolicyLossType.IMPORTANCE_SAMPLING, importance_sampling_loss),
"rollout_is": (PolicyLossType.ROLLOUT_IS, rollout_is_policy_loss),
}

@classmethod
def repopulate_registry(cls):
"""Repopulate the registry with default policy loss functions."""
pl_avail = set(cls.list_available())
pl_types = {
"regular": [PolicyLossType.REGULAR, ppo_policy_loss],
"dual_clip": [PolicyLossType.DUAL_CLIP, ppo_policy_loss],
"gspo": [PolicyLossType.GSPO, gspo_policy_loss],
"clip_cov": [PolicyLossType.CLIP_COV, compute_policy_loss_clip_cov],
"kl_cov": [PolicyLossType.KL_COV, compute_policy_loss_kl_cov],
"sapo": [PolicyLossType.SAPO, sapo_policy_loss],
"cross_entropy": [PolicyLossType.CROSS_ENTROPY, cross_entropy_loss],
"importance_sampling": [PolicyLossType.IMPORTANCE_SAMPLING, importance_sampling_loss],
"rollout_is": [PolicyLossType.ROLLOUT_IS, rollout_is_policy_loss],
}

for pl_name, (pl_type, pl_func) in pl_types.items():
for pl_name, (pl_type, pl_func) in cls._default_policy_losses().items():
if pl_name not in pl_avail:
cls.register(pl_type, pl_func)

@classmethod
def repopulate_local_registry(cls):
"""Install default policy losses locally without syncing with Ray."""
for pl_name, (_, pl_func) in cls._default_policy_losses().items():
cls._functions.setdefault(pl_name, pl_func)


def register_advantage_estimator(name: Union[str, AdvantageEstimator]):
"""Decorator to register an advantage estimator function."""
Expand Down Expand Up @@ -529,6 +581,12 @@ def sync_registries():
logger.info("Synced registries to ray actor")


def snapshot_policy_loss_registry_for_workers() -> dict[str, bytes]:
"""Snapshot policy losses so workers can look them up without actor sync."""
PolicyLossRegistry.repopulate_local_registry()
return PolicyLossRegistry.snapshot_serialized(sync_with_actor=True)


@register_policy_loss(PolicyLossType.REGULAR)
@register_policy_loss(PolicyLossType.DUAL_CLIP)
def ppo_policy_loss(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
from_parallel_logits_to_logprobs,
vocab_parallel_entropy,
)
from skyrl.backends.skyrl_train.utils.ppo_utils import (
PolicyLossRegistry,
compute_approx_kl,
)
from skyrl.backends.skyrl_train.utils.ppo_utils import compute_approx_kl
from skyrl.backends.skyrl_train.utils.replay_utils import (
setup_per_microbatch_replay_backward,
setup_per_microbatch_replay_forward,
Expand All @@ -40,11 +37,13 @@ def __init__(
actor_module: List[nn.Module],
actor_optimizer: Optional[torch.optim.Optimizer] = None,
policy_loss_fn: Optional[Callable] = None,
policy_loss_fn_resolver: Optional[Callable[[str], Callable]] = None,
):
self.cfg = config
self.actor_module = actor_module
self.actor_optimizer = actor_optimizer
self.policy_loss_fn = policy_loss_fn
self.policy_loss_fn_resolver = policy_loss_fn_resolver
self.use_sample_packing = self.cfg.use_sample_packing

config = get_model_config(self.actor_module[0])
Expand Down Expand Up @@ -225,7 +224,9 @@ def forward_backward_mini_batch(
# Resolve loss function
resolved_loss_name = loss_fn if loss_fn is not None else self.cfg.algorithm.policy_loss_type
if loss_fn is not None:
current_loss_fn = PolicyLossRegistry.get(loss_fn)
if self.policy_loss_fn_resolver is None:
raise ValueError("MegatronModelWrapper requires policy_loss_fn_resolver for loss overrides")
current_loss_fn = self.policy_loss_fn_resolver(loss_fn)
else:
current_loss_fn = self.policy_loss_fn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ def init_model(self, model_path, num_training_steps: int = 1e9):
actor_module=self.actor_module,
actor_optimizer=self.optimizer,
policy_loss_fn=self.policy_loss_fn,
policy_loss_fn_resolver=self._get_policy_loss_fn,
)

self.empty_cuda_cache = self.cfg.policy.megatron_config.empty_cuda_cache
Expand Down
80 changes: 56 additions & 24 deletions skyrl/backends/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
from collections import defaultdict
from ctypes import CDLL, POINTER, Structure, c_char_p, c_int, c_ulong, c_void_p
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Type,
Union,
)

import ray
import torch
Expand Down Expand Up @@ -486,6 +496,7 @@ def __init__(
colocate_all: bool = False,
sequence_parallel_size: int = 1,
record_memory: bool = False,
worker_init_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Args:
Expand All @@ -504,6 +515,7 @@ def __init__(
self.colocate_all = colocate_all
self.sequence_parallel_size = sequence_parallel_size
self.record_memory = record_memory
self.worker_init_kwargs = dict(worker_init_kwargs or {})
self._initiate_actors(pg, num_gpus_per_actor)

def _initiate_actors(self, pg: Optional[ResolvedPlacementGroup], num_gpus_per_actor: float):
Expand Down Expand Up @@ -572,16 +584,18 @@ def _scheduling_strategy_for_rank(rank):
if sched is not None:
actor_options["scheduling_strategy"] = sched

master_actor = self.ray_actor_type.options(**actor_options).remote(
cfg=self.cfg,
world_size=world_size,
rank=0,
local_rank=0,
master_addr=None,
master_port=None,
sequence_parallel_size=self.sequence_parallel_size,
record_memory=self.record_memory,
)
master_actor_kwargs = {
**self.worker_init_kwargs,
"cfg": self.cfg,
"world_size": world_size,
"rank": 0,
"local_rank": 0,
"master_addr": None,
"master_port": None,
"sequence_parallel_size": self.sequence_parallel_size,
"record_memory": self.record_memory,
}
master_actor = self.ray_actor_type.options(**actor_options).remote(**master_actor_kwargs)
self._actor_handlers = [master_actor]

if world_size > 1:
Expand All @@ -598,16 +612,18 @@ def _scheduling_strategy_for_rank(rank):
if sched is not None:
actor_options["scheduling_strategy"] = sched

worker_actor = self.ray_actor_type.options(**actor_options).remote(
cfg=self.cfg,
world_size=world_size,
rank=rank,
local_rank=local_rank,
master_addr=master_addr,
master_port=master_port,
sequence_parallel_size=self.sequence_parallel_size,
record_memory=self.record_memory,
)
worker_actor_kwargs = {
**self.worker_init_kwargs,
"cfg": self.cfg,
"world_size": world_size,
"rank": rank,
"local_rank": local_rank,
"master_addr": master_addr,
"master_port": master_port,
"sequence_parallel_size": self.sequence_parallel_size,
"record_memory": self.record_memory,
}
worker_actor = self.ray_actor_type.options(**actor_options).remote(**worker_actor_kwargs)
self._actor_handlers.append(worker_actor)

# Initialize process group
Expand Down Expand Up @@ -682,15 +698,31 @@ def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kw


class PolicyWorkerBase(Worker):
def __init__(self, **kwargs):
def __init__(self, policy_loss_registry_snapshot: Optional[Mapping[str, bytes]] = None, **kwargs):
super().__init__(**kwargs)
if policy_loss_registry_snapshot is None:
PolicyLossRegistry.repopulate_local_registry()
else:
PolicyLossRegistry.install_serialized(policy_loss_registry_snapshot)
self.model: nn.Module = None
self.scheduler: LRScheduler = None
self.optimizer: Optimizer = None
self.strategy: DistributedStrategy = None
self.record_memory: bool = False
self.mesh_rank: MeshRank = None
self.policy_loss_fn: Callable = PolicyLossRegistry.get(self.cfg.algorithm.policy_loss_type)
self._policy_loss_fn_cache: Dict[str, Callable] = {}
self.policy_loss_fn: Callable = self._get_policy_loss_fn(self.cfg.algorithm.policy_loss_type)

def _get_policy_loss_fn(self, loss_name: str) -> Callable:
try:
if loss_name not in self._policy_loss_fn_cache:
self._policy_loss_fn_cache[loss_name] = PolicyLossRegistry.get_local(loss_name)
return self._policy_loss_fn_cache[loss_name]
except ValueError as e:
raise ValueError(
f"{e}. If this is a custom policy loss, register it before initialize_ray()/worker creation "
"so it is included in the policy-loss snapshot."
) from e

def forward_backward(
self,
Expand Down Expand Up @@ -789,7 +821,7 @@ def _forward_backward_micro(
resolved_loss_name = loss_fn if loss_fn is not None else self.cfg.algorithm.policy_loss_type
if loss_fn is not None:
# Use the provided loss function (Tinker API style)
current_loss_fn = PolicyLossRegistry.get(loss_fn)
current_loss_fn = self._get_policy_loss_fn(loss_fn)
else:
# Fall back to config default
current_loss_fn = self.policy_loss_fn
Expand Down
6 changes: 6 additions & 0 deletions skyrl/backends/skyrl_train_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
TrainingInputBatch,
pad_training_input_batch,
)
from skyrl.backends.skyrl_train.utils.ppo_utils import (
snapshot_policy_loss_registry_for_workers,
)
from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup
from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch
from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE, SKYRL_RAY_PG_TIMEOUT_IN_S
Expand Down Expand Up @@ -239,6 +242,9 @@ def _build_policy(self, PolicyWorker, model_id: str):
colocate_all=colocate_all,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
record_memory=cfg.trainer.policy.record_memory,
worker_init_kwargs={
"policy_loss_registry_snapshot": snapshot_policy_loss_registry_for_workers(),
},
)

# set to a large number for megatron scheduler init
Expand Down
6 changes: 6 additions & 0 deletions skyrl/train/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@

from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch
from skyrl.backends.skyrl_train.utils.io import io
from skyrl.backends.skyrl_train.utils.ppo_utils import (
snapshot_policy_loss_registry_for_workers,
)
from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup
from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch
from skyrl.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S
Expand Down Expand Up @@ -388,6 +391,9 @@ def _init_workers(self):
colocate_all=False,
sequence_parallel_size=self.cfg.trainer.policy.sequence_parallel_size,
record_memory=self.cfg.trainer.policy.record_memory,
worker_init_kwargs={
"policy_loss_registry_snapshot": snapshot_policy_loss_registry_for_workers(),
},
)
num_training_steps = (
self.sft_cfg.dummy_run_max_steps if self.sft_cfg.dummy_run_full_ctx else self.sft_cfg.num_steps
Expand Down
5 changes: 5 additions & 0 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker):
pg = None

use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward
policy_worker_init_kwargs = {
"policy_loss_registry_snapshot": ppo_utils.snapshot_policy_loss_registry_for_workers(),
}

if cfg.trainer.placement.colocate_all:
num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
Expand Down Expand Up @@ -422,6 +425,7 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker):
colocate_all=True,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
record_memory=cfg.trainer.policy.record_memory,
worker_init_kwargs=policy_worker_init_kwargs,
)
if use_ref_model:
assert (
Expand Down Expand Up @@ -484,6 +488,7 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker):
num_gpus_per_actor=0.75 if pg else 1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
worker_init_kwargs=policy_worker_init_kwargs,
)
if use_ref_model:
ref_model = PPORayActorGroup(
Expand Down
Loading
Loading