Skip to content

[tinker][megatron] Multi-LoRA Megatron + Tinker API#1617

Merged
erictang000 merged 24 commits into
NovaSky-AI:mainfrom
erictang000:multi_lora
May 8, 2026
Merged

[tinker][megatron] Multi-LoRA Megatron + Tinker API#1617
erictang000 merged 24 commits into
NovaSky-AI:mainfrom
erictang000:multi_lora

Conversation

@erictang000
Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 commented May 4, 2026

Adds multi-tenant LoRA training to the SkyRL-Train Megatron Tinker backend. Multiple Tinker clients can create_model against the same server, each with its own LoRA adapter sharing the base model.

Architecture

  • New AdapterStore (per PolicyWorker) holds a per-adapter pinned-CPU snapshot of the LoRA bucket params, DistributedOptimizer fp32-main, Adam state (exp_avg, exp_avg_sq), and grad buffer. A swap is tensor.copy_() between live GPU storage and a slot, bracketed by dp_group barriers + cuda stream syncs.
  • WorkerDispatch.ensure_active_adapter(role, model_id) fans the swap RPC to all policy actors. Every per-model dispatch entry (forward_backward, optim_step, set_lr, save_checkpoint, load_checkpoint) takes model_id and swaps before executing.

API surface

  • create_model allows additional policy model_ids when LoRA is active. First call builds the policy + bootstraps the AdapterStore (prime_optimizer_stateregister_pristine_adapterregister_adapter); subsequent calls validate (rank, alpha) against the pristine signature and register a new slot.
  • delete_model only does ray.shutdown() on the last adapter; otherwise it's just a slot drop.
  • sample() and save_sampler_checkpoint raise when more than one adapter is registered — multi-tenant inference lands in the RL follow-up ([tinker][megatron] Multi-LoRA Megatron + Tinker API RL Training #1621).

Files

  • New: skyrl/backends/skyrl_train/workers/megatron/adapter_store.py, tests/tinker/test_multi_lora_megatron.py (GPU-gated integration test).
  • Modified: megatron_worker.py, worker_dispatch.py, skyrl_train_backend.py.

Verification

  • New integration test: two adapters trained in alternation, weights checked for cross-contamination, rank-mismatch rejection, delete-then-train continuity.
  • Manual smoke (two sl_loop clients) — runbook lives in erictang000/skyrl-tinker-dev.

Design doc: same repo, design/multi_lora_design.mdx.

Adds the design write-up for multi-tenant LoRA training on the Megatron
backend exposed via the Tinker API. v1 is training-only; sampling and
adapter-only checkpoint export are deferred. Implementation follows on
the multi_lora branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a design document for multi-tenant LoRA training on the Megatron backend, outlining a strategy to swap adapter weights and optimizer states between GPU and pinned CPU memory. Feedback focuses on technical risks and optimizations: specifically, the potential for gradient corruption during interleaved training steps, the need for a no-op check to avoid redundant synchronization overhead, and concerns regarding host memory pressure from pinned CPU storage. Additionally, it is recommended to remove specific line number references to ensure the documentation remains maintainable.

Comment thread docs/content/docs/tinker/multi_lora_design.mdx Outdated
Comment thread docs/content/docs/tinker/multi_lora_design.mdx Outdated
Comment thread docs/content/docs/tinker/multi_lora_design.mdx Outdated
Comment thread docs/content/docs/tinker/multi_lora_design.mdx Outdated
erictang000 and others added 7 commits May 4, 2026 21:10
New module holding per-adapter pinned-CPU snapshots of the LoRA bucket
params + DistributedOptimizer fp32-main + Adam state on each Megatron
PolicyWorker. swap_to() walks mc.buffers + expert_parallel_buffers and
shard_fp32_from_float16_groups, doing tensor.copy_() in both directions
under torch.no_grad with dp_group barriers + cuda stream syncs.

Also includes a sanity check that every trainable param under DDP
buffers is a LoRA adapter param (named "...adapter..."), so a future
regression that unfreezes a non-LoRA param fails loudly at registration
rather than silently corrupting state.

Wiring into PolicyWorker / WorkerDispatch / SkyRLTrainBackend follows
in subsequent commits.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an `adapter_store: AdapterStore | None` attribute on the policy
worker (allocated only when LoRA is active so the FFT path is unchanged)
plus five Ray-callable methods:

- prime_optimizer_state — calls Megatron's
  DistributedOptimizer._init_optimizer_states_with_dummy_values() so
  exp_avg/exp_avg_sq exist before we snapshot the pristine slot.
- register_pristine_adapter — derives a LoraSignature from the worker's
  own lora config + parallel state, snapshots live state into pristine.
- register_adapter(model_id) — allocates a fresh slot; first call uses
  live as the slot, subsequent calls seed from pristine.
- delete_adapter(model_id) — drops a slot.
- swap_to_adapter(model_id) — local tensor.copy_() between live and slot
  storages plus dp_group barriers.

Plus an adapter_store_state() diagnostic for tests. Orchestration from
the controller follows in subsequent commits.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
WorkerDispatch now exposes:
  - ensure_active_adapter(role, model_id): fans swap_to_adapter to all
    actors of `role`. No-op when model_id is None or the workers don't
    own an AdapterStore (FFT path).
  - prime_adapter_store(role, model_id): one-shot bootstrap for the very
    first create_model — primes optimizer state, registers pristine slot,
    registers the first adapter in one Ray-fanout sequence.
  - register_adapter / delete_adapter: per-call slot maintenance.

forward / forward_backward / forward_backward_from_staged / optim_step /
set_lr / save_checkpoint / load_checkpoint take an optional model_id and
call ensure_active_adapter after _ensure_on_gpu. Default None preserves
single-tenant behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
create_model now allows additional 'policy' models when LoRA is active
and the first policy model has been built. Subsequent calls validate
(rank, alpha, target_modules) match the first adapter's signature, then
register a new slot via WorkerDispatch.register_adapter. FFT (rank=0)
keeps the original single-tenant gate.

_build_policy takes the first model_id and, when LoRA is active, fires
the AdapterStore bootstrap (prime_optimizer_state +
register_pristine_adapter + register_adapter) on every worker before
the colocate_all offload while model + optimizer are still GPU-resident.

delete_model: when more than one model is registered and the role is a
LoRA policy, just drop the slot via dispatch.delete_adapter and pop the
controller-side maps. Last-adapter delete still does the full
ray.shutdown teardown so the runtime can be rebuilt cleanly.

Plumbed model_id through forward / forward_backward / optim_step /
set_lr / save_checkpoint / load_checkpoint dispatch calls so the active
adapter is swapped in on every per-model entry point.

sample() and save_sampler_checkpoint() refuse with a clear error when
more than one LoRA adapter is registered (v1 inference path is single-
tenant; per-adapter sampling is deferred).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
End-to-end test that starts a Tinker API server with the SkyRL-Train
Megatron backend and exercises:

  - two LoRA adapters training independently without weight contamination,
  - rank-mismatch on a second create_model raises a clear error,
  - sample()/save_sampler_checkpoint with two adapters raises (v1 scope),
  - delete_model on one adapter leaves the runtime alive and the other
    adapter still trainable.

Auto-skips when no CUDA device is visible. Server lifecycle uses the
same wait_for_condition pattern as test_api.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Manual smoke test (the gate before merging multi_lora): launch a Tinker
API server with the SkyRL-Train Megatron backend, run two
tinker-cookbook sl_loop clients in parallel against it with distinct
model_ids, and verify

  - the policy model is built once (no second `init policy model done`),
  - the second client triggers `Registered additional LoRA adapter`,
  - both clients converge on their respective NLLs without weight
    contamination,
  - GPU memory stays bounded as the second client connects,
  - rank-mismatch / two-adapter sample / single-adapter-delete behave per
    the v1 contract.

Plus troubleshooting notes for the common failure modes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_modules

Tinker's public LoraConfig (skyrl/tinker/types.py:66) exposes only
rank + alpha + seed + train_{attn,mlp,unembed}; it has no
target_modules attribute. The Megatron path reads target_modules from
the server-side cfg.trainer.policy.model.lora.target_modules, which is
fixed at startup, so multi-adapter signature equality reduces to
(rank, alpha). The worker-side AdapterStore still verifies parallel
state equality via its own LoraSignature.

Fixes the AttributeError on the first create_model in the smoke test.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@erictang000
Copy link
Copy Markdown
Collaborator Author

erictang000 commented May 4, 2026

LoRA A and LoRA B are running concurrently against the same Tinker API server (B200), LoRA Baseline was run on a separate tinker API server on a different node (h100).

Things are looking relatively uncontaminated, and are almost identically matching!

image

erictang000 and others added 2 commits May 4, 2026 22:13
Fixes a cross-tenant grad-corruption race surfaced in review:

  Tick N: batched fwd_bwd = [A.fb, B.fb]
    - sub-batch A: swap_to("A"), zero_grad_buffer, accumulate A's grads
    - sub-batch B: swap_to("B")  <-- only params + opt state swapped
                   zero_grad_buffer  <-- A's grads CLOBBERED here
                   accumulate B's grads
  Tick N+1: A.optim_step
    - swap_to("A") restores A's params + opt state
    - optimizer.step() reads grad_data, which holds B's grads -> B's
      gradient is applied to A's weights, A's actual gradient is lost

The fix is to snapshot/restore `mc.buffers[i].grad_data` (and
`expert_parallel_buffers`) alongside `param_data`. AdapterSlot now
carries a parallel cpu_grad_data list; _allocate_empty_slot,
_snapshot, _restore, and _copy_slot all maintain it. The fp32 grad
accumulator inside DistributedOptimizer.step() is short-lived (created
and consumed within one call) so it doesn't need slot storage.

Memory cost: ~+1x per slot for the grad mirror (bf16, same size as
param buffer). For a 7B base + rank-32 LoRA on a single DP shard this
is on the order of tens of MB, dwarfed by the existing fp32 main +
Adam moments.

Updates the design doc to reflect the four storages per LoRA param and
adds a "Why grads must travel with the slot" section walking through
the race the review caught.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@erictang000 erictang000 changed the title [docs] Add Multi-LoRA Megatron Tinker design doc (v1) [tinker][megatron] Multi-LoRA Megatron + Tinker API May 4, 2026
Move the design doc and smoke-test runbook to a separate
skyrl-tinker-dev repo (internal development workspace), and drop the
.python-version pin so it doesn't leak into contributors' local repos.

Files removed:
  - docs/content/docs/tinker/multi_lora_design.mdx
  - tests/tinker/test_multi_lora_smoke_two_clients.md
  - .python-version

The integration test (tests/tinker/test_multi_lora_megatron.py) stays
in this repo since it's a real pytest test, not a runbook.
@erictang000
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces multi-LoRA support for the Megatron backend by implementing an AdapterStore to manage LoRA weights and optimizer states on CPU, enabling efficient swapping between adapters on the GPU. The changes include orchestration logic in the dispatch and backend layers, along with new end-to-end tests. Feedback highlights that the OptimizerParamScheduler state is not currently swapped, which could result in incorrect learning rate progress for interleaved tenants. Additionally, the step_count field in AdapterSlot is currently unused and should be either integrated or removed.

Comment thread skyrl/backends/skyrl_train/workers/megatron/adapter_store.py
Comment thread skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py
Comment thread skyrl/backends/skyrl_train/workers/megatron/adapter_store.py Outdated
erictang000 and others added 12 commits May 7, 2026 21:36
…SDK API

Three corrections to make tests/tinker/test_multi_lora_megatron.py
actually run:

1. --backend skyrl_train -> --backend megatron and
   --extra skyrl_train -> --extra megatron, matching what
   skyrl/tinker/engine.get_backend_classes accepts.

2. /api/v1/server_capabilities -> /api/v1/healthz for the
   wait-for-server probe; the former endpoint is named
   /api/v1/get_server_capabilities and used to throw 404 on
   empty-body GETs.

3. lora_rank=N -> rank=N — the public Tinker SDK uses `rank`.

Also drops the save_weights_for_sampler() calls in
test_two_adapters_train_independently. v1 multi_lora deliberately
gates save_sampler_checkpoint to single-adapter so those calls would
raise; cross-adapter isolation is now verified purely via
loss-improvement-after-swap-back, which is what we actually care about.

Locally: 4 passed in 2m15s on 1x B200.
Two fixes:

1. Restore the v1 single-tenant sampling guards in skyrl_train_backend.py
   that the merge from origin/main accidentally dropped:
     - sample() returns ErrorResponse when LoRA is active and >1 adapter
       is registered.
     - save_sampler_checkpoint raises ValueError under the same condition.
   Multi-tenant inference is the RL follow-up (NovaSky-AI#1621); SFT v1 must
   refuse it explicitly rather than silently corrupting state.
   test_sample_with_two_adapters_errors had been passing in earlier runs
   only by accident — restore the actual guarantee.

2. Add test_seq_vs_alt_per_adapter_step_isolation: min repro of the
   SEQ-vs-ALT divergence flagged in ~/skyrl-seq-vs-alt-repro (against
   Qwen3-4B + PPO on a real pod). Two fresh adapters, ALT-style
   sequence, identical data, asserts pre-update losses match within
   1e-2 at every step. With AdapterStore snapshotting state['step'] per
   slot, this passes on the tiny model — step 0 is bit-exact, step 1
   diverges by 1.7e-4 (three orders of magnitude below the user's
   Qwen3-4B observation). If a future change leaks a global step
   counter across adapters, this test will fail loudly; the assertion
   message points at the SEQ-vs-ALT diagnosis.

Local: 5/5 pass in ~2m on 1x B200.
…['step']

Fixes the SEQ-vs-ALT divergence in ~/skyrl-seq-vs-alt-repro on
Qwen3-4B + PPO. Megatron's DistributedOptimizer wraps TE FusedAdam
which tracks the bias-correction step counter at the param-group
level (`optimizer.param_groups[g]['step']`), not in the per-param
state dict. The AdapterStore was only snapshotting per-param state,
so this counter advanced globally across adapters and broke Adam
bias correction every time a different adapter ran an optim_step in
between.

Two changes:

1. AdapterSlot grows a `cpu_param_group_state[opt_idx][pg_idx]` dict
   that mirrors scalar entries (int/float/bool/0-dim tensors) of
   each `optimizer.param_groups[g]`. We deliberately don't capture
   the `params` list or other non-scalar config — `step` is the
   only field that advances per call.

2. _allocate_empty_slot, _snapshot, _restore, and _copy_slot now
   round-trip these scalars alongside the per-param state. Tensor
   step counters (capturable=True path) get copy_() into existing
   storage; Python-int counters (default FusedAdam) get assigned.

Verified end-to-end against a live Qwen3-4B Megatron Tinker server
with the SEQ-vs-ALT repro:

  ALT step 0: A=B=-5.128122329711914  (bit-exact, was bit-exact)
  ALT step 1: A=B=-7.785201549530029  (bit-exact, was Δ=0.117)
  ALT step 2: A=B=-9.004300117492676  (bit-exact, was Δ=0.217)
  SEQ step 0: A=B=-5.128122329711914  (bit-exact, was bit-exact)
  SEQ step 1: A=B=-7.785201549530029  (bit-exact, was Δ=0.140)
  SEQ step 2: A=B=-9.004300117492676  (bit-exact, was Δ=0.027)

  ALT step N == SEQ step N for every (adapter, step) pair —
  per-adapter trajectory is now fully independent of cross-adapter
  scheduling, which is the v1 multi-LoRA correctness contract.

Also includes round-tripping non-tensor (Python int/float) entries
inside `optimizer.state[main_param]` for completeness — most state
is already tensor-valued for FusedAdam (exp_avg / exp_avg_sq), but
this protects against future configs (e.g. capturable=False torch
Adam) that store `state['step']` as an int.

Co-Authored-By: Eric Tang <erictang000@gmail.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…riant

Two related changes to the SEQ-vs-ALT regression coverage:

1. test_seq_vs_alt_per_adapter_step_isolation (tiny model): tighten
   from |Δ| < 1e-2 to bit-exact equality. With aca96d0's per-param-
   group state snapshot fix in place, both step 0 and step 1 are
   bit-exact across A and B on the tiny test model. Pre-fix the
   delta was 1.7e-4 — small but non-zero, so the bit-exact bound
   catches the regression even at the tiny scale.

2. New test_seq_vs_alt_qwen3_0_6b_cross_scenario: spins up a separate
   server fixture on Qwen/Qwen3-0.6B (~1.2 GB bf16; fits a single L4)
   and exercises the FULL upstream repro shape from
   ~/skyrl-seq-vs-alt-repro:
     - ALT scenario: 2 fresh adapters, A.0 / B.0 / A.1 / B.1
     - SEQ scenario: 2 more fresh adapters, A.0 / A.1 / B.0 / B.1
     - Asserts within-scenario A == B at every step.
     - Asserts cross-scenario A_ALT step N == A_SEQ step N at every step.

   Why we need this in addition to the tiny-model test: pre-fix on the
   tiny model the divergence was 1.7e-4 (within FP noise); on Qwen3-4B
   it was 0.45 nats. Qwen3-0.6B + cross_entropy is the smallest setup
   that surfaces the bug at a magnitude that's clearly real signal,
   while still fitting on the cheapest single-GPU box that runs
   Megatron LoRA SFT.

Currently bit-exact on multi_lora @ aca96d0 (this branch HEAD):
    ALT 0: A=B=22.81096936017275
    ALT 1: A=B=19.674404971301556
    SEQ 0: A=B=22.81096936017275  (= ALT 0)
    SEQ 1: A=B=19.674404971301556 (= ALT 1)

Local: 6/6 pass.

Co-Authored-By: Eric Tang <erictang000@gmail.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The bit-exact bound on the tiny test model is enough to catch the
SEQ-vs-ALT regression. With aca96d0's fix in place, both step 0 and
step 1 are |Δ|=0 on trl-internal-testing/tiny-Qwen3ForCausalLM —
pre-fix it was 1.7e-4, so even a tightened bound on the tiny model
catches the param-group `step` counter regression. No need for the
larger Qwen3-0.6B server fixture and its associated runtime cost.

Reverts the additions from ddb87c8 (the qwen3_0_6b_server /
qwen3_0_6b_service_client fixtures, the cross_scenario test, and the
base_model parameter on _api_server) but keeps the tightened bound
on test_seq_vs_alt_per_adapter_step_isolation.

Local: 5/5 pass in ~2m on 1x B200.
The SEQ-vs-ALT framing came from a one-off external repro path; in-tree
comments shouldn't reference it. Renames the per-adapter-isolation test
to test_per_adapter_step_isolation, drops the docstring/print/assert
references to ~/skyrl-seq-vs-alt-repro, and trims the AdapterStore
comments to keep only the technical explanation of why param-group
state has to round-trip per adapter.
Move tests/tinker/test_multi_lora_megatron.py to a new
tests/tinker/skyrl_train/ subdir (with __init__.py) so the SkyRL-Train
backend tests have their own home, separate from JAX-backend tests
that already live alongside conftest.py in tests/tinker/.

Add a new GPU CI workflow that exercises the multi-LoRA Megatron path
end-to-end. Mirrors the existing gpu_skyrl_train_megatron.yaml shape:

  - .github/workflows/gpu_skyrl_train_multi_lora_megatron.yaml triggers
    on pushes that touch the AdapterStore / dispatch / backend / tinker
    plumbing, on PRs labeled run_train_multi_lora_megatron_gpu_ci, and
    on workflow_dispatch.
  - ci/anyscale_gpu_ci_multi_lora_megatron.yaml submits the job to the
    same l4_ci compute config + skyrl-train-megatron image as the
    existing Megatron CI.
  - ci/gpu_ci_run_multi_lora_megatron.sh runs pytest against the new
    test path with --timeout=600 (each test starts a real Tinker server
    and does Megatron init; ~2 min total on L4).
Frame the new CI workflow around the SkyRL-Train Tinker backend rather
than this PR's specific multi-LoRA work — same suite is the right home
for any future Tinker-API-on-Megatron correctness tests we add.

Renames:
  .github/workflows/gpu_skyrl_train_multi_lora_megatron.yaml
    -> .github/workflows/tinker_skyrl_train_backend_gpu.yaml
  ci/anyscale_gpu_ci_multi_lora_megatron.yaml
    -> ci/anyscale_tinker_skyrl_train_backend_gpu.yaml
  ci/gpu_ci_run_multi_lora_megatron.sh
    -> ci/gpu_ci_run_tinker_skyrl_train_backend.sh

And updates the inside of each: workflow name (Tinker-SkyRL-Train-
Backend-GPU), Anyscale job name (tinker-skyrl-train-backend-gpu), job
key (tinker_skyrl_train_backend_gpu_tests), PR label
(run_tinker_skyrl_train_backend_gpu_ci), and the path filters'
self-reference.

Test path stays at tests/tinker/skyrl_train/test_multi_lora_megatron.py
since that's still the file we're invoking.
- Drop dead WorkerDispatch.prime_adapter_store: never called; first-time
  bootstrap is done inline in _build_policy via async_run_ray_method
  (dispatch isn't constructed yet at that point).
- Drop dead AdapterStore.clear: never called; ray.shutdown() on teardown
  destroys workers so per-instance state goes with them.
- Add public AdapterStore.registered_ids() and use it in
  MegatronPolicyWorkerBase.adapter_store_state instead of reaching into
  the private _slots dict.
- Rename _iter_opts -> iter_opts on AdapterStore so megatron_worker can
  import it without leaking private symbols.
- Tighten the test module docstring: drop the stale 8-step "test plan"
  that didn't match the implemented assertions, drop the dead
  docs/.../multi_lora_design.mdx#verification reference, and fix the
  "skyrl_train extras" mention (we use --extra tinker --extra megatron).
  Replace with a per-test summary of what each one actually checks.
- _api_server now yields just proc instead of (proc, log_path); the log
  path was never consumed.
- Reword test_two_adapters_train_independently docstring to drop the
  "multi_lora branch" reference (won't make sense post-merge).

Local: 5/5 still pass.
@erictang000 erictang000 merged commit c680cb7 into NovaSky-AI:main May 8, 2026
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant