-
Notifications
You must be signed in to change notification settings - Fork 321
[tinker] Support additional LoRA config parameters #1643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,10 @@ | |
| ) | ||
| from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup | ||
| from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch | ||
| from skyrl.backends.skyrl_train_lora import ( | ||
| resolve_skyrl_train_lora_config, | ||
| skyrl_train_lora_signature, | ||
| ) | ||
| from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE, SKYRL_RAY_PG_TIMEOUT_IN_S | ||
| from skyrl.tinker import types | ||
| from skyrl.train.config import SkyRLTrainConfig, get_config_as_yaml_str | ||
|
|
@@ -98,8 +102,18 @@ def _build_skyrl_train_config( | |
|
|
||
| # Apply LoRA configuration | ||
| if lora_config is not None and lora_config.rank > 0: | ||
| cfg.trainer.seed = int(lora_config.seed) | ||
| lora_type = cfg.trainer.policy.megatron_config.lora_config.lora_type | ||
| resolved_lora = resolve_skyrl_train_lora_config( | ||
| lora_config, | ||
| strategy=cfg.trainer.strategy, | ||
| lora_type=lora_type, | ||
| pipeline_parallel_size=cfg.trainer.policy.megatron_config.pipeline_model_parallel_size, | ||
| ) | ||
| cfg.trainer.policy.model.lora.rank = lora_config.rank | ||
| cfg.trainer.policy.model.lora.alpha = int(lora_config.alpha) | ||
| cfg.trainer.policy.model.lora.target_modules = resolved_lora.target_modules | ||
| cfg.trainer.policy.model.lora.exclude_modules = resolved_lora.exclude_modules | ||
|
|
||
| logger.info("SkyRL-Train config:\n%s", get_config_as_yaml_str(cfg)) | ||
| return cfg | ||
|
|
@@ -133,6 +147,7 @@ def __init__(self, base_model: str, config: SkyRLTrainBackendOverrides): | |
| # Captured at first LoRA create_model; subsequent create_models must | ||
| # match this signature exactly. None when no LoRA model is registered. | ||
| self._base_lora_signature: tuple | None = None | ||
| self._base_lora_seed: int | None = None | ||
|
|
||
| # New inference infrastructure | ||
| self._server_groups: list = [] | ||
|
|
@@ -287,8 +302,18 @@ def _build_critic(self, CriticWorker, lora_config: types.LoraConfig) -> None: | |
| num_policy_gpus == num_critic_gpus | ||
| ), "num_policy_gpus and num_critic_gpus must be the same when colocating policy and critic model" | ||
|
|
||
| cfg.trainer.critic.model.lora.rank = lora_config.rank | ||
| cfg.trainer.critic.model.lora.alpha = int(lora_config.alpha) | ||
| if lora_config is not None and lora_config.rank > 0: | ||
| cfg.trainer.seed = int(lora_config.seed) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| resolved_lora = resolve_skyrl_train_lora_config( | ||
| lora_config, | ||
| strategy=cfg.trainer.strategy, | ||
| lora_type=cfg.trainer.policy.megatron_config.lora_config.lora_type, | ||
| pipeline_parallel_size=cfg.trainer.policy.megatron_config.pipeline_model_parallel_size, | ||
| ) | ||
| cfg.trainer.critic.model.lora.rank = lora_config.rank | ||
| cfg.trainer.critic.model.lora.alpha = int(lora_config.alpha) | ||
| cfg.trainer.critic.model.lora.target_modules = resolved_lora.target_modules | ||
| cfg.trainer.critic.model.lora.exclude_modules = resolved_lora.exclude_modules | ||
| critic_model = PPORayActorGroup( | ||
| cfg.trainer, | ||
| cfg.trainer.placement.critic_num_nodes, | ||
|
|
@@ -356,14 +381,28 @@ def _ensure_inference_engines(self): | |
| self._inference_engines_initialized = True | ||
|
|
||
| def _lora_signature_from(self, lora_config: types.LoraConfig) -> tuple: | ||
| # Tinker's public LoraConfig only exposes rank + alpha (plus | ||
| # seed/train_attn/train_mlp/train_unembed) - pending support https://github.com/NovaSky-AI/SkyRL/issues/1632. | ||
| # Equality across adapters therefore reduces to (rank, alpha); the worker-side | ||
| # AdapterStore additionally verifies parallel-state equality via | ||
| # its own LoraSignature. | ||
| return (int(lora_config.rank), int(lora_config.alpha)) | ||
|
|
||
| def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None: | ||
| if self._cfg is not None: | ||
| strategy = self._cfg.trainer.strategy | ||
| lora_type = self._cfg.trainer.policy.megatron_config.lora_config.lora_type | ||
| pipeline_parallel_size = self._cfg.trainer.policy.megatron_config.pipeline_model_parallel_size | ||
| else: | ||
| strategy = self.config.strategy | ||
| lora_type = "lora" | ||
| pipeline_parallel_size = 1 | ||
| return skyrl_train_lora_signature( | ||
| lora_config, | ||
| strategy=strategy, | ||
| lora_type=lora_type, | ||
| pipeline_parallel_size=pipeline_parallel_size, | ||
| ) | ||
|
|
||
| def create_model( | ||
| self, | ||
| model_id: str, | ||
| lora_config: types.LoraConfig, | ||
| model_role: str = "policy", | ||
| seed_was_provided: bool = True, | ||
| ) -> None: | ||
| if model_id in self._model_ids_to_role: | ||
| raise ValueError(f"Model '{model_id}' already exists") | ||
|
|
||
|
|
@@ -389,10 +428,17 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: | |
| if new_signature != self._base_lora_signature: | ||
| raise ValueError( | ||
| f"LoRA signature mismatch for model '{model_id}': " | ||
| f"got (rank, alpha)={new_signature}, " | ||
| f"got {new_signature}, " | ||
| f"first adapter registered with {self._base_lora_signature}. " | ||
| "Multi-LoRA with the SkyRLTrainBackend requires identical (rank, alpha) across all " | ||
| "adapters; target_modules is fixed server-side." | ||
| "Multi-LoRA with the SkyRLTrainBackend requires identical " | ||
| "(rank, alpha, target_modules, exclude_modules, lora_type) across all adapters." | ||
| ) | ||
| if seed_was_provided and self._base_lora_seed is not None and int(lora_config.seed) != self._base_lora_seed: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| raise ValueError( | ||
| f"LoRA seed mismatch for model '{model_id}': got seed={lora_config.seed}, " | ||
| f"first adapter registered with seed={self._base_lora_seed}. " | ||
| "SkyRLTrainBackend additional adapters are initialized from the first pristine adapter, " | ||
| "so explicit seeds must match." | ||
| ) | ||
| self._dispatch.register_adapter("policy", model_id) | ||
| self._model_ids_to_role[model_id] = model_role | ||
|
|
@@ -402,7 +448,11 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: | |
|
|
||
| # First-time setup OR critic creation (existing path). | ||
| if model_role == "policy": | ||
| self._cfg = _build_skyrl_train_config(self.base_model, self.config, lora_config) | ||
| self._cfg = _build_skyrl_train_config( | ||
| self.base_model, | ||
| self.config, | ||
| lora_config, | ||
| ) | ||
|
|
||
| if not ray.is_initialized(): | ||
| logger.info("Initializing Ray with runtime environment") | ||
|
|
@@ -425,6 +475,7 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: | |
| self._build_policy(PolicyWorker, model_id=model_id) | ||
| if is_lora: | ||
| self._base_lora_signature = self._lora_signature_from(lora_config) | ||
| self._base_lora_seed = int(lora_config.seed) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| elif model_role == "critic": | ||
| if model_role in self._model_ids_to_role.values(): | ||
| raise ValueError(f"SkyRLTrainBackend already has a '{model_role}' model") | ||
|
|
@@ -496,6 +547,7 @@ def delete_model(self, model_id: str) -> None: | |
| self._renderer = None | ||
| self._colocate_pg = None | ||
| self._base_lora_signature = None | ||
| self._base_lora_seed = None | ||
| logger.info(f"Successfully deleted model {model_id}") | ||
|
|
||
| def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch, role: str) -> TrainingInputBatch: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| """Translate Tinker LoRA options into SkyRL-Train LoRA target modules.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Iterable | ||
|
|
||
| from skyrl.tinker import types | ||
|
|
||
| FSDP_ATTN_TARGET_MODULES = ( | ||
| "q_proj", | ||
| "k_proj", | ||
| "v_proj", | ||
| "o_proj", | ||
| "query_key_value", | ||
| "attn.c_attn", | ||
| "attn.c_proj", | ||
| ) | ||
| FSDP_MLP_TARGET_MODULES = ( | ||
| "gate_proj", | ||
| "up_proj", | ||
| "down_proj", | ||
| "fc1", | ||
| "fc2", | ||
| "c_fc", | ||
| "mlp.c_proj", | ||
| ) | ||
| FSDP_UNEMBED_TARGET_MODULES = ( | ||
| "lm_head", | ||
| "embed_out", | ||
| "output_projection", | ||
| ) | ||
|
|
||
| MEGATRON_LORA_ATTN_TARGET_MODULES = ("linear_qkv", "linear_proj") | ||
| MEGATRON_LORA_MLP_TARGET_MODULES = ("linear_fc1", "linear_fc2") | ||
| MEGATRON_CANONICAL_LORA_ATTN_TARGET_MODULES = ("linear_q", "linear_k", "linear_v", "linear_proj") | ||
| MEGATRON_CANONICAL_LORA_MLP_TARGET_MODULES = ("linear_fc1_up", "linear_fc1_gate", "linear_fc2") | ||
| MEGATRON_UNEMBED_TARGET_MODULES = ("output_layer",) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class ResolvedSkyRLTrainLoraConfig: | ||
| target_modules: str | list[str] | ||
| exclude_modules: list[str] | None = None | ||
|
|
||
|
|
||
| def normalize_lora_targets(target_modules: str | Iterable[str]) -> tuple[str, ...]: | ||
| if isinstance(target_modules, str): | ||
| return (target_modules,) | ||
| return tuple(target_modules) | ||
|
|
||
|
|
||
| def _dedupe_targets(target_modules: Iterable[str]) -> list[str]: | ||
| return list(dict.fromkeys(target_modules)) | ||
|
|
||
|
|
||
| def _validate_train_targets(lora_config: types.LoraConfig) -> None: | ||
| if lora_config.rank > 0 and not (lora_config.train_attn or lora_config.train_mlp or lora_config.train_unembed): | ||
| raise ValueError("At least one of train_attn, train_mlp, or train_unembed must be true for LoRA rank > 0") | ||
|
|
||
|
|
||
| def resolve_skyrl_train_lora_config( | ||
| lora_config: types.LoraConfig, | ||
| strategy: str, | ||
| lora_type: str = "lora", | ||
| pipeline_parallel_size: int = 1, | ||
| ) -> ResolvedSkyRLTrainLoraConfig: | ||
| """Resolve Tinker LoRA train flags to the target module surface SkyRL-Train expects.""" | ||
|
|
||
| _validate_train_targets(lora_config) | ||
| if lora_config.rank <= 0: | ||
| return ResolvedSkyRLTrainLoraConfig(target_modules="all-linear") | ||
| if lora_config.train_attn and lora_config.train_mlp and not lora_config.train_unembed: | ||
| return ResolvedSkyRLTrainLoraConfig(target_modules="all-linear") | ||
|
|
||
| if strategy in ("fsdp", "fsdp2"): | ||
| target_modules: list[str] = [] | ||
| if lora_config.train_attn: | ||
| target_modules.extend(FSDP_ATTN_TARGET_MODULES) | ||
| if lora_config.train_mlp: | ||
| target_modules.extend(FSDP_MLP_TARGET_MODULES) | ||
| if lora_config.train_unembed: | ||
| target_modules.extend(FSDP_UNEMBED_TARGET_MODULES) | ||
| return ResolvedSkyRLTrainLoraConfig(target_modules=_dedupe_targets(target_modules)) | ||
|
|
||
| if strategy == "megatron": | ||
| if lora_config.train_unembed and pipeline_parallel_size > 1: | ||
| raise ValueError( | ||
| "train_unembed=True is not supported for the Megatron SkyRL-Train backend when " | ||
| "pipeline_model_parallel_size > 1 because output_layer only exists on the final pipeline stage" | ||
| ) | ||
| if lora_type == "canonical_lora": | ||
| attn_targets = MEGATRON_CANONICAL_LORA_ATTN_TARGET_MODULES | ||
| mlp_targets = MEGATRON_CANONICAL_LORA_MLP_TARGET_MODULES | ||
| elif lora_type == "lora": | ||
| attn_targets = MEGATRON_LORA_ATTN_TARGET_MODULES | ||
| mlp_targets = MEGATRON_LORA_MLP_TARGET_MODULES | ||
| else: | ||
| raise ValueError(f"Unsupported Megatron LoRA type: {lora_type!r}") | ||
|
|
||
| target_modules = [] | ||
| if lora_config.train_attn: | ||
| target_modules.extend(attn_targets) | ||
| if lora_config.train_mlp: | ||
| target_modules.extend(mlp_targets) | ||
| if lora_config.train_unembed: | ||
| target_modules.extend(MEGATRON_UNEMBED_TARGET_MODULES) | ||
| return ResolvedSkyRLTrainLoraConfig(target_modules=_dedupe_targets(target_modules)) | ||
|
|
||
| raise ValueError(f"Unsupported SkyRL-Train strategy for Tinker LoRA config: {strategy!r}") | ||
|
|
||
|
|
||
| def skyrl_train_lora_signature( | ||
| lora_config: types.LoraConfig, | ||
| strategy: str, | ||
| lora_type: str = "lora", | ||
| pipeline_parallel_size: int = 1, | ||
| ) -> tuple: | ||
| resolved = resolve_skyrl_train_lora_config( | ||
| lora_config, | ||
| strategy=strategy, | ||
| lora_type=lora_type, | ||
| pipeline_parallel_size=pipeline_parallel_size, | ||
| ) | ||
| return ( | ||
| int(lora_config.rank), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| int(lora_config.alpha), | ||
| normalize_lora_targets(resolved.target_modules), | ||
| tuple(resolved.exclude_modules or ()), | ||
| lora_type if strategy == "megatron" else strategy, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
int()cast is redundant aslora_config.seedis already defined as anintin theLoraConfigtype definition.