Skip to content

[tinker] Separate out SkyRLTrainBackend sample requests into EXTERNAL request path#1634

Closed
erictang000 wants to merge 31 commits into
NovaSky-AI:mainfrom
erictang000:async_sample_routing
Closed

[tinker] Separate out SkyRLTrainBackend sample requests into EXTERNAL request path#1634
erictang000 wants to merge 31 commits into
NovaSky-AI:mainfrom
erictang000:async_sample_routing

Conversation

@erictang000
Copy link
Copy Markdown
Collaborator

Allows for speedups with multi-lora!

hao-aaron and others added 30 commits April 28, 2026 01:21
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Made-with: Cursor

# Conflicts:
#	tests/backends/skyrl_train/gpu/utils.py
Adds the design write-up for multi-tenant LoRA training on the Megatron
backend exposed via the Tinker API. v1 is training-only; sampling and
adapter-only checkpoint export are deferred. Implementation follows on
the multi_lora branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
New module holding per-adapter pinned-CPU snapshots of the LoRA bucket
params + DistributedOptimizer fp32-main + Adam state on each Megatron
PolicyWorker. swap_to() walks mc.buffers + expert_parallel_buffers and
shard_fp32_from_float16_groups, doing tensor.copy_() in both directions
under torch.no_grad with dp_group barriers + cuda stream syncs.

Also includes a sanity check that every trainable param under DDP
buffers is a LoRA adapter param (named "...adapter..."), so a future
regression that unfreezes a non-LoRA param fails loudly at registration
rather than silently corrupting state.

Wiring into PolicyWorker / WorkerDispatch / SkyRLTrainBackend follows
in subsequent commits.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an `adapter_store: AdapterStore | None` attribute on the policy
worker (allocated only when LoRA is active so the FFT path is unchanged)
plus five Ray-callable methods:

- prime_optimizer_state — calls Megatron's
  DistributedOptimizer._init_optimizer_states_with_dummy_values() so
  exp_avg/exp_avg_sq exist before we snapshot the pristine slot.
- register_pristine_adapter — derives a LoraSignature from the worker's
  own lora config + parallel state, snapshots live state into pristine.
- register_adapter(model_id) — allocates a fresh slot; first call uses
  live as the slot, subsequent calls seed from pristine.
- delete_adapter(model_id) — drops a slot.
- swap_to_adapter(model_id) — local tensor.copy_() between live and slot
  storages plus dp_group barriers.

Plus an adapter_store_state() diagnostic for tests. Orchestration from
the controller follows in subsequent commits.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
WorkerDispatch now exposes:
  - ensure_active_adapter(role, model_id): fans swap_to_adapter to all
    actors of `role`. No-op when model_id is None or the workers don't
    own an AdapterStore (FFT path).
  - prime_adapter_store(role, model_id): one-shot bootstrap for the very
    first create_model — primes optimizer state, registers pristine slot,
    registers the first adapter in one Ray-fanout sequence.
  - register_adapter / delete_adapter: per-call slot maintenance.

forward / forward_backward / forward_backward_from_staged / optim_step /
set_lr / save_checkpoint / load_checkpoint take an optional model_id and
call ensure_active_adapter after _ensure_on_gpu. Default None preserves
single-tenant behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
create_model now allows additional 'policy' models when LoRA is active
and the first policy model has been built. Subsequent calls validate
(rank, alpha, target_modules) match the first adapter's signature, then
register a new slot via WorkerDispatch.register_adapter. FFT (rank=0)
keeps the original single-tenant gate.

_build_policy takes the first model_id and, when LoRA is active, fires
the AdapterStore bootstrap (prime_optimizer_state +
register_pristine_adapter + register_adapter) on every worker before
the colocate_all offload while model + optimizer are still GPU-resident.

delete_model: when more than one model is registered and the role is a
LoRA policy, just drop the slot via dispatch.delete_adapter and pop the
controller-side maps. Last-adapter delete still does the full
ray.shutdown teardown so the runtime can be rebuilt cleanly.

Plumbed model_id through forward / forward_backward / optim_step /
set_lr / save_checkpoint / load_checkpoint dispatch calls so the active
adapter is swapped in on every per-model entry point.

sample() and save_sampler_checkpoint() refuse with a clear error when
more than one LoRA adapter is registered (v1 inference path is single-
tenant; per-adapter sampling is deferred).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
End-to-end test that starts a Tinker API server with the SkyRL-Train
Megatron backend and exercises:

  - two LoRA adapters training independently without weight contamination,
  - rank-mismatch on a second create_model raises a clear error,
  - sample()/save_sampler_checkpoint with two adapters raises (v1 scope),
  - delete_model on one adapter leaves the runtime alive and the other
    adapter still trainable.

Auto-skips when no CUDA device is visible. Server lifecycle uses the
same wait_for_condition pattern as test_api.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Manual smoke test (the gate before merging multi_lora): launch a Tinker
API server with the SkyRL-Train Megatron backend, run two
tinker-cookbook sl_loop clients in parallel against it with distinct
model_ids, and verify

  - the policy model is built once (no second `init policy model done`),
  - the second client triggers `Registered additional LoRA adapter`,
  - both clients converge on their respective NLLs without weight
    contamination,
  - GPU memory stays bounded as the second client connects,
  - rank-mismatch / two-adapter sample / single-adapter-delete behave per
    the v1 contract.

Plus troubleshooting notes for the common failure modes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_modules

Tinker's public LoraConfig (skyrl/tinker/types.py:66) exposes only
rank + alpha + seed + train_{attn,mlp,unembed}; it has no
target_modules attribute. The Megatron path reads target_modules from
the server-side cfg.trainer.policy.model.lora.target_modules, which is
fixed at startup, so multi-adapter signature equality reduces to
(rank, alpha). The worker-side AdapterStore still verifies parallel
state equality via its own LoraSignature.

Fixes the AttributeError on the first create_model in the smoke test.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Fixes a cross-tenant grad-corruption race surfaced in review:

  Tick N: batched fwd_bwd = [A.fb, B.fb]
    - sub-batch A: swap_to("A"), zero_grad_buffer, accumulate A's grads
    - sub-batch B: swap_to("B")  <-- only params + opt state swapped
                   zero_grad_buffer  <-- A's grads CLOBBERED here
                   accumulate B's grads
  Tick N+1: A.optim_step
    - swap_to("A") restores A's params + opt state
    - optimizer.step() reads grad_data, which holds B's grads -> B's
      gradient is applied to A's weights, A's actual gradient is lost

The fix is to snapshot/restore `mc.buffers[i].grad_data` (and
`expert_parallel_buffers`) alongside `param_data`. AdapterSlot now
carries a parallel cpu_grad_data list; _allocate_empty_slot,
_snapshot, _restore, and _copy_slot all maintain it. The fp32 grad
accumulator inside DistributedOptimizer.step() is short-lived (created
and consumed within one call) so it doesn't need slot storage.

Memory cost: ~+1x per slot for the grad mirror (bf16, same size as
param buffer). For a 7B base + rank-32 LoRA on a single DP shard this
is on the order of tens of MB, dwarfed by the existing fp32 main +
Adam moments.

Updates the design doc to reflect the four storages per LoRA param and
adds a "Why grads must travel with the slot" section walking through
the race the review caught.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Per-tenant adapter routing for the merge_lora=False Megatron + vLLM
path (and FSDP for parity). The Tinker model_id IS the vLLM adapter
name end-to-end.

Worker side (megatron_worker, fsdp_worker):
  - broadcast_to_inference_engines accepts model_id (Optional[str]).
  - When LoRA is active, save the adapter into a per-tenant subdir
    os.path.join(lora_sync_path, model_id) so concurrent saves don't
    collide, and call load_lora_adapter(model_id, path, load_inplace=True)
    on vLLM. model_id=None preserves the legacy single-tenant
    SKYRL_LORA_ADAPTER_NAME path.
  - _save_lora_adapters_and_sync takes a lora_name parameter (default
    SKYRL_LORA_ADAPTER_NAME) instead of hardcoding the singleton.

Dispatch side (worker_dispatch):
  - save_weights_for_sampler(model_id=None) calls
    ensure_active_adapter(policy, model_id) before broadcasting so the
    correct adapter is live, and forwards model_id to
    broadcast_to_inference_engines.

Backend side (skyrl_train_backend):
  - save_sampler_checkpoint passes model_id (when LoRA is active).
  - sample() per-request `model` field is now the request's model_id
    when it's a registered LoRA adapter, falling back to
    resolve_policy_model_name(cfg) for FFT / single-tenant.
  - Drop the v1 'raise if >1 adapter' guards on sample / save_sampler_
    checkpoint — multi-tenant sampling is the goal of this branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…apter

Both _load_on_server and _unload_on_server called await resp.json()
without try/except, so a non-JSON error body (e.g. a plain-text 5xx
from a proxy in front of vLLM) would raise a generic JSON-parse error
and lose the original status. Mirror the robust pattern from _post:
try resp.json(content_type=None), fall back to resp.text() on parse
failure, then raise_for_status with whichever body we got.

Addresses the gemini-code-assist review note on PR NovaSky-AI#1579 (see
NovaSky-AI#1579).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Updates docs/content/docs/tinker/multi_lora_design.mdx scope from
"training only" to "training + per-adapter sampling", documents the
Tinker model_id == vLLM adapter name contract, the per-tenant
lora_sync_path layout, the merge_lora=False requirement on Megatron,
and the operator's max_cpu_loras sizing contract. Adds a "PR NovaSky-AI#1579
foundation" section pointing at the upstream PR.

Adds tests/tinker/test_multi_lora_rl_two_clients.md as the manual gate:
two rl_loop clients training and sampling on independent adapters
against one server, plus contamination check, negative checks, and
troubleshooting.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
# Conflicts:
#	skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py
#	skyrl/backends/skyrl_train/workers/worker_dispatch.py
#	skyrl/backends/skyrl_train_backend.py
#	skyrl/train/entrypoints/main_base.py
#	tests/backends/skyrl_train/gpu/utils.py
The Tinker engine batches sample requests across model_ids in
find_batchable_sample, then dispatches one prepared_batch to
backend.sample(). Our previous "exactly one model_id per batch" guard
short-circuited multi-tenant RL — when both rl_loop clients had
sample() requests pending in the same engine tick, the batched call
hit the guard and returned 400 to both.

Replaces the unique-model check with a per-request validation: every
model_id must be a known policy, but multiple distinct policy
model_ids in one batch are fine. Routing per request is already
handled by _sample_with_remote_client via the per-request `model`
field on the data plane.

Co-Authored-By: Eric Tang <erictang000@gmail.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The merged unload-then-load workaround from main wasn't sufficient for
re-syncing an adapter — vLLM still returned 400 "adapter already
loaded" on the second sync of the same name (e.g. when a tenant calls
save_sampler_checkpoint a second time, or when the unload step
returns 200 but the cached LoRARequest hasn't been evicted yet).

vLLM's own error message instructs the caller to set load_inplace=True
in that case, which is what PR NovaSky-AI#1579 originally did. Restore that
behavior: thread the load_inplace parameter (default True, exposed on
the public API) into the /v1/load_lora_adapter payload, drop the
separate _unload_on_server pre-step. The standalone unload_lora_adapter
method still exists for callers that explicitly want eviction.

Fixes the rl_loop runtime error:
  ClientResponseError: 400, message="The lora adapter '<id>' has
  already been loaded. If you want to load the adapter in place, set
  'load_inplace' to True."

Co-Authored-By: Eric Tang <erictang000@gmail.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds docs/content/docs/tinker/async_sample_routing.mdx describing the
plan to route SkyRL-Train sample requests through the existing
EXTERNAL fan-out path (api.py:1039-1064) instead of the engine's
synchronous loop. The engine already excludes EXTERNAL futures from
its scheduler; we just need to point a new
BackendForwardingInferenceClient at the engine-managed vLLM.

Covers: synchronization invariants (I1-I4) that already hold via the
SDK + checkpoint validation + vLLM pause/resume, files to add/modify
(EngineStateDB row, BackendForwardingInferenceClient,
SkyRLTrainBackend._publish_engine_state, api.py lifespan wiring),
trade-offs vs. dual-loop-in-engine and full async refactors, failure
modes, testing plan, and explicit non-goals (training-side parallelism,
auto-recovery from vLLM eviction).

Co-Authored-By: Eric Tang <erictang000@gmail.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Captures /tmp/rl_loop_{a..i}.log and /tmp/sl_loop_{a..d}.log from the
multi-LoRA RL + SFT smoke runs into tests/tinker/smoke_logs/. Each run
contributes code.diff (the working-tree diff at launch), config.json,
logs.log (full stdout/stderr), and metrics.jsonl.

