[tinker][megatron] Multi-LoRA Megatron + Tinker API#1617
Conversation
Adds the design write-up for multi-tenant LoRA training on the Megatron backend exposed via the Tinker API. v1 is training-only; sampling and adapter-only checkpoint export are deferred. Implementation follows on the multi_lora branch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a design document for multi-tenant LoRA training on the Megatron backend, outlining a strategy to swap adapter weights and optimizer states between GPU and pinned CPU memory. Feedback focuses on technical risks and optimizations: specifically, the potential for gradient corruption during interleaved training steps, the need for a no-op check to avoid redundant synchronization overhead, and concerns regarding host memory pressure from pinned CPU storage. Additionally, it is recommended to remove specific line number references to ensure the documentation remains maintainable.
New module holding per-adapter pinned-CPU snapshots of the LoRA bucket params + DistributedOptimizer fp32-main + Adam state on each Megatron PolicyWorker. swap_to() walks mc.buffers + expert_parallel_buffers and shard_fp32_from_float16_groups, doing tensor.copy_() in both directions under torch.no_grad with dp_group barriers + cuda stream syncs. Also includes a sanity check that every trainable param under DDP buffers is a LoRA adapter param (named "...adapter..."), so a future regression that unfreezes a non-LoRA param fails loudly at registration rather than silently corrupting state. Wiring into PolicyWorker / WorkerDispatch / SkyRLTrainBackend follows in subsequent commits. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an `adapter_store: AdapterStore | None` attribute on the policy worker (allocated only when LoRA is active so the FFT path is unchanged) plus five Ray-callable methods: - prime_optimizer_state — calls Megatron's DistributedOptimizer._init_optimizer_states_with_dummy_values() so exp_avg/exp_avg_sq exist before we snapshot the pristine slot. - register_pristine_adapter — derives a LoraSignature from the worker's own lora config + parallel state, snapshots live state into pristine. - register_adapter(model_id) — allocates a fresh slot; first call uses live as the slot, subsequent calls seed from pristine. - delete_adapter(model_id) — drops a slot. - swap_to_adapter(model_id) — local tensor.copy_() between live and slot storages plus dp_group barriers. Plus an adapter_store_state() diagnostic for tests. Orchestration from the controller follows in subsequent commits. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
WorkerDispatch now exposes:
- ensure_active_adapter(role, model_id): fans swap_to_adapter to all
actors of `role`. No-op when model_id is None or the workers don't
own an AdapterStore (FFT path).
- prime_adapter_store(role, model_id): one-shot bootstrap for the very
first create_model — primes optimizer state, registers pristine slot,
registers the first adapter in one Ray-fanout sequence.
- register_adapter / delete_adapter: per-call slot maintenance.
forward / forward_backward / forward_backward_from_staged / optim_step /
set_lr / save_checkpoint / load_checkpoint take an optional model_id and
call ensure_active_adapter after _ensure_on_gpu. Default None preserves
single-tenant behavior.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
create_model now allows additional 'policy' models when LoRA is active and the first policy model has been built. Subsequent calls validate (rank, alpha, target_modules) match the first adapter's signature, then register a new slot via WorkerDispatch.register_adapter. FFT (rank=0) keeps the original single-tenant gate. _build_policy takes the first model_id and, when LoRA is active, fires the AdapterStore bootstrap (prime_optimizer_state + register_pristine_adapter + register_adapter) on every worker before the colocate_all offload while model + optimizer are still GPU-resident. delete_model: when more than one model is registered and the role is a LoRA policy, just drop the slot via dispatch.delete_adapter and pop the controller-side maps. Last-adapter delete still does the full ray.shutdown teardown so the runtime can be rebuilt cleanly. Plumbed model_id through forward / forward_backward / optim_step / set_lr / save_checkpoint / load_checkpoint dispatch calls so the active adapter is swapped in on every per-model entry point. sample() and save_sampler_checkpoint() refuse with a clear error when more than one LoRA adapter is registered (v1 inference path is single- tenant; per-adapter sampling is deferred). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
End-to-end test that starts a Tinker API server with the SkyRL-Train
Megatron backend and exercises:
- two LoRA adapters training independently without weight contamination,
- rank-mismatch on a second create_model raises a clear error,
- sample()/save_sampler_checkpoint with two adapters raises (v1 scope),
- delete_model on one adapter leaves the runtime alive and the other
adapter still trainable.
Auto-skips when no CUDA device is visible. Server lifecycle uses the
same wait_for_condition pattern as test_api.py.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Manual smoke test (the gate before merging multi_lora): launch a Tinker
API server with the SkyRL-Train Megatron backend, run two
tinker-cookbook sl_loop clients in parallel against it with distinct
model_ids, and verify
- the policy model is built once (no second `init policy model done`),
- the second client triggers `Registered additional LoRA adapter`,
- both clients converge on their respective NLLs without weight
contamination,
- GPU memory stays bounded as the second client connects,
- rank-mismatch / two-adapter sample / single-adapter-delete behave per
the v1 contract.
Plus troubleshooting notes for the common failure modes.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_modules
Tinker's public LoraConfig (skyrl/tinker/types.py:66) exposes only
rank + alpha + seed + train_{attn,mlp,unembed}; it has no
target_modules attribute. The Megatron path reads target_modules from
the server-side cfg.trainer.policy.model.lora.target_modules, which is
fixed at startup, so multi-adapter signature equality reduces to
(rank, alpha). The worker-side AdapterStore still verifies parallel
state equality via its own LoraSignature.
Fixes the AttributeError on the first create_model in the smoke test.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Fixes a cross-tenant grad-corruption race surfaced in review:
Tick N: batched fwd_bwd = [A.fb, B.fb]
- sub-batch A: swap_to("A"), zero_grad_buffer, accumulate A's grads
- sub-batch B: swap_to("B") <-- only params + opt state swapped
zero_grad_buffer <-- A's grads CLOBBERED here
accumulate B's grads
Tick N+1: A.optim_step
- swap_to("A") restores A's params + opt state
- optimizer.step() reads grad_data, which holds B's grads -> B's
gradient is applied to A's weights, A's actual gradient is lost
The fix is to snapshot/restore `mc.buffers[i].grad_data` (and
`expert_parallel_buffers`) alongside `param_data`. AdapterSlot now
carries a parallel cpu_grad_data list; _allocate_empty_slot,
_snapshot, _restore, and _copy_slot all maintain it. The fp32 grad
accumulator inside DistributedOptimizer.step() is short-lived (created
and consumed within one call) so it doesn't need slot storage.
Memory cost: ~+1x per slot for the grad mirror (bf16, same size as
param buffer). For a 7B base + rank-32 LoRA on a single DP shard this
is on the order of tens of MB, dwarfed by the existing fp32 main +
Adam moments.
Updates the design doc to reflect the four storages per LoRA param and
adds a "Why grads must travel with the slot" section walking through
the race the review caught.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move the design doc and smoke-test runbook to a separate skyrl-tinker-dev repo (internal development workspace), and drop the .python-version pin so it doesn't leak into contributors' local repos. Files removed: - docs/content/docs/tinker/multi_lora_design.mdx - tests/tinker/test_multi_lora_smoke_two_clients.md - .python-version The integration test (tests/tinker/test_multi_lora_megatron.py) stays in this repo since it's a real pytest test, not a runbook.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces multi-LoRA support for the Megatron backend by implementing an AdapterStore to manage LoRA weights and optimizer states on CPU, enabling efficient swapping between adapters on the GPU. The changes include orchestration logic in the dispatch and backend layers, along with new end-to-end tests. Feedback highlights that the OptimizerParamScheduler state is not currently swapped, which could result in incorrect learning rate progress for interleaved tenants. Additionally, the step_count field in AdapterSlot is currently unused and should be either integrated or removed.
…SDK API Three corrections to make tests/tinker/test_multi_lora_megatron.py actually run: 1. --backend skyrl_train -> --backend megatron and --extra skyrl_train -> --extra megatron, matching what skyrl/tinker/engine.get_backend_classes accepts. 2. /api/v1/server_capabilities -> /api/v1/healthz for the wait-for-server probe; the former endpoint is named /api/v1/get_server_capabilities and used to throw 404 on empty-body GETs. 3. lora_rank=N -> rank=N — the public Tinker SDK uses `rank`. Also drops the save_weights_for_sampler() calls in test_two_adapters_train_independently. v1 multi_lora deliberately gates save_sampler_checkpoint to single-adapter so those calls would raise; cross-adapter isolation is now verified purely via loss-improvement-after-swap-back, which is what we actually care about. Locally: 4 passed in 2m15s on 1x B200.
Two fixes:
1. Restore the v1 single-tenant sampling guards in skyrl_train_backend.py
that the merge from origin/main accidentally dropped:
- sample() returns ErrorResponse when LoRA is active and >1 adapter
is registered.
- save_sampler_checkpoint raises ValueError under the same condition.
Multi-tenant inference is the RL follow-up (NovaSky-AI#1621); SFT v1 must
refuse it explicitly rather than silently corrupting state.
test_sample_with_two_adapters_errors had been passing in earlier runs
only by accident — restore the actual guarantee.
2. Add test_seq_vs_alt_per_adapter_step_isolation: min repro of the
SEQ-vs-ALT divergence flagged in ~/skyrl-seq-vs-alt-repro (against
Qwen3-4B + PPO on a real pod). Two fresh adapters, ALT-style
sequence, identical data, asserts pre-update losses match within
1e-2 at every step. With AdapterStore snapshotting state['step'] per
slot, this passes on the tiny model — step 0 is bit-exact, step 1
diverges by 1.7e-4 (three orders of magnitude below the user's
Qwen3-4B observation). If a future change leaks a global step
counter across adapters, this test will fail loudly; the assertion
message points at the SEQ-vs-ALT diagnosis.
Local: 5/5 pass in ~2m on 1x B200.
…['step'] Fixes the SEQ-vs-ALT divergence in ~/skyrl-seq-vs-alt-repro on Qwen3-4B + PPO. Megatron's DistributedOptimizer wraps TE FusedAdam which tracks the bias-correction step counter at the param-group level (`optimizer.param_groups[g]['step']`), not in the per-param state dict. The AdapterStore was only snapshotting per-param state, so this counter advanced globally across adapters and broke Adam bias correction every time a different adapter ran an optim_step in between. Two changes: 1. AdapterSlot grows a `cpu_param_group_state[opt_idx][pg_idx]` dict that mirrors scalar entries (int/float/bool/0-dim tensors) of each `optimizer.param_groups[g]`. We deliberately don't capture the `params` list or other non-scalar config — `step` is the only field that advances per call. 2. _allocate_empty_slot, _snapshot, _restore, and _copy_slot now round-trip these scalars alongside the per-param state. Tensor step counters (capturable=True path) get copy_() into existing storage; Python-int counters (default FusedAdam) get assigned. Verified end-to-end against a live Qwen3-4B Megatron Tinker server with the SEQ-vs-ALT repro: ALT step 0: A=B=-5.128122329711914 (bit-exact, was bit-exact) ALT step 1: A=B=-7.785201549530029 (bit-exact, was Δ=0.117) ALT step 2: A=B=-9.004300117492676 (bit-exact, was Δ=0.217) SEQ step 0: A=B=-5.128122329711914 (bit-exact, was bit-exact) SEQ step 1: A=B=-7.785201549530029 (bit-exact, was Δ=0.140) SEQ step 2: A=B=-9.004300117492676 (bit-exact, was Δ=0.027) ALT step N == SEQ step N for every (adapter, step) pair — per-adapter trajectory is now fully independent of cross-adapter scheduling, which is the v1 multi-LoRA correctness contract. Also includes round-tripping non-tensor (Python int/float) entries inside `optimizer.state[main_param]` for completeness — most state is already tensor-valued for FusedAdam (exp_avg / exp_avg_sq), but this protects against future configs (e.g. capturable=False torch Adam) that store `state['step']` as an int. Co-Authored-By: Eric Tang <erictang000@gmail.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…riant Two related changes to the SEQ-vs-ALT regression coverage: 1. test_seq_vs_alt_per_adapter_step_isolation (tiny model): tighten from |Δ| < 1e-2 to bit-exact equality. With aca96d0's per-param- group state snapshot fix in place, both step 0 and step 1 are bit-exact across A and B on the tiny test model. Pre-fix the delta was 1.7e-4 — small but non-zero, so the bit-exact bound catches the regression even at the tiny scale. 2. New test_seq_vs_alt_qwen3_0_6b_cross_scenario: spins up a separate server fixture on Qwen/Qwen3-0.6B (~1.2 GB bf16; fits a single L4) and exercises the FULL upstream repro shape from ~/skyrl-seq-vs-alt-repro: - ALT scenario: 2 fresh adapters, A.0 / B.0 / A.1 / B.1 - SEQ scenario: 2 more fresh adapters, A.0 / A.1 / B.0 / B.1 - Asserts within-scenario A == B at every step. - Asserts cross-scenario A_ALT step N == A_SEQ step N at every step. Why we need this in addition to the tiny-model test: pre-fix on the tiny model the divergence was 1.7e-4 (within FP noise); on Qwen3-4B it was 0.45 nats. Qwen3-0.6B + cross_entropy is the smallest setup that surfaces the bug at a magnitude that's clearly real signal, while still fitting on the cheapest single-GPU box that runs Megatron LoRA SFT. Currently bit-exact on multi_lora @ aca96d0 (this branch HEAD): ALT 0: A=B=22.81096936017275 ALT 1: A=B=19.674404971301556 SEQ 0: A=B=22.81096936017275 (= ALT 0) SEQ 1: A=B=19.674404971301556 (= ALT 1) Local: 6/6 pass. Co-Authored-By: Eric Tang <erictang000@gmail.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The bit-exact bound on the tiny test model is enough to catch the SEQ-vs-ALT regression. With aca96d0's fix in place, both step 0 and step 1 are |Δ|=0 on trl-internal-testing/tiny-Qwen3ForCausalLM — pre-fix it was 1.7e-4, so even a tightened bound on the tiny model catches the param-group `step` counter regression. No need for the larger Qwen3-0.6B server fixture and its associated runtime cost. Reverts the additions from ddb87c8 (the qwen3_0_6b_server / qwen3_0_6b_service_client fixtures, the cross_scenario test, and the base_model parameter on _api_server) but keeps the tightened bound on test_seq_vs_alt_per_adapter_step_isolation. Local: 5/5 pass in ~2m on 1x B200.
The SEQ-vs-ALT framing came from a one-off external repro path; in-tree comments shouldn't reference it. Renames the per-adapter-isolation test to test_per_adapter_step_isolation, drops the docstring/print/assert references to ~/skyrl-seq-vs-alt-repro, and trims the AdapterStore comments to keep only the technical explanation of why param-group state has to round-trip per adapter.
Move tests/tinker/test_multi_lora_megatron.py to a new
tests/tinker/skyrl_train/ subdir (with __init__.py) so the SkyRL-Train
backend tests have their own home, separate from JAX-backend tests
that already live alongside conftest.py in tests/tinker/.
Add a new GPU CI workflow that exercises the multi-LoRA Megatron path
end-to-end. Mirrors the existing gpu_skyrl_train_megatron.yaml shape:
- .github/workflows/gpu_skyrl_train_multi_lora_megatron.yaml triggers
on pushes that touch the AdapterStore / dispatch / backend / tinker
plumbing, on PRs labeled run_train_multi_lora_megatron_gpu_ci, and
on workflow_dispatch.
- ci/anyscale_gpu_ci_multi_lora_megatron.yaml submits the job to the
same l4_ci compute config + skyrl-train-megatron image as the
existing Megatron CI.
- ci/gpu_ci_run_multi_lora_megatron.sh runs pytest against the new
test path with --timeout=600 (each test starts a real Tinker server
and does Megatron init; ~2 min total on L4).
Frame the new CI workflow around the SkyRL-Train Tinker backend rather
than this PR's specific multi-LoRA work — same suite is the right home
for any future Tinker-API-on-Megatron correctness tests we add.
Renames:
.github/workflows/gpu_skyrl_train_multi_lora_megatron.yaml
-> .github/workflows/tinker_skyrl_train_backend_gpu.yaml
ci/anyscale_gpu_ci_multi_lora_megatron.yaml
-> ci/anyscale_tinker_skyrl_train_backend_gpu.yaml
ci/gpu_ci_run_multi_lora_megatron.sh
-> ci/gpu_ci_run_tinker_skyrl_train_backend.sh
And updates the inside of each: workflow name (Tinker-SkyRL-Train-
Backend-GPU), Anyscale job name (tinker-skyrl-train-backend-gpu), job
key (tinker_skyrl_train_backend_gpu_tests), PR label
(run_tinker_skyrl_train_backend_gpu_ci), and the path filters'
self-reference.
Test path stays at tests/tinker/skyrl_train/test_multi_lora_megatron.py
since that's still the file we're invoking.
- Drop dead WorkerDispatch.prime_adapter_store: never called; first-time bootstrap is done inline in _build_policy via async_run_ray_method (dispatch isn't constructed yet at that point). - Drop dead AdapterStore.clear: never called; ray.shutdown() on teardown destroys workers so per-instance state goes with them. - Add public AdapterStore.registered_ids() and use it in MegatronPolicyWorkerBase.adapter_store_state instead of reaching into the private _slots dict. - Rename _iter_opts -> iter_opts on AdapterStore so megatron_worker can import it without leaking private symbols. - Tighten the test module docstring: drop the stale 8-step "test plan" that didn't match the implemented assertions, drop the dead docs/.../multi_lora_design.mdx#verification reference, and fix the "skyrl_train extras" mention (we use --extra tinker --extra megatron). Replace with a per-test summary of what each one actually checks. - _api_server now yields just proc instead of (proc, log_path); the log path was never consumed. - Reword test_two_adapters_train_independently docstring to drop the "multi_lora branch" reference (won't make sense post-merge). Local: 5/5 still pass.

Adds multi-tenant LoRA training to the SkyRL-Train Megatron Tinker backend. Multiple Tinker clients can
create_modelagainst the same server, each with its own LoRA adapter sharing the base model.Architecture
AdapterStore(perPolicyWorker) holds a per-adapter pinned-CPU snapshot of the LoRA bucket params,DistributedOptimizerfp32-main, Adam state (exp_avg,exp_avg_sq), and grad buffer. A swap istensor.copy_()between live GPU storage and a slot, bracketed bydp_groupbarriers + cuda stream syncs.WorkerDispatch.ensure_active_adapter(role, model_id)fans the swap RPC to all policy actors. Every per-model dispatch entry (forward_backward,optim_step,set_lr,save_checkpoint,load_checkpoint) takesmodel_idand swaps before executing.API surface
create_modelallows additional policymodel_ids when LoRA is active. First call builds the policy + bootstraps the AdapterStore (prime_optimizer_state→register_pristine_adapter→register_adapter); subsequent calls validate(rank, alpha)against the pristine signature and register a new slot.delete_modelonly doesray.shutdown()on the last adapter; otherwise it's just a slot drop.sample()andsave_sampler_checkpointraise when more than one adapter is registered — multi-tenant inference lands in the RL follow-up ([tinker][megatron] Multi-LoRA Megatron + Tinker API RL Training #1621).Files
skyrl/backends/skyrl_train/workers/megatron/adapter_store.py,tests/tinker/test_multi_lora_megatron.py(GPU-gated integration test).megatron_worker.py,worker_dispatch.py,skyrl_train_backend.py.Verification
delete-then-train continuity.sl_loopclients) — runbook lives inerictang000/skyrl-tinker-dev.Design doc: same repo,
design/multi_lora_design.mdx.