Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion skyrl/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ def __init__(self, base_model: str, config: BaseModel):
pass

@abstractmethod
def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None:
def create_model(
self,
model_id: str,
lora_config: types.LoraConfig,
model_role: str = "policy",
seed_was_provided: bool = True,
) -> None:
"""Create a new model in the backend.

Creates optimizer and configures LoRA adapter.
Expand All @@ -52,6 +58,7 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role:
model_id: The model identifier
lora_config: LoRA configuration with rank and alpha
model_role: Logical role for the model (e.g. policy or critic)
seed_was_provided: Whether the client explicitly set the seed.
"""
pass

Expand Down
24 changes: 21 additions & 3 deletions skyrl/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,13 @@ def has_model(self, model_id: str) -> bool:
"""Check if a model is registered with the backend."""
return model_id in self.models

def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None:
def create_model(
self,
model_id: str,
lora_config: types.LoraConfig,
model_role: str = "policy",
seed_was_provided: bool = True,
) -> None:
"""Create a new model in the backend.

Creates optimizer and configures LoRA adapter. Allocates adapter_index internally.
Expand Down Expand Up @@ -1109,8 +1115,20 @@ def serialize(k, v):
)
return getattr(super(), method)(**kwargs)

def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None:
self._broadcast_and_call("create_model", model_id=model_id, lora_config=lora_config, model_role=model_role)
def create_model(
self,
model_id: str,
lora_config: types.LoraConfig,
model_role: str = "policy",
seed_was_provided: bool = True,
) -> None:
self._broadcast_and_call(
"create_model",
model_id=model_id,
lora_config=lora_config,
model_role=model_role,
seed_was_provided=seed_was_provided,
)

def forward_backward(self, prepared_batch: types.PreparedModelPassBatch):
return self._broadcast_and_call("forward_backward", prepared_batch=prepared_batch)
Expand Down
80 changes: 66 additions & 14 deletions skyrl/backends/skyrl_train_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
)
from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup
from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch
from skyrl.backends.skyrl_train_lora import (
resolve_skyrl_train_lora_config,
skyrl_train_lora_signature,
)
from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE, SKYRL_RAY_PG_TIMEOUT_IN_S
from skyrl.tinker import types
from skyrl.train.config import SkyRLTrainConfig, get_config_as_yaml_str
Expand Down Expand Up @@ -98,8 +102,18 @@ def _build_skyrl_train_config(

# Apply LoRA configuration
if lora_config is not None and lora_config.rank > 0:
cfg.trainer.seed = int(lora_config.seed)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The int() cast is redundant as lora_config.seed is already defined as an int in the LoraConfig type definition.

Suggested change
cfg.trainer.seed = int(lora_config.seed)
cfg.trainer.seed = lora_config.seed

lora_type = cfg.trainer.policy.megatron_config.lora_config.lora_type
resolved_lora = resolve_skyrl_train_lora_config(
lora_config,
strategy=cfg.trainer.strategy,
lora_type=lora_type,
pipeline_parallel_size=cfg.trainer.policy.megatron_config.pipeline_model_parallel_size,
)
cfg.trainer.policy.model.lora.rank = lora_config.rank
cfg.trainer.policy.model.lora.alpha = int(lora_config.alpha)
cfg.trainer.policy.model.lora.target_modules = resolved_lora.target_modules
cfg.trainer.policy.model.lora.exclude_modules = resolved_lora.exclude_modules

logger.info("SkyRL-Train config:\n%s", get_config_as_yaml_str(cfg))
return cfg
Expand Down Expand Up @@ -133,6 +147,7 @@ def __init__(self, base_model: str, config: SkyRLTrainBackendOverrides):
# Captured at first LoRA create_model; subsequent create_models must
# match this signature exactly. None when no LoRA model is registered.
self._base_lora_signature: tuple | None = None
self._base_lora_seed: int | None = None

# New inference infrastructure
self._server_groups: list = []
Expand Down Expand Up @@ -287,8 +302,18 @@ def _build_critic(self, CriticWorker, lora_config: types.LoraConfig) -> None:
num_policy_gpus == num_critic_gpus
), "num_policy_gpus and num_critic_gpus must be the same when colocating policy and critic model"

cfg.trainer.critic.model.lora.rank = lora_config.rank
cfg.trainer.critic.model.lora.alpha = int(lora_config.alpha)
if lora_config is not None and lora_config.rank > 0:
cfg.trainer.seed = int(lora_config.seed)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The int() cast is redundant here as well.

Suggested change
cfg.trainer.seed = int(lora_config.seed)
cfg.trainer.seed = lora_config.seed

resolved_lora = resolve_skyrl_train_lora_config(
lora_config,
strategy=cfg.trainer.strategy,
lora_type=cfg.trainer.policy.megatron_config.lora_config.lora_type,
pipeline_parallel_size=cfg.trainer.policy.megatron_config.pipeline_model_parallel_size,
)
cfg.trainer.critic.model.lora.rank = lora_config.rank
cfg.trainer.critic.model.lora.alpha = int(lora_config.alpha)
cfg.trainer.critic.model.lora.target_modules = resolved_lora.target_modules
cfg.trainer.critic.model.lora.exclude_modules = resolved_lora.exclude_modules
critic_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.critic_num_nodes,
Expand Down Expand Up @@ -356,14 +381,28 @@ def _ensure_inference_engines(self):
self._inference_engines_initialized = True

def _lora_signature_from(self, lora_config: types.LoraConfig) -> tuple:
# Tinker's public LoraConfig only exposes rank + alpha (plus
# seed/train_attn/train_mlp/train_unembed) - pending support https://github.com/NovaSky-AI/SkyRL/issues/1632.
# Equality across adapters therefore reduces to (rank, alpha); the worker-side
# AdapterStore additionally verifies parallel-state equality via
# its own LoraSignature.
return (int(lora_config.rank), int(lora_config.alpha))

def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None:
if self._cfg is not None:
strategy = self._cfg.trainer.strategy
lora_type = self._cfg.trainer.policy.megatron_config.lora_config.lora_type
pipeline_parallel_size = self._cfg.trainer.policy.megatron_config.pipeline_model_parallel_size
else:
strategy = self.config.strategy
lora_type = "lora"
pipeline_parallel_size = 1
return skyrl_train_lora_signature(
lora_config,
strategy=strategy,
lora_type=lora_type,
pipeline_parallel_size=pipeline_parallel_size,
)

def create_model(
self,
model_id: str,
lora_config: types.LoraConfig,
model_role: str = "policy",
seed_was_provided: bool = True,
) -> None:
if model_id in self._model_ids_to_role:
raise ValueError(f"Model '{model_id}' already exists")

Expand All @@ -389,10 +428,17 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role:
if new_signature != self._base_lora_signature:
raise ValueError(
f"LoRA signature mismatch for model '{model_id}': "
f"got (rank, alpha)={new_signature}, "
f"got {new_signature}, "
f"first adapter registered with {self._base_lora_signature}. "
"Multi-LoRA with the SkyRLTrainBackend requires identical (rank, alpha) across all "
"adapters; target_modules is fixed server-side."
"Multi-LoRA with the SkyRLTrainBackend requires identical "
"(rank, alpha, target_modules, exclude_modules, lora_type) across all adapters."
)
if seed_was_provided and self._base_lora_seed is not None and int(lora_config.seed) != self._base_lora_seed:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The int() cast is redundant here as well.

Suggested change
if seed_was_provided and self._base_lora_seed is not None and int(lora_config.seed) != self._base_lora_seed:
if seed_was_provided and self._base_lora_seed is not None and lora_config.seed != self._base_lora_seed:

raise ValueError(
f"LoRA seed mismatch for model '{model_id}': got seed={lora_config.seed}, "
f"first adapter registered with seed={self._base_lora_seed}. "
"SkyRLTrainBackend additional adapters are initialized from the first pristine adapter, "
"so explicit seeds must match."
)
self._dispatch.register_adapter("policy", model_id)
self._model_ids_to_role[model_id] = model_role
Expand All @@ -402,7 +448,11 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role:

# First-time setup OR critic creation (existing path).
if model_role == "policy":
self._cfg = _build_skyrl_train_config(self.base_model, self.config, lora_config)
self._cfg = _build_skyrl_train_config(
self.base_model,
self.config,
lora_config,
)

if not ray.is_initialized():
logger.info("Initializing Ray with runtime environment")
Expand All @@ -425,6 +475,7 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role:
self._build_policy(PolicyWorker, model_id=model_id)
if is_lora:
self._base_lora_signature = self._lora_signature_from(lora_config)
self._base_lora_seed = int(lora_config.seed)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The int() cast is redundant here as well.

Suggested change
self._base_lora_seed = int(lora_config.seed)
self._base_lora_seed = lora_config.seed

elif model_role == "critic":
if model_role in self._model_ids_to_role.values():
raise ValueError(f"SkyRLTrainBackend already has a '{model_role}' model")
Expand Down Expand Up @@ -496,6 +547,7 @@ def delete_model(self, model_id: str) -> None:
self._renderer = None
self._colocate_pg = None
self._base_lora_signature = None
self._base_lora_seed = None
logger.info(f"Successfully deleted model {model_id}")

def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch, role: str) -> TrainingInputBatch:
Expand Down
131 changes: 131 additions & 0 deletions skyrl/backends/skyrl_train_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Translate Tinker LoRA options into SkyRL-Train LoRA target modules."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable

from skyrl.tinker import types

FSDP_ATTN_TARGET_MODULES = (
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"query_key_value",
"attn.c_attn",
"attn.c_proj",
)
FSDP_MLP_TARGET_MODULES = (
"gate_proj",
"up_proj",
"down_proj",
"fc1",
"fc2",
"c_fc",
"mlp.c_proj",
)
FSDP_UNEMBED_TARGET_MODULES = (
"lm_head",
"embed_out",
"output_projection",
)

MEGATRON_LORA_ATTN_TARGET_MODULES = ("linear_qkv", "linear_proj")
MEGATRON_LORA_MLP_TARGET_MODULES = ("linear_fc1", "linear_fc2")
MEGATRON_CANONICAL_LORA_ATTN_TARGET_MODULES = ("linear_q", "linear_k", "linear_v", "linear_proj")
MEGATRON_CANONICAL_LORA_MLP_TARGET_MODULES = ("linear_fc1_up", "linear_fc1_gate", "linear_fc2")
MEGATRON_UNEMBED_TARGET_MODULES = ("output_layer",)


@dataclass(frozen=True)
class ResolvedSkyRLTrainLoraConfig:
target_modules: str | list[str]
exclude_modules: list[str] | None = None


def normalize_lora_targets(target_modules: str | Iterable[str]) -> tuple[str, ...]:
if isinstance(target_modules, str):
return (target_modules,)
return tuple(target_modules)


def _dedupe_targets(target_modules: Iterable[str]) -> list[str]:
return list(dict.fromkeys(target_modules))


def _validate_train_targets(lora_config: types.LoraConfig) -> None:
if lora_config.rank > 0 and not (lora_config.train_attn or lora_config.train_mlp or lora_config.train_unembed):
raise ValueError("At least one of train_attn, train_mlp, or train_unembed must be true for LoRA rank > 0")


def resolve_skyrl_train_lora_config(
lora_config: types.LoraConfig,
strategy: str,
lora_type: str = "lora",
pipeline_parallel_size: int = 1,
) -> ResolvedSkyRLTrainLoraConfig:
"""Resolve Tinker LoRA train flags to the target module surface SkyRL-Train expects."""

_validate_train_targets(lora_config)
if lora_config.rank <= 0:
return ResolvedSkyRLTrainLoraConfig(target_modules="all-linear")
if lora_config.train_attn and lora_config.train_mlp and not lora_config.train_unembed:
return ResolvedSkyRLTrainLoraConfig(target_modules="all-linear")

if strategy in ("fsdp", "fsdp2"):
target_modules: list[str] = []
if lora_config.train_attn:
target_modules.extend(FSDP_ATTN_TARGET_MODULES)
if lora_config.train_mlp:
target_modules.extend(FSDP_MLP_TARGET_MODULES)
if lora_config.train_unembed:
target_modules.extend(FSDP_UNEMBED_TARGET_MODULES)
return ResolvedSkyRLTrainLoraConfig(target_modules=_dedupe_targets(target_modules))

if strategy == "megatron":
if lora_config.train_unembed and pipeline_parallel_size > 1:
raise ValueError(
"train_unembed=True is not supported for the Megatron SkyRL-Train backend when "
"pipeline_model_parallel_size > 1 because output_layer only exists on the final pipeline stage"
)
if lora_type == "canonical_lora":
attn_targets = MEGATRON_CANONICAL_LORA_ATTN_TARGET_MODULES
mlp_targets = MEGATRON_CANONICAL_LORA_MLP_TARGET_MODULES
elif lora_type == "lora":
attn_targets = MEGATRON_LORA_ATTN_TARGET_MODULES
mlp_targets = MEGATRON_LORA_MLP_TARGET_MODULES
else:
raise ValueError(f"Unsupported Megatron LoRA type: {lora_type!r}")

target_modules = []
if lora_config.train_attn:
target_modules.extend(attn_targets)
if lora_config.train_mlp:
target_modules.extend(mlp_targets)
if lora_config.train_unembed:
target_modules.extend(MEGATRON_UNEMBED_TARGET_MODULES)
return ResolvedSkyRLTrainLoraConfig(target_modules=_dedupe_targets(target_modules))

raise ValueError(f"Unsupported SkyRL-Train strategy for Tinker LoRA config: {strategy!r}")


def skyrl_train_lora_signature(
lora_config: types.LoraConfig,
strategy: str,
lora_type: str = "lora",
pipeline_parallel_size: int = 1,
) -> tuple:
resolved = resolve_skyrl_train_lora_config(
lora_config,
strategy=strategy,
lora_type=lora_type,
pipeline_parallel_size=pipeline_parallel_size,
)
return (
int(lora_config.rank),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The int() cast is redundant as lora_config.rank is already an int.

Suggested change
int(lora_config.rank),
lora_config.rank,

int(lora_config.alpha),
normalize_lora_targets(resolved.target_modules),
tuple(resolved.exclude_modules or ()),
lora_type if strategy == "megatron" else strategy,
)
Loading
Loading