-
Notifications
You must be signed in to change notification settings - Fork 365
[tinker][megatron] Multi-LoRA Megatron + Tinker API #1617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 12 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
b8a55f3
[docs] Add Multi-LoRA Megatron Tinker design doc (v1)
erictang000 90c3ed4
[multi-lora] Add AdapterStore for per-worker LoRA slot bookkeeping
erictang000 7183d55
[multi-lora] Wire AdapterStore into MegatronPolicyWorkerBase
erictang000 abdadb1
[multi-lora] Add ensure_active_adapter + model_id threading to dispatch
erictang000 e4f2333
[multi-lora] Allow multiple LoRA policy adapters in SkyRLTrainBackend
erictang000 f762261
[multi-lora] Add GPU-gated multi-LoRA integration test for Megatron
erictang000 03d623a
[multi-lora] Add two-client smoke runbook
erictang000 057a627
[multi-lora] Fix _lora_signature_from to not read non-existent target…
erictang000 2fcea45
x
erictang000 003d3ee
[multi-lora] Swap grad buffers along with params + optimizer state
erictang000 e5309f4
Merge remote-tracking branch 'origin/main' into multi_lora
erictang000 2a3a236
[multi-lora] Remove internal-development docs from PR
erictang000 76dc375
[multi-lora] Fix integration test: backend name, healthcheck, Tinker …
erictang000 43f7d65
[multi-lora] Restore v1 sampling guards + add SEQ-vs-ALT min repro test
erictang000 aca96d0
[multi-lora] AdapterStore: snapshot/restore optimizer.param_groups[g]…
erictang000 ddb87c8
[multi-lora] Tighten SEQ-vs-ALT test to bit-exact + add Qwen3-0.6B va…
erictang000 24ca9c7
[multi-lora] Drop Qwen3-0.6B variant; tiny-model bit-exact is sufficient
erictang000 fe008ea
[multi-lora] Drop SEQ-vs-ALT comment references from in-tree code
erictang000 9b374fa
x
erictang000 824a840
[multi-lora][ci] Move test to tests/tinker/skyrl_train + add GPU CI
erictang000 c2d27e5
[multi-lora][ci] Rename CI to tinker-skyrl-train-backend-gpu
erictang000 e1f3c31
[ci] Remove accidentally-tracked .claude/scheduled_tasks.lock
erictang000 28775d6
x
erictang000 6a45214
[multi-lora] Code review cleanup
erictang000 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
363 changes: 363 additions & 0 deletions
363
skyrl/backends/skyrl_train/workers/megatron/adapter_store.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,363 @@ | ||
| """Per-worker store of LoRA adapter weights and optimizer state. | ||
|
|
||
| Holds one CPU-pinned snapshot per registered model_id plus a single pristine | ||
| slot used to seed newly-created adapters. At any moment exactly one adapter is | ||
| "live" in the worker's `actor_module` + `DistributedOptimizer`; swap_to() moves | ||
| LoRA bucket params and DistributedOptimizer fp32-main / Adam state between live | ||
| GPU storage and the per-adapter CPU slot via tensor.copy_(). | ||
|
|
||
| See docs/content/docs/tinker/multi_lora_design.mdx for the full design. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from typing import Any, Iterable, List, Optional, Tuple | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from megatron.core import parallel_state as mpu | ||
| from megatron.core.distributed import DistributedDataParallel as DDP | ||
| from megatron.core.optimizer import ChainedOptimizer | ||
|
|
||
|
|
||
| def _iter_opts(opt) -> List[Any]: | ||
| """Yield underlying Megatron optimizers, unwrapping ChainedOptimizer.""" | ||
| if isinstance(opt, ChainedOptimizer): | ||
| return list(opt.chained_optimizers) | ||
| return [opt] | ||
|
|
||
|
|
||
| def _iter_buffers(model_chunks) -> Iterable[Tuple[int, int, Any]]: | ||
| """Yield (mc_idx, buf_idx, buffer) for every LoRA-trainable DDP buffer.""" | ||
| for mc_idx, mc in enumerate(model_chunks): | ||
| if not isinstance(mc, DDP): | ||
| continue | ||
| bufs = list(mc.buffers) + list(mc.expert_parallel_buffers) | ||
| for buf_idx, buf in enumerate(bufs): | ||
| yield mc_idx, buf_idx, buf | ||
|
|
||
|
|
||
| def _new_pinned_like(t: torch.Tensor) -> torch.Tensor: | ||
| """Allocate a pinned-CPU tensor with the same shape/dtype as t.""" | ||
| return torch.empty_like(t, device="cpu").pin_memory() | ||
|
|
||
|
|
||
| def _expected_lora_param_check(model_chunks) -> None: | ||
| """Sanity-check: every trainable param under DDP buffers is a LoRA adapter param. | ||
|
|
||
| Megatron's DDP filters out requires_grad=False params before bucket | ||
| construction. With the LoRA pre-wrap hook freezing base params, only | ||
| LoRA A/B params should remain. If a future change breaks this invariant | ||
| (e.g. an unfrozen bias or new trainable head), we want to fail loudly | ||
| rather than silently swap the wrong tensors. | ||
| """ | ||
| for mc_idx, _buf_idx, buf in _iter_buffers(model_chunks): | ||
| for param in getattr(buf, "params", []): | ||
| mc = model_chunks[mc_idx] | ||
| name = next( | ||
| (n for n, p in mc.named_parameters() if p is param), | ||
| None, | ||
| ) | ||
| if name is None: | ||
| continue | ||
| if "adapter" not in name: | ||
| raise RuntimeError( | ||
| f"AdapterStore: trainable non-adapter param '{name}' found in " | ||
| f"DDP buffer {mc_idx}/{_buf_idx}; multi-LoRA swap would " | ||
| f"corrupt this param. Refusing to register." | ||
| ) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class LoraSignature: | ||
| """Immutable identity of a LoRA configuration. All registered adapters | ||
| must share the same signature; otherwise tensor shapes won't match across | ||
| swaps.""" | ||
|
|
||
| rank: int | ||
| alpha: int | ||
| target_modules: Tuple[str, ...] | ||
| lora_type: str | ||
| tp_size: int | ||
| pp_size: int | ||
| ep_size: int | ||
|
|
||
| @classmethod | ||
| def from_lora_config(cls, lora_config, lora_type: str = "lora") -> "LoraSignature": | ||
| targets = lora_config.target_modules | ||
| if isinstance(targets, str): | ||
| targets_tuple = (targets,) | ||
| else: | ||
| targets_tuple = tuple(targets) | ||
| return cls( | ||
| rank=int(lora_config.rank), | ||
| alpha=int(lora_config.alpha), | ||
| target_modules=targets_tuple, | ||
| lora_type=lora_type, | ||
| tp_size=mpu.get_tensor_model_parallel_world_size(), | ||
| pp_size=mpu.get_pipeline_model_parallel_world_size(), | ||
| ep_size=( | ||
| mpu.get_expert_model_parallel_world_size() | ||
| if hasattr(mpu, "get_expert_model_parallel_world_size") | ||
| else 1 | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class AdapterSlot: | ||
| """Per-adapter pinned-CPU storage mirroring the live GPU LoRA state. | ||
|
|
||
| Layout: | ||
| cpu_param_data[mc_idx] -> list[Tensor], one per buffer in | ||
| (mc.buffers + mc.expert_parallel_buffers). | ||
| cpu_grad_data[mc_idx] -> same shape as cpu_param_data; mirrors | ||
| buffer.grad_data so that grads accumulated by an interrupted | ||
| forward_backward aren't lost when another tenant runs in the | ||
| gap before this adapter's optim_step. | ||
| cpu_main_param[opt_idx][g] -> list[Tensor], shapes matching | ||
| opt.shard_fp32_from_float16_groups[g]. | ||
| cpu_opt_state[opt_idx][g][i] -> dict[str, Tensor], mirroring | ||
| opt.optimizer.state[main_param] for every tensor-valued entry | ||
| (exp_avg, exp_avg_sq, step, ...). | ||
| """ | ||
|
|
||
| cpu_param_data: List[List[torch.Tensor]] = field(default_factory=list) | ||
| cpu_grad_data: List[List[torch.Tensor]] = field(default_factory=list) | ||
| cpu_main_param: List[List[List[torch.Tensor]]] = field(default_factory=list) | ||
| cpu_opt_state: List[List[List[dict]]] = field(default_factory=list) | ||
| step_count: int = 0 | ||
|
|
||
|
|
||
| class AdapterStore: | ||
| """Per-worker registry of LoRA adapter slots. | ||
|
|
||
| One AdapterStore lives on each Megatron PolicyWorker. It owns CPU storage | ||
| for every registered adapter plus a pristine template; the live GPU model | ||
| + optimizer always reflect the slot identified by `current_id`. | ||
|
|
||
| Operations are local: snapshot/restore is a series of tensor.copy_()s that | ||
| issue no collectives. Callers are responsible for the surrounding | ||
| dist.barrier() (we recommend before and after the swap; see swap_to docs). | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| self._slots: dict[str, AdapterSlot] = {} | ||
| self._pristine: Optional[AdapterSlot] = None | ||
| self._current_id: Optional[str] = None | ||
| self._signature: Optional[LoraSignature] = None | ||
|
|
||
| @property | ||
| def current_id(self) -> Optional[str]: | ||
| return self._current_id | ||
|
|
||
| @property | ||
| def signature(self) -> Optional[LoraSignature]: | ||
| return self._signature | ||
|
|
||
| def has(self, model_id: str) -> bool: | ||
| return model_id in self._slots | ||
|
|
||
| def num_adapters(self) -> int: | ||
| return len(self._slots) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Slot allocation helpers | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def _allocate_empty_slot(self, model_chunks, optimizer) -> AdapterSlot: | ||
| slot = AdapterSlot() | ||
| # Param data + grad data: one pinned bf16 tensor each per (mc, buffer). | ||
| # Grads must travel with the slot — otherwise an interleaved tenant's | ||
| # forward_backward will clobber unconsumed grads via zero_grad_buffer | ||
| # at the top of forward_backward. See docs/.../multi_lora_design.mdx. | ||
| for mc_idx, _buf_idx, buf in _iter_buffers(model_chunks): | ||
| while len(slot.cpu_param_data) <= mc_idx: | ||
| slot.cpu_param_data.append([]) | ||
| slot.cpu_grad_data.append([]) | ||
| slot.cpu_param_data[mc_idx].append(_new_pinned_like(buf.param_data)) | ||
| slot.cpu_grad_data[mc_idx].append(_new_pinned_like(buf.grad_data)) | ||
| # Main params + optimizer state: per (opt_idx, group, param_idx). | ||
| for _opt in _iter_opts(optimizer): | ||
| opt_main: List[List[torch.Tensor]] = [] | ||
| opt_state: List[List[dict]] = [] | ||
| groups = getattr(_opt, "shard_fp32_from_float16_groups", None) or [] | ||
| for g, group in enumerate(groups): | ||
| main_g: List[torch.Tensor] = [] | ||
| state_g: List[dict] = [] | ||
| for main_param in group: | ||
| main_g.append(_new_pinned_like(main_param)) | ||
| state = _opt.optimizer.state.get(main_param, {}) | ||
| state_g.append({k: _new_pinned_like(v) for k, v in state.items() if isinstance(v, torch.Tensor)}) | ||
| opt_main.append(main_g) | ||
| opt_state.append(state_g) | ||
| slot.cpu_main_param.append(opt_main) | ||
| slot.cpu_opt_state.append(opt_state) | ||
| return slot | ||
|
|
||
| @torch.no_grad() | ||
| def _snapshot(self, slot: AdapterSlot, model_chunks, optimizer) -> None: | ||
| """Copy live GPU state into `slot` (CPU).""" | ||
| for mc_idx, buf_idx, buf in _iter_buffers(model_chunks): | ||
| slot.cpu_param_data[mc_idx][buf_idx].copy_(buf.param_data, non_blocking=True) | ||
| slot.cpu_grad_data[mc_idx][buf_idx].copy_(buf.grad_data, non_blocking=True) | ||
| for opt_idx, _opt in enumerate(_iter_opts(optimizer)): | ||
| groups = getattr(_opt, "shard_fp32_from_float16_groups", None) or [] | ||
| for g, group in enumerate(groups): | ||
| for i, main_param in enumerate(group): | ||
| slot.cpu_main_param[opt_idx][g][i].copy_(main_param, non_blocking=True) | ||
| state = _opt.optimizer.state.get(main_param, {}) | ||
| cpu_state = slot.cpu_opt_state[opt_idx][g][i] | ||
| for k, v in state.items(): | ||
| if isinstance(v, torch.Tensor) and k in cpu_state: | ||
| cpu_state[k].copy_(v, non_blocking=True) | ||
|
|
||
| @torch.no_grad() | ||
| def _restore(self, slot: AdapterSlot, model_chunks, optimizer) -> None: | ||
| """Copy `slot` (CPU) into live GPU state.""" | ||
| for mc_idx, buf_idx, buf in _iter_buffers(model_chunks): | ||
| buf.param_data.copy_(slot.cpu_param_data[mc_idx][buf_idx], non_blocking=True) | ||
| buf.grad_data.copy_(slot.cpu_grad_data[mc_idx][buf_idx], non_blocking=True) | ||
| for opt_idx, _opt in enumerate(_iter_opts(optimizer)): | ||
| groups = getattr(_opt, "shard_fp32_from_float16_groups", None) or [] | ||
| for g, group in enumerate(groups): | ||
| for i, main_param in enumerate(group): | ||
| main_param.copy_(slot.cpu_main_param[opt_idx][g][i], non_blocking=True) | ||
| state = _opt.optimizer.state.get(main_param, {}) | ||
| cpu_state = slot.cpu_opt_state[opt_idx][g][i] | ||
| for k, v in state.items(): | ||
| if isinstance(v, torch.Tensor) and k in cpu_state: | ||
| v.copy_(cpu_state[k], non_blocking=True) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Public API used by the worker | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def register_pristine(self, model_chunks, optimizer, signature: LoraSignature) -> None: | ||
| """Capture the freshly-initialised LoRA state as the pristine template. | ||
|
|
||
| Must be called once per worker, after the optimizer state has been | ||
| materialised (e.g. via DistributedOptimizer._init_optimizer_states_with_dummy_values). | ||
| Subsequent registrations will copy this slot to seed new adapters. | ||
| """ | ||
| if self._pristine is not None: | ||
| raise RuntimeError("AdapterStore.register_pristine called twice") | ||
| _expected_lora_param_check(model_chunks) | ||
| self._signature = signature | ||
| self._pristine = self._allocate_empty_slot(model_chunks, optimizer) | ||
| self._snapshot(self._pristine, model_chunks, optimizer) | ||
|
|
||
| @torch.no_grad() | ||
| def create(self, model_id: str, model_chunks, optimizer, signature: LoraSignature) -> None: | ||
| """Register a new adapter slot. | ||
|
|
||
| - First registration: this is also the live adapter; allocate a slot | ||
| but skip the pristine→slot copy because the live state already | ||
| equals pristine. `current_id` becomes `model_id`. | ||
| - Subsequent registrations: allocate slot and copy pristine → slot. | ||
| Live state is unchanged (no swap). The new adapter only becomes | ||
| live when the next `swap_to(model_id)` is issued. | ||
| """ | ||
| if self._signature is None: | ||
| raise RuntimeError("AdapterStore.create called before register_pristine") | ||
| if signature != self._signature: | ||
| raise ValueError( | ||
| f"AdapterStore: lora signature mismatch for '{model_id}'. " | ||
| f"Pristine={self._signature}, requested={signature}. " | ||
| f"Multi-LoRA requires identical (rank, alpha, target_modules, " | ||
| f"lora_type, tp/pp/ep sizes) across all adapters." | ||
| ) | ||
| if model_id in self._slots: | ||
| raise ValueError(f"AdapterStore: adapter '{model_id}' already registered") | ||
|
|
||
| slot = self._allocate_empty_slot(model_chunks, optimizer) | ||
| if self._current_id is None: | ||
| # First adapter: live state IS pristine; slot will be filled on | ||
| # the next snapshot (i.e. swap-away). Treat live as authoritative. | ||
| self._current_id = model_id | ||
| else: | ||
| # Seed the new slot from pristine. | ||
| self._copy_slot(self._pristine, slot) | ||
| self._slots[model_id] = slot | ||
|
|
||
| @torch.no_grad() | ||
| def _copy_slot(self, src: AdapterSlot, dst: AdapterSlot) -> None: | ||
| """CPU→CPU copy used to seed a new slot from the pristine template.""" | ||
| for mc_idx, mc_buffers in enumerate(src.cpu_param_data): | ||
| for buf_idx, t in enumerate(mc_buffers): | ||
| dst.cpu_param_data[mc_idx][buf_idx].copy_(t) | ||
| for mc_idx, mc_grads in enumerate(src.cpu_grad_data): | ||
| for buf_idx, t in enumerate(mc_grads): | ||
| dst.cpu_grad_data[mc_idx][buf_idx].copy_(t) | ||
| for opt_idx, opt_groups in enumerate(src.cpu_main_param): | ||
| for g, group in enumerate(opt_groups): | ||
| for i, t in enumerate(group): | ||
| dst.cpu_main_param[opt_idx][g][i].copy_(t) | ||
| for opt_idx, opt_groups in enumerate(src.cpu_opt_state): | ||
| for g, group in enumerate(opt_groups): | ||
| for i, state in enumerate(group): | ||
| for k, v in state.items(): | ||
| if k in dst.cpu_opt_state[opt_idx][g][i]: | ||
| dst.cpu_opt_state[opt_idx][g][i][k].copy_(v) | ||
|
|
||
| @torch.no_grad() | ||
| def delete(self, model_id: str) -> None: | ||
| """Drop the slot for `model_id`. | ||
|
|
||
| If `model_id` was the current adapter, `current_id` is cleared. The | ||
| live GPU state is left untouched (it now mirrors a deleted adapter); | ||
| the next `swap_to` will overwrite it. | ||
| """ | ||
| if model_id not in self._slots: | ||
| raise KeyError(f"AdapterStore: unknown adapter '{model_id}'") | ||
| del self._slots[model_id] | ||
| if self._current_id == model_id: | ||
| self._current_id = None | ||
|
|
||
| @torch.no_grad() | ||
| def swap_to(self, model_id: str, model_chunks, optimizer) -> None: | ||
|
erictang000 marked this conversation as resolved.
|
||
| """Make `model_id` the live adapter on this worker. | ||
|
|
||
| Algorithm (all under torch.no_grad): | ||
| 1. dist.barrier(dp_group) | ||
| 2. snapshot live → current's slot (skipped if current_id is None) | ||
| 3. cuda stream sync (D2H done) | ||
| 4. restore target's slot → live | ||
| 5. cuda stream sync (H2D done) | ||
| 6. dist.barrier(dp_group) | ||
|
|
||
| Caller responsibility: the trailing barrier guarantees all DP ranks | ||
| agree on the live adapter before the next collective. TP/PP/EP groups | ||
| do not need barriers because the swap is identical-shape on all | ||
| ranks within those groups (LoRA signature is fixed). | ||
| """ | ||
| if model_id not in self._slots: | ||
| raise KeyError(f"AdapterStore: unknown adapter '{model_id}'") | ||
| if self._current_id == model_id: | ||
| return # no-op fast path | ||
|
|
||
| dp_group = mpu.get_data_parallel_group() | ||
| if dist.is_available() and dist.is_initialized(): | ||
| dist.barrier(group=dp_group) | ||
|
|
||
| if self._current_id is not None: | ||
| current_slot = self._slots[self._current_id] | ||
| self._snapshot(current_slot, model_chunks, optimizer) | ||
| torch.cuda.current_stream().synchronize() | ||
|
|
||
| target_slot = self._slots[model_id] | ||
| self._restore(target_slot, model_chunks, optimizer) | ||
| torch.cuda.current_stream().synchronize() | ||
|
|
||
| self._current_id = model_id | ||
|
|
||
| if dist.is_available() and dist.is_initialized(): | ||
| dist.barrier(group=dp_group) | ||
|
|
||
| def clear(self) -> None: | ||
| """Drop all slots (used at full-shutdown reset).""" | ||
| self._slots.clear() | ||
| self._pristine = None | ||
| self._current_id = None | ||
| self._signature = None | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.