Force-added because *.log is in the project .gitignore. ~1.2 MB total;
useful as reference output for the runbooks at
tests/tinker/test_multi_lora_{rl_two_clients,smoke_two_clients}.md.

Co-Authored-By: Eric Tang <erictang000@gmail.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
# Conflicts:
#	skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py
#	skyrl/backends/skyrl_train/inference_servers/setup.py
#	skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py
#	skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py
#	skyrl/backends/skyrl_train_backend.py
#	skyrl/train/generators/skyrl_gym_generator.py
#	skyrl/train/generators/skyrl_vlm_generator.py
#	tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_multi_lora_serving.py
#	tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py
…version

PR NovaSky-AI#1579's test files evolved between when our branch cherry-picked
them and when the PR was merged to main (model field is now optional
via _resolve_model). Reset them to main's version to keep the
multi_lora_rl diff strictly to RL functionality.
The 13 run dirs (~1.2 MB) added in 57a474a are still recoverable from
git history; remove them from the working tree so the PR diff stays
focused on multi-LoRA RL code.
Move the design docs and smoke-test runbooks to a separate
skyrl-tinker-dev repo (internal development workspace), and drop the
.python-version pin so it doesn't leak into contributors' local repos.

Files removed:
  - docs/content/docs/tinker/multi_lora_design.mdx
  - docs/content/docs/tinker/async_sample_routing.mdx
  - tests/tinker/test_multi_lora_rl_two_clients.md
  - tests/tinker/test_multi_lora_smoke_two_clients.md
  - .python-version

The integration test (tests/tinker/test_multi_lora_megatron.py) stays
in this repo since it's a real pytest test, not a runbook.
# Conflicts:
#	skyrl/backends/skyrl_train/workers/megatron/adapter_store.py
#	skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py
#	skyrl/backends/skyrl_train/workers/worker_dispatch.py
#	skyrl/backends/skyrl_train_backend.py
Two new GPU-gated tests that exercise the RL-side weight-sync path
(save_weights_for_sampler -> RemoteInferenceClient.load_lora_adapter
under model_id -> sample(model=model_id) on vLLM):

- test_per_adapter_sample_isolation: analog of
  test_per_adapter_step_isolation. Two pristine adapters following an
  identical training trajectory must produce bit-exact greedy samples
  at every step. Catches per-tenant routing failures (vLLM serves the
  wrong adapter), load_lora_adapter slot collisions, and per-param-
  group `step` snapshot regressions.

- test_two_adapters_sample_independently: analog of
  test_two_adapters_train_independently. After A trains, B trains, A
  trains again — A's continued sample must differ from A's earlier
  sample (A continues to learn) and from B's sample (B's sync did not
  clobber A's adapter on vLLM).

Required server config changes (folded into BACKEND_CONFIG, harmless
for the existing train-only tests):
  - merge_lora=False so vLLM serves per-tenant LoRA adapters by name
    instead of pushing merged weights as a base-model update;
  - max_loras=4 / max_cpu_loras=4 so vLLM's CPU LRU holds all
    adapters the suite registers concurrently.

Local: 6/6 pass on 2x B200 (~5 min wall clock for the full suite).
When SkyRL-Train runs non-colocated, vLLM is always-on but its sample
capacity is wasted: the engine serializes sample behind forward_backward
in its 100ms tick loop, so a sample submitted at the start of a 30s
training step waits 30s. For multi-tenant RL the queueing compounds.

This change reuses the existing EXTERNAL request-type plumbing (built
for fully external vLLMs) but points it at the engine-managed vLLM:

- New EngineStateDB singleton row carries the engine-managed vLLM proxy
  URL from the engine subprocess to the API process.
- SkyRLTrainBackend._publish_engine_state writes it on every
  _create_new_inference_client and clears it on full teardown.
- New BackendForwardingInferenceClient mirrors ExternalInferenceClient
  but reads the proxy URL lazily from EngineStateDB and uses
  model=<model_id> (the name save_weights_for_sampler registered with
  vLLM). Refreshes the cached URL once on connection error.
- api.py lifespan installs it when backend in (megatron, fsdp) and
  trainer.placement.colocate_all=False. Colocated and JAX paths
  unchanged.

