From 8ed4953213d84291e6f0347fadb003836d5b2f00 Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Sun, 10 May 2026 03:49:22 -0700 Subject: [PATCH] Avoid Ray actor lookup for policy loss resolution --- skyrl/backends/skyrl_train/utils/ppo_utils.py | 86 ++++++++++++++++--- .../megatron/megatron_model_wrapper.py | 11 +-- .../workers/megatron/megatron_worker.py | 1 + skyrl/backends/skyrl_train/workers/worker.py | 80 +++++++++++------ skyrl/backends/skyrl_train_backend.py | 6 ++ skyrl/train/sft_trainer.py | 6 ++ skyrl/train/trainer.py | 5 ++ .../skyrl_train/utils/test_ppo_utils.py | 60 +++++++++++++ tests/train/test_trainer.py | 65 ++++++++++++++ 9 files changed, 277 insertions(+), 43 deletions(-) diff --git a/skyrl/backends/skyrl_train/utils/ppo_utils.py b/skyrl/backends/skyrl_train/utils/ppo_utils.py index 8a173a13e0..a7fb7acdc3 100644 --- a/skyrl/backends/skyrl_train/utils/ppo_utils.py +++ b/skyrl/backends/skyrl_train/utils/ppo_utils.py @@ -19,7 +19,7 @@ from collections import defaultdict from enum import StrEnum from functools import wraps -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Iterable, List, Mapping, Optional, Tuple, Union import numpy as np import ray @@ -341,6 +341,50 @@ def get(cls, name: str) -> Callable: raise ValueError(f"Unknown {cls._function_type.lower()} '{name}'. Available: {available}") return cls._functions[name] + @classmethod + def get_local(cls, name: Union[str, StrEnum]) -> Callable: + """Get a function from the local process registry without syncing with Ray.""" + if isinstance(name, StrEnum): + name = name.value + + if name not in cls._functions: + available = list(cls._functions.keys()) + raise ValueError(f"Unknown local {cls._function_type.lower()} '{name}'. Available locally: {available}") + return cls._functions[name] + + @classmethod + def snapshot_serialized( + cls, + names: Optional[Iterable[Union[str, StrEnum]]] = None, + *, + sync_with_actor: bool = True, + ) -> dict[str, bytes]: + """Create a serialized registry snapshot for worker-local installation. + + This is intended for driver/control paths. When requested, it may sync with + the named Ray actor before serializing, but the resulting snapshot can be + installed in workers without any Ray actor calls. + """ + if sync_with_actor and ray.is_initialized(): + cls.sync_with_actor() + + if names is None: + function_names = list(cls._functions.keys()) + else: + function_names = [name.value if isinstance(name, StrEnum) else name for name in names] + + return {name: cloudpickle.dumps(cls.get_local(name)) for name in function_names} + + @classmethod + def install_serialized(cls, snapshot: Mapping[str, bytes]) -> None: + """Install a serialized registry snapshot locally without syncing with Ray.""" + for name, func_serialized in snapshot.items(): + try: + cls._functions[name] = cloudpickle.loads(func_serialized) + except Exception as e: + logger.error(f"Error deserializing {name} into local {cls._function_type} registry: {e}") + raise e + @classmethod def list_available(cls) -> List[str]: """List all registered functions.""" @@ -471,26 +515,34 @@ class PolicyLossRegistry(BaseFunctionRegistry): _actor_name = "policy_loss_registry" _function_type = "policy loss" + @classmethod + def _default_policy_losses(cls) -> dict[str, tuple[PolicyLossType, Callable]]: + return { + "regular": (PolicyLossType.REGULAR, ppo_policy_loss), + "dual_clip": (PolicyLossType.DUAL_CLIP, ppo_policy_loss), + "gspo": (PolicyLossType.GSPO, gspo_policy_loss), + "clip_cov": (PolicyLossType.CLIP_COV, compute_policy_loss_clip_cov), + "kl_cov": (PolicyLossType.KL_COV, compute_policy_loss_kl_cov), + "sapo": (PolicyLossType.SAPO, sapo_policy_loss), + "cross_entropy": (PolicyLossType.CROSS_ENTROPY, cross_entropy_loss), + "importance_sampling": (PolicyLossType.IMPORTANCE_SAMPLING, importance_sampling_loss), + "rollout_is": (PolicyLossType.ROLLOUT_IS, rollout_is_policy_loss), + } + @classmethod def repopulate_registry(cls): """Repopulate the registry with default policy loss functions.""" pl_avail = set(cls.list_available()) - pl_types = { - "regular": [PolicyLossType.REGULAR, ppo_policy_loss], - "dual_clip": [PolicyLossType.DUAL_CLIP, ppo_policy_loss], - "gspo": [PolicyLossType.GSPO, gspo_policy_loss], - "clip_cov": [PolicyLossType.CLIP_COV, compute_policy_loss_clip_cov], - "kl_cov": [PolicyLossType.KL_COV, compute_policy_loss_kl_cov], - "sapo": [PolicyLossType.SAPO, sapo_policy_loss], - "cross_entropy": [PolicyLossType.CROSS_ENTROPY, cross_entropy_loss], - "importance_sampling": [PolicyLossType.IMPORTANCE_SAMPLING, importance_sampling_loss], - "rollout_is": [PolicyLossType.ROLLOUT_IS, rollout_is_policy_loss], - } - - for pl_name, (pl_type, pl_func) in pl_types.items(): + for pl_name, (pl_type, pl_func) in cls._default_policy_losses().items(): if pl_name not in pl_avail: cls.register(pl_type, pl_func) + @classmethod + def repopulate_local_registry(cls): + """Install default policy losses locally without syncing with Ray.""" + for pl_name, (_, pl_func) in cls._default_policy_losses().items(): + cls._functions.setdefault(pl_name, pl_func) + def register_advantage_estimator(name: Union[str, AdvantageEstimator]): """Decorator to register an advantage estimator function.""" @@ -529,6 +581,12 @@ def sync_registries(): logger.info("Synced registries to ray actor") +def snapshot_policy_loss_registry_for_workers() -> dict[str, bytes]: + """Snapshot policy losses so workers can look them up without actor sync.""" + PolicyLossRegistry.repopulate_local_registry() + return PolicyLossRegistry.snapshot_serialized(sync_with_actor=True) + + @register_policy_loss(PolicyLossType.REGULAR) @register_policy_loss(PolicyLossType.DUAL_CLIP) def ppo_policy_loss( diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index f2d1aa0a44..ae8a26fa09 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -21,10 +21,7 @@ from_parallel_logits_to_logprobs, vocab_parallel_entropy, ) -from skyrl.backends.skyrl_train.utils.ppo_utils import ( - PolicyLossRegistry, - compute_approx_kl, -) +from skyrl.backends.skyrl_train.utils.ppo_utils import compute_approx_kl from skyrl.backends.skyrl_train.utils.replay_utils import ( setup_per_microbatch_replay_backward, setup_per_microbatch_replay_forward, @@ -40,11 +37,13 @@ def __init__( actor_module: List[nn.Module], actor_optimizer: Optional[torch.optim.Optimizer] = None, policy_loss_fn: Optional[Callable] = None, + policy_loss_fn_resolver: Optional[Callable[[str], Callable]] = None, ): self.cfg = config self.actor_module = actor_module self.actor_optimizer = actor_optimizer self.policy_loss_fn = policy_loss_fn + self.policy_loss_fn_resolver = policy_loss_fn_resolver self.use_sample_packing = self.cfg.use_sample_packing config = get_model_config(self.actor_module[0]) @@ -225,7 +224,9 @@ def forward_backward_mini_batch( # Resolve loss function resolved_loss_name = loss_fn if loss_fn is not None else self.cfg.algorithm.policy_loss_type if loss_fn is not None: - current_loss_fn = PolicyLossRegistry.get(loss_fn) + if self.policy_loss_fn_resolver is None: + raise ValueError("MegatronModelWrapper requires policy_loss_fn_resolver for loss overrides") + current_loss_fn = self.policy_loss_fn_resolver(loss_fn) else: current_loss_fn = self.policy_loss_fn diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 1e7761e495..e13b98f0ab 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -673,6 +673,7 @@ def init_model(self, model_path, num_training_steps: int = 1e9): actor_module=self.actor_module, actor_optimizer=self.optimizer, policy_loss_fn=self.policy_loss_fn, + policy_loss_fn_resolver=self._get_policy_loss_fn, ) self.empty_cuda_cache = self.cfg.policy.megatron_config.empty_cuda_cache diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index 43a73fc28c..74c164d42b 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -5,7 +5,17 @@ from collections import defaultdict from ctypes import CDLL, POINTER, Structure, c_char_p, c_int, c_ulong, c_void_p from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Type, + Union, +) import ray import torch @@ -486,6 +496,7 @@ def __init__( colocate_all: bool = False, sequence_parallel_size: int = 1, record_memory: bool = False, + worker_init_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """ Args: @@ -504,6 +515,7 @@ def __init__( self.colocate_all = colocate_all self.sequence_parallel_size = sequence_parallel_size self.record_memory = record_memory + self.worker_init_kwargs = dict(worker_init_kwargs or {}) self._initiate_actors(pg, num_gpus_per_actor) def _initiate_actors(self, pg: Optional[ResolvedPlacementGroup], num_gpus_per_actor: float): @@ -572,16 +584,18 @@ def _scheduling_strategy_for_rank(rank): if sched is not None: actor_options["scheduling_strategy"] = sched - master_actor = self.ray_actor_type.options(**actor_options).remote( - cfg=self.cfg, - world_size=world_size, - rank=0, - local_rank=0, - master_addr=None, - master_port=None, - sequence_parallel_size=self.sequence_parallel_size, - record_memory=self.record_memory, - ) + master_actor_kwargs = { + **self.worker_init_kwargs, + "cfg": self.cfg, + "world_size": world_size, + "rank": 0, + "local_rank": 0, + "master_addr": None, + "master_port": None, + "sequence_parallel_size": self.sequence_parallel_size, + "record_memory": self.record_memory, + } + master_actor = self.ray_actor_type.options(**actor_options).remote(**master_actor_kwargs) self._actor_handlers = [master_actor] if world_size > 1: @@ -598,16 +612,18 @@ def _scheduling_strategy_for_rank(rank): if sched is not None: actor_options["scheduling_strategy"] = sched - worker_actor = self.ray_actor_type.options(**actor_options).remote( - cfg=self.cfg, - world_size=world_size, - rank=rank, - local_rank=local_rank, - master_addr=master_addr, - master_port=master_port, - sequence_parallel_size=self.sequence_parallel_size, - record_memory=self.record_memory, - ) + worker_actor_kwargs = { + **self.worker_init_kwargs, + "cfg": self.cfg, + "world_size": world_size, + "rank": rank, + "local_rank": local_rank, + "master_addr": master_addr, + "master_port": master_port, + "sequence_parallel_size": self.sequence_parallel_size, + "record_memory": self.record_memory, + } + worker_actor = self.ray_actor_type.options(**actor_options).remote(**worker_actor_kwargs) self._actor_handlers.append(worker_actor) # Initialize process group @@ -682,15 +698,31 @@ def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kw class PolicyWorkerBase(Worker): - def __init__(self, **kwargs): + def __init__(self, policy_loss_registry_snapshot: Optional[Mapping[str, bytes]] = None, **kwargs): super().__init__(**kwargs) + if policy_loss_registry_snapshot is None: + PolicyLossRegistry.repopulate_local_registry() + else: + PolicyLossRegistry.install_serialized(policy_loss_registry_snapshot) self.model: nn.Module = None self.scheduler: LRScheduler = None self.optimizer: Optimizer = None self.strategy: DistributedStrategy = None self.record_memory: bool = False self.mesh_rank: MeshRank = None - self.policy_loss_fn: Callable = PolicyLossRegistry.get(self.cfg.algorithm.policy_loss_type) + self._policy_loss_fn_cache: Dict[str, Callable] = {} + self.policy_loss_fn: Callable = self._get_policy_loss_fn(self.cfg.algorithm.policy_loss_type) + + def _get_policy_loss_fn(self, loss_name: str) -> Callable: + try: + if loss_name not in self._policy_loss_fn_cache: + self._policy_loss_fn_cache[loss_name] = PolicyLossRegistry.get_local(loss_name) + return self._policy_loss_fn_cache[loss_name] + except ValueError as e: + raise ValueError( + f"{e}. If this is a custom policy loss, register it before initialize_ray()/worker creation " + "so it is included in the policy-loss snapshot." + ) from e def forward_backward( self, @@ -789,7 +821,7 @@ def _forward_backward_micro( resolved_loss_name = loss_fn if loss_fn is not None else self.cfg.algorithm.policy_loss_type if loss_fn is not None: # Use the provided loss function (Tinker API style) - current_loss_fn = PolicyLossRegistry.get(loss_fn) + current_loss_fn = self._get_policy_loss_fn(loss_fn) else: # Fall back to config default current_loss_fn = self.policy_loss_fn diff --git a/skyrl/backends/skyrl_train_backend.py b/skyrl/backends/skyrl_train_backend.py index ebf612f575..dc391bef36 100644 --- a/skyrl/backends/skyrl_train_backend.py +++ b/skyrl/backends/skyrl_train_backend.py @@ -26,6 +26,9 @@ TrainingInputBatch, pad_training_input_batch, ) +from skyrl.backends.skyrl_train.utils.ppo_utils import ( + snapshot_policy_loss_registry_for_workers, +) from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE, SKYRL_RAY_PG_TIMEOUT_IN_S @@ -239,6 +242,9 @@ def _build_policy(self, PolicyWorker, model_id: str): colocate_all=colocate_all, sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size, record_memory=cfg.trainer.policy.record_memory, + worker_init_kwargs={ + "policy_loss_registry_snapshot": snapshot_policy_loss_registry_for_workers(), + }, ) # set to a large number for megatron scheduler init diff --git a/skyrl/train/sft_trainer.py b/skyrl/train/sft_trainer.py index 7b1f324de5..213ee60eac 100644 --- a/skyrl/train/sft_trainer.py +++ b/skyrl/train/sft_trainer.py @@ -34,6 +34,9 @@ from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch from skyrl.backends.skyrl_train.utils.io import io +from skyrl.backends.skyrl_train.utils.ppo_utils import ( + snapshot_policy_loss_registry_for_workers, +) from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch from skyrl.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S @@ -388,6 +391,9 @@ def _init_workers(self): colocate_all=False, sequence_parallel_size=self.cfg.trainer.policy.sequence_parallel_size, record_memory=self.cfg.trainer.policy.record_memory, + worker_init_kwargs={ + "policy_loss_registry_snapshot": snapshot_policy_loss_registry_for_workers(), + }, ) num_training_steps = ( self.sft_cfg.dummy_run_max_steps if self.sft_cfg.dummy_run_full_ctx else self.sft_cfg.num_steps diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 61655f0f65..500883e581 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -395,6 +395,9 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker): pg = None use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward + policy_worker_init_kwargs = { + "policy_loss_registry_snapshot": ppo_utils.snapshot_policy_loss_registry_for_workers(), + } if cfg.trainer.placement.colocate_all: num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes @@ -422,6 +425,7 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker): colocate_all=True, sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size, record_memory=cfg.trainer.policy.record_memory, + worker_init_kwargs=policy_worker_init_kwargs, ) if use_ref_model: assert ( @@ -484,6 +488,7 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker): num_gpus_per_actor=0.75 if pg else 1, colocate_all=False, sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size, + worker_init_kwargs=policy_worker_init_kwargs, ) if use_ref_model: ref_model = PPORayActorGroup( diff --git a/tests/backends/skyrl_train/utils/test_ppo_utils.py b/tests/backends/skyrl_train/utils/test_ppo_utils.py index 81055debda..68b90ffa3a 100644 --- a/tests/backends/skyrl_train/utils/test_ppo_utils.py +++ b/tests/backends/skyrl_train/utils/test_ppo_utils.py @@ -7,6 +7,7 @@ import numpy as np import pytest +import ray import torch from skyrl.backends.skyrl_train.utils.ppo_utils import ( @@ -25,6 +26,7 @@ reduce_loss, register_advantage_estimator, register_policy_loss, + snapshot_policy_loss_registry_for_workers, ) from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean @@ -447,6 +449,64 @@ def decorated_policy_loss(log_probs, old_log_probs, advantages, config, loss_mas PolicyLossRegistry.unregister("test_policy_decorator") +def test_policy_loss_registry_get_local_does_not_sync(monkeypatch): + """Local lookup should never contact the Ray registry actor.""" + PolicyLossRegistry.repopulate_registry() + + def fail_sync(): + pytest.fail("get_local should not sync with the registry actor") + + monkeypatch.setattr(PolicyLossRegistry, "sync_with_actor", fail_sync) + + assert PolicyLossRegistry.get_local("regular") is not None + with pytest.raises(ValueError, match="Unknown local policy loss"): + PolicyLossRegistry.get_local("missing_policy_loss") + + +def test_policy_loss_registry_snapshot_install_round_trip(): + """Serialized snapshots should install custom losses locally without actor sync.""" + from skyrl.train.config import AlgorithmConfig + + loss_name = "snapshot_round_trip_policy_loss" + + def custom_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None, rollout_logprobs=None): + return torch.tensor(4.0), {"clip_ratio": 0.7} + + PolicyLossRegistry._functions.pop(loss_name, None) + + try: + PolicyLossRegistry.register(loss_name, custom_policy_loss) + snapshot = PolicyLossRegistry.snapshot_serialized([loss_name], sync_with_actor=False) + PolicyLossRegistry.unregister(loss_name) + + PolicyLossRegistry.install_serialized(snapshot) + retrieved = PolicyLossRegistry.get_local(loss_name) + loss, metrics = retrieved( + log_probs=torch.tensor([[0.1]]), + old_log_probs=torch.tensor([[0.2]]), + advantages=torch.tensor([[1.0]]), + config=AlgorithmConfig(policy_loss_type=loss_name), + ) + + assert loss.item() == 4.0 + assert metrics["clip_ratio"] == 0.7 + finally: + PolicyLossRegistry._functions.pop(loss_name, None) + if ray.is_initialized(): + try: + actor = ray.get_actor("policy_loss_registry") + ray.get(actor.unregister.remote(loss_name)) + except ValueError: + pass + + +def test_snapshot_policy_loss_registry_for_workers_syncs_then_serializes(): + snapshot = snapshot_policy_loss_registry_for_workers() + + assert "regular" in snapshot + assert "cross_entropy" in snapshot + + def test_registry_cross_ray_process(): """Test that registry works with Ray and that functions can be retrieved and called from different processes""" try: diff --git a/tests/train/test_trainer.py b/tests/train/test_trainer.py index 20f93495fb..931876d557 100644 --- a/tests/train/test_trainer.py +++ b/tests/train/test_trainer.py @@ -11,6 +11,11 @@ from pytest import approx from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +from skyrl.backends.skyrl_train.utils.ppo_utils import ( + PolicyLossRegistry, + repopulate_all_registries, + snapshot_policy_loss_registry_for_workers, +) from skyrl.backends.skyrl_train.workers.worker import CriticWorkerBase, PolicyWorkerBase from skyrl.backends.skyrl_train.workers.worker_utils import BatchIterator from skyrl.train.config import SkyRLTrainConfig @@ -62,6 +67,66 @@ def dummy_generator(): return MagicMock() +@pytest.fixture(scope="module", autouse=True) +def _ensure_default_registries_for_trainer_tests(): + repopulate_all_registries() + + +def test_policy_worker_uses_local_policy_loss_snapshot(monkeypatch, dummy_config): + dummy_config.trainer.algorithm.policy_loss_type = "regular" + snapshot = snapshot_policy_loss_registry_for_workers() + + def fail_sync(): + pytest.fail("policy workers should resolve losses from the local snapshot") + + monkeypatch.setattr(PolicyLossRegistry, "sync_with_actor", fail_sync) + + worker = PolicyWorkerBase( + cfg=dummy_config.trainer, + world_size=1, + rank=0, + local_rank=0, + master_addr="localhost", + master_port=12345, + sequence_parallel_size=1, + policy_loss_registry_snapshot=snapshot, + ) + + assert worker.policy_loss_fn is worker._get_policy_loss_fn("regular") + assert worker._get_policy_loss_fn("cross_entropy") is not None + + with pytest.raises(ValueError, match="custom policy loss"): + worker._get_policy_loss_fn("missing_policy_loss") + + +def test_policy_worker_repopulates_builtin_losses_without_actor_sync(monkeypatch, dummy_config): + dummy_config.trainer.algorithm.policy_loss_type = "regular" + saved_functions = dict(PolicyLossRegistry._functions) + PolicyLossRegistry._functions.clear() + + def fail_sync(): + pytest.fail("policy worker fallback should not sync with the registry actor") + + monkeypatch.setattr(PolicyLossRegistry, "sync_with_actor", fail_sync) + + try: + worker = PolicyWorkerBase( + cfg=dummy_config.trainer, + world_size=1, + rank=0, + local_rank=0, + master_addr="localhost", + master_port=12345, + sequence_parallel_size=1, + ) + + assert worker.policy_loss_fn is not None + assert "regular" in PolicyLossRegistry._functions + finally: + PolicyLossRegistry._functions.clear() + PolicyLossRegistry._functions.update(saved_functions) + + def _get_test_data(trainer: RayPPOTrainer): trainer.critic_model = MagicMock() # pretend we're using a critic