Adds tests/tinker/skyrl_train/test_async_sample_routing.py covering:
  - EngineStateDB is published with proxy URL + is_colocated=False
  - Sample futures use RequestType.EXTERNAL (off the engine queue)
  - Sample latency is bounded well below a concurrent training stream's
    duration (the parallelism test)
  - Many concurrent samples across two adapters all resolve

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Forwarding client was sending logprobs=true to /v1/completions; vLLM
  (and the upstream vllm-router) require an int. Use 1 to get the chosen
  token's logprob, which is what Tinker's SampleOutput needs.
- The SDK's user-facing future doesn't expose the server-side request_id,
  so test_sample_uses_external_path now snapshots max(request_id) before
  the sample, then queries for any new EXTERNAL row afterward — robust
  to the module-scoped server fixture sharing state across tests.

All four async sample routing tests pass (sample latency 1.3s vs
concurrent 10s training, ratio 0.13). All six existing multi-LoRA
Megatron tests still pass under the new forwarding path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces multi-tenant LoRA support and an asynchronous sample routing path to improve inference performance. Key changes include updating the FSDP and Megatron workers to handle multiple named LoRA adapters, implementing a BackendForwardingInferenceClient to bypass the engine's serial loop for non-colocated backends, and adding an EngineStateDB to manage inference endpoint state. Feedback highlights critical logic errors in model name resolution for base model sampling and Full Fine-Tuning (FFT) scenarios, as well as a potential resource leak where the inference engine client is not explicitly closed during model deletion.

Comment on lines +947 to 959
unknown = [mid for mid in unique_models if mid not in self._model_ids_to_role]
if unknown:
error = types.ErrorResponse(
error=f"Expected exactly one model_id for sampling, got {unique_models}", status="error"
error=f"Sampling requested for unknown model_id(s): {sorted(unknown)}", status="error"
)
return {req_id: error for req_id, _, _, _, _ in prepared_batch.request_batch_slices}
model_id = next(iter(unique_models))
role = self._model_ids_to_role.get(model_id)
if role != "policy":
non_policy = [mid for mid in unique_models if self._model_ids_to_role.get(mid) != "policy"]
if non_policy:
error = types.ErrorResponse(
error=f"Sampling is only supported for policy models, got '{model_id}'", status="error"
error=f"Sampling is only supported for policy models, got non-policy: {sorted(non_policy)}",
status="error",
)
return {req_id: error for req_id, _, _, _, _ in prepared_batch.request_batch_slices}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The validation logic in sample incorrectly flags base model sampling requests (where model_id is an empty string) as unknown or non-policy models. This will cause base model sampling to fail in colocated mode where the EXTERNAL path is not used.

Comment on lines +1019 to +1022
per_request_models = [
mid if (self._base_lora_signature is not None and mid in self._model_ids_to_role) else fallback_model_name
for mid in prepared_batch.all_model_ids
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The model name resolution logic for vLLM requests is incorrect for base model sampling when LoRA is active. If mid is empty (base model), it currently falls back to fallback_model_name, which is the LoRA adapter name when LoRA is enabled. This results in sampling the adapter instead of the base model.

# vLLM identifies the LoRA adapter by the name passed to load_lora_adapter,
# which was set to model_id in save_weights_for_sampler. For base-model
# sampling we point at the underlying HF model name directly.
model_name = base_model if base_model else model_id
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The resolution of model_name does not account for Full Fine-Tuning (FFT) models. For FFT, the vLLM instance serves the base model, but model_id is passed as the model name, which will result in a 404 from vLLM. The client needs to distinguish between LoRA adapters (where model_name == model_id) and FFT models (where model_name == base_model).

self._publish_engine_state(proxy_url=None, server_urls=[], is_colocated=False)
self._cfg = None
self._dispatch = None
self._inference_engine_client = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _inference_engine_client should be explicitly closed during model deletion to ensure that underlying network resources (like the httpx client) are properly released. Since delete_model is a synchronous method, you can use asyncio.run to call the asynchronous aclose method.

References
  1. Ensure that resources like network clients are properly closed when they are no longer needed to avoid resource leaks.

@erictang000
Copy link
Copy Markdown
Collaborator Author

closing in favor of #1638

@erictang000 erictang000 closed this May 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants