Skip to content

[tinker] Forward sample requests directly to backend vLLM (non-colocated)#1638

Open
erictang000 wants to merge 14 commits into
NovaSky-AI:mainfrom
erictang000:async_sample_routing_clean
Open

[tinker] Forward sample requests directly to backend vLLM (non-colocated)#1638
erictang000 wants to merge 14 commits into
NovaSky-AI:mainfrom
erictang000:async_sample_routing_clean

Conversation

@erictang000
Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 commented May 9, 2026

Description

When SkyRL-Train runs non-colocated (colocate_all=false), vLLM is always-on but its sample capacity is wasted: the Tinker engine subprocess serializes sample behind forward_backward / optim_step in its 100ms tick loop, so a sample submitted at the start of a 30s training step waits 30s. For multi-tenant RL the queueing compounds — every sample from any tenant queues behind every other tenant's training step.

This PR hoists sample requests off the engine queue and into the API process's asyncio loop, where they're forwarded directly to the engine-managed vLLM. The colocated and JAX paths are unchanged.

How it works

We reuse the existing RequestType.EXTERNAL plumbing (already in place for fully external vLLM URLs):

       ┌─ POST /api/v1/asample ─────────────────────────┐
client ┤  api.py:                                       │
       │    create FutureDB(type=EXTERNAL)              │
       │    asyncio.create_task(client.call_and_store)  │
       │    return future_id immediately                │
       └────────────────────────────────────────────────┘
                            │
                            ▼
       ┌─ SkyRLTrainInferenceForwardingClient ──────────┐
       │  POST {vllm_proxy_url}/v1/completions          │
       │    model=<model_id>  (LoRA registered already) │
       │  on completion: write FutureDB.result_data     │
       └────────────────────────────────────────────────┘

   (parallel)
       ┌─ engine subprocess ────────────────────────────┐
       │  process_pending_requests loop:                │
       │    forward_backward / optim_step /             │
       │    save_weights_for_sampler / ...              │
       │  (NO sample work here)                         │
       └────────────────────────────────────────────────┘

The cross-process plumbing was already wired for external vLLM URLs; the gap was that nothing was pointing it at the engine-managed vLLM. This PR closes that gap.

What's in the diff

8 files, +796/-3 LOC.

New EngineStateDB (skyrl/tinker/db_models.py) — singleton row carrying engine→API handoff state (currently just inference_proxy_url + updated_at). The backend writes it after building vLLM; the forwarding client reads it lazily on first sample.

SkyRLTrainBackend._publish_engine_state (skyrl/backends/skyrl_train_backend.py) — called inside _create_new_inference_client after build_new_inference_client returns, and on delete_model teardown to clear the row. Best-effort: a DB write failure logs and returns rather than corrupting controller state. Wired through a new set_engine_database_url(...) setter that the engine calls right after constructing the backend.

New SkyRLTrainInferenceForwardingClient (skyrl/tinker/extra/skyrl_train_inference_forwarding.py) — pair to ExternalInferenceClient, with the same EXTERNAL future-write contract but resolves the target URL from EngineStateDB instead of from a user-supplied EngineConfig.external_inference_url. Uses model=<model_id> (the name save_weights_for_sampler registered with vLLM via load_lora_adapter).

Key design choices:

  • No API-side semaphore. Backpressure is layered: httpx connection pool → vllm-router → each vLLM server's max_num_seqs. Adding a Python semaphore on top would just serialize work above what vLLM already manages.
  • Persistent httpx.AsyncClient with forwarding_inference_max_connections defaulting to None (unlimited). The only cost of "unlimited" is file descriptors, so operators just raise ulimit -n to match peak concurrent samples. Set an int to enforce a per-process cap. Closed via aclose() from the API lifespan.
  • Polymorphic .stop dispatch matches api.py SamplingParams.to_types() (list[str] → stop, list[int] → stop_token_ids).
  • Response parsing mirrors the in-process path: finish_reason normalized to Literal["stop", "length"]; missing token_logprobs zero-filled; non-JSON responses (e.g. proxy 502 HTML) surfaced with content-type + body excerpt for diagnosis.
  • Retry covers all transport-level errors (httpx.RequestError umbrella — ConnectError, ReadError, RemoteProtocolError, TimeoutException, PoolTimeout). HTTP 4xx/5xx from vLLM is NOT retried — it's a real upstream signal.

api.py lifespan — installs the forwarding client when backend in ("megatron", "fsdp") AND trainer.placement.colocate_all=False. Colocated runs and JAX keep the synchronous engine flow. Calls aclose() on shutdown.

Synchronization invariants preserved:

  • (I1) Sample for checkpoint X requires save-X to have completed. Already enforced by the SDK + validate_checkpoint(...) at api.py:1037 before the future is created.
  • (I2) In-flight sample during a re-broadcast for the same model. WorkerDispatch.save_weights_for_sampler brackets the broadcast with pause_generation / resume_generation; vLLM's KEEP-mode pause freezes in-flight requests in its scheduler.
  • (I3) Result writes don't conflict with engine writes. EXTERNAL futures are written by the API process; all other futures by the engine. SQLite WAL handles concurrent readers.
  • (I4) Colocated path is unchanged. Lifespan strictly gates on colocate_all=false && backend ∈ (megatron, fsdp).

Test plan

  • test_engine_state_published — after save_weights_for_sampler, EngineStateDB.inference_proxy_url is populated with the engine-managed vLLM proxy URL.
  • test_sample_uses_external_path — issued sample creates a FutureDB row with request_type=EXTERNAL (off the engine queue).
  • test_sample_concurrent_with_training_is_fast — central parallelism test. While 24 forward_backward+optim_step calls stream in a background thread, a concurrent sample resolves in 1.3s vs 10s training stream (ratio 0.13×). Without async routing the sample would queue behind the entire training stream and the latencies would be comparable.
  • test_concurrent_samples_per_adapter — 8 concurrent samples across two adapters all resolve via the forwarding client's connection pool and per-adapter model_id routing on vLLM.
  • ✅ All 6 existing test_multi_lora_megatron.py tests continue to pass under the new forwarding path (backwards compatible with the multi-LoRA RL workload).
  • tests/tinker/test_api.py single-tenant paths pass.

End-to-end multi-tenant RL bench following docs/tinker/multi_tenancy.mdx#quickstart---two-rl-clients (Qwen3-0.6B, GSM8K, 4 policy GPUs + 1 vLLM, 3 iters × batch 16 × group 4 × max 128 tokens per client):

Mode A wall-clock B wall-clock Total wall-clock
Sequential (A then B) 37.8s 31.7s 85s
Concurrent (A ∥ B) 36.0s 34.7s 44s

Speedup ~1.93×, close to the theoretical 2× ceiling. Per-iter cost grows ~1-2s under contention (vLLM multiplexes both adapters' samples + engine batches both fwd_bwds), but total wall-clock halves because both tenants progress simultaneously instead of waiting their turn.

Server-side instrumentation: 240 sample requests forwarded via the API-process client; 0 served by the engine (single startup-time engine sample is pre-bench). process_batch_requests(forward_backward, n=2) lines in the server log confirm the engine batches concurrent A+B fwd_bwd calls together.

Operator config

To enable, set in --backend-config:

{
    "trainer.placement.colocate_all": false,
    "trainer.policy.megatron_config.lora_config.merge_lora": false,
    "trainer.policy.model.lora.max_loras": <max concurrent adapters in a batch>,
    "trainer.policy.model.lora.max_cpu_loras": <total adapter capacity>
}

For very high fan-out workloads, raise the host's ulimit -n to match peak concurrent samples (default httpx pool is unbounded). Operators who want a per-API-process FD cap can set --forwarding-inference-max-connections <N>.

See docs/content/docs/tinker/multi_tenancy.mdx (added in the multi-lora-rl PR) for the full operator contract.

Out of scope

  • Colocated mode acceleration. No parallelism to exploit there — vLLM is asleep during training and the engine's synchronous sample path is what wakes it.
  • Auto-recovery from vLLM eviction. If max_cpu_loras is too low and vLLM evicts an adapter mid-run, the next sample 404s. Re-loading from disk on demand is a follow-up; for now the operator sizes max_cpu_loras ≥ expected concurrent adapters.
  • Sample retries beyond one transport-level retry. The forwarding client refreshes its cached proxy URL and retries once on httpx.RequestError. Application-level retry is the SDK's job (already implemented in tinker/retry_handler).
  • Parallelizing training requests across model_ids. Multi-tenant forward_backward / optim_step still serialize through the engine's main loop.

erictang000 and others added 2 commits May 9, 2026 02:49
When SkyRL-Train runs non-colocated, vLLM is always-on but its sample
capacity is wasted: the engine serializes sample behind forward_backward
in its 100ms tick loop, so a sample submitted at the start of a 30s
training step waits 30s. For multi-tenant RL the queueing compounds.

This change reuses the existing EXTERNAL request-type plumbing (built
for fully external vLLMs) but points it at the engine-managed vLLM:

- New EngineStateDB singleton row carries the engine-managed vLLM proxy
  URL from the engine subprocess to the API process.
- SkyRLTrainBackend._publish_engine_state writes it on every
  _create_new_inference_client and clears it on full teardown.
- New BackendForwardingInferenceClient mirrors ExternalInferenceClient
  but reads the proxy URL lazily from EngineStateDB and uses
  model=<model_id> (the name save_weights_for_sampler registered with
  vLLM). Refreshes the cached URL once on connection error.
- api.py lifespan installs it when backend in (megatron, fsdp) and
  trainer.placement.colocate_all=False. Colocated and JAX paths
  unchanged.

Adds tests/tinker/skyrl_train/test_async_sample_routing.py covering:
  - EngineStateDB is published with proxy URL + is_colocated=False
  - Sample futures use RequestType.EXTERNAL (off the engine queue)
  - Sample latency is bounded well below a concurrent training stream's
    duration (the parallelism test)
  - Many concurrent samples across two adapters all resolve

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
- Forwarding client was sending logprobs=true to /v1/completions; vLLM
  (and the upstream vllm-router) require an int. Use 1 to get the chosen
  token's logprob, which is what Tinker's SampleOutput needs.
- The SDK's user-facing future doesn't expose the server-side request_id,
  so test_sample_uses_external_path now snapshots max(request_id) before
  the sample, then queries for any new EXTERNAL row afterward — robust
  to the module-scoped server fixture sharing state across tests.

All four async sample routing tests pass (sample latency 1.3s vs
concurrent 10s training, ratio 0.13). All six existing multi-LoRA
Megatron tests still pass under the new forwarding path.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements an asynchronous sample routing path, enabling the API process to forward sample requests directly to a backend-managed vLLM instance, thereby bypassing the engine's serial scheduling loop. Key changes include the introduction of an EngineStateDB singleton for state handoff, the BackendForwardingInferenceClient for request handling, and a new suite of end-to-end tests. Feedback points out that the vLLM payload is missing stop parameters, the response parsing logic needs to be more robust regarding finish reasons and logprobs, and the use of a persistent HTTP client is recommended to improve efficiency.

I am having trouble creating individual review comments. Click here to see my feedback.

skyrl/tinker/extra/backend_forwarding_inference.py (121-137)

high

The payload for the vLLM /v1/completions request is missing the stop parameters (stop and stop_token_ids). These are present in the SamplingParams and should be forwarded to ensure generation terminates correctly according to the request configuration.

        payload = {
            "model": model_name,
            "prompt": prompt_tokens,
            "n": sample_req.num_samples,
            "seed": sample_req.sampling_params.seed,
            "max_tokens": sample_req.sampling_params.max_tokens,
            "temperature": sample_req.sampling_params.temperature,
            "top_p": sample_req.sampling_params.top_p,
            "top_k": sample_req.sampling_params.top_k,
            # vLLM (and the upstream vllm-router) expects an integer for the
            # OpenAI-compatible /v1/completions endpoint — the number of top
            # tokens to return logprobs for. 1 gives the chosen token's
            # logprob, which is what the Tinker SampleOutput requires.
            "logprobs": 1,
            "stream": False,
            "return_token_ids": True,
        }
        if sample_req.sampling_params.stop_strings:
            payload["stop"] = sample_req.sampling_params.stop_strings
        if sample_req.sampling_params.stop_tokens:
            payload["stop_token_ids"] = sample_req.sampling_params.stop_tokens

skyrl/tinker/extra/backend_forwarding_inference.py (151-162)

high

The parsing logic for the vLLM response should be more robust and consistent with the internal engine path. Specifically:

  1. Normalization of finish_reason: vLLM may return stop_token or other strings that should be mapped to stop to remain compatible with the SDK's expectations (which typically expects stop or length).
  2. Logprobs Safeguard: If logprobs are missing or empty (e.g., due to an upstream issue), it's critical for RL workloads to have a fallback (e.g., zeros) to avoid downstream failures.
  3. Prompt Logprobs: The prompt_logprobs field should return None rather than an empty list to match the behavior in skyrl_train_backend.py.
        sequences = []
        for choice in result["choices"]:
            tokens = choice.get("token_ids", [])
            lp = choice.get("logprobs") or {}
            logprobs = lp.get("token_logprobs") or []

            # Ensure logprobs exist (critical for RL)
            if not logprobs and tokens:
                logger.warning("No logprobs returned from vLLM - filling with zeros")
                logprobs = [0.0] * len(tokens)

            # Map vLLM finish reason to Tinker format
            finish_reason = choice.get("finish_reason")
            stop_reason = "stop" if finish_reason in ("stop", "stop_token") else "length"

            sequences.append(
                types.GeneratedSequence(
                    tokens=tokens,
                    logprobs=logprobs,
                    stop_reason=stop_reason,
                )
            )

        return types.SampleOutput(sequences=sequences, prompt_logprobs=None)

skyrl/tinker/extra/backend_forwarding_inference.py (139-142)

medium

Creating a new httpx.AsyncClient for every request is inefficient as it prevents connection pooling and incurs the overhead of a new TCP/TLS handshake for every sample. Consider initializing a single persistent client in __init__ and reusing it across requests. If you do this, ensure the client is properly closed during the API server's shutdown phase in api.py lifespan.

erictang000 and others added 4 commits May 11, 2026 18:02
… persistent httpx

Three review nits on the forwarding client:

1. Stop params (high): the vLLM /v1/completions payload was missing
   `stop` and `stop_token_ids`. SamplingParams carries both
   stop_strings and stop_tokens — forward them.

2. Response parse robustness (high): mirror the normalization in
   skyrl_train_backend._sample_with_remote_client —
     - normalize vLLM's finish_reason ({"stop","stop_token"} -> "stop",
       else "length"; Tinker's GeneratedSequence.stop_reason is a Literal
       so this matters)
     - if vLLM returns None/empty token_logprobs, zero-fill (RL needs
       these for advantages)
     - return prompt_logprobs=None instead of []

3. Persistent httpx.AsyncClient (medium): a fresh client per call kills
   connection pooling and pays a new TCP/TLS handshake per sample.
   Construct the client once in __init__ and reuse. Expose
   `aclose()` and call it from api.py lifespan shutdown.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
…kpressure

Review feedback: the asyncio.Semaphore cap (default 64) was a real
bottleneck for high fan-out workloads (e.g. 4k concurrent requests from
multiple training runs), and it stacked unnecessarily on top of vLLM's
own scheduling (vllm-router -> per-server max_num_seqs).

Drop the semaphore. The remaining backpressure chain is:
  httpx connection pool  ->  vllm-router  ->  vLLM max_num_seqs

The httpx pool is the only API-side cap. Rename the config knob to
`forwarding_inference_max_connections` to reflect what it actually
controls and bump the default from 64 to 1024. max_keepalive scales
with it (max_conn / 4, floor 32). Operators with very high fan-out can
raise this further.

Also fix a regression introduced in the last review pass: SamplingParams
on the API-level model uses a polymorphic .stop field (list[str] |
list[int]), not the normalized stop_strings/stop_tokens. Dispatch on
the element type to set "stop" vs "stop_token_ids" correctly.

Verified with the multi_tenancy.mdx RL quickstart:
  Concurrent A+B (3 iters each, batch=16, group=4, max_tokens=128):
    36s wall-clock; per-iter ~9-10s
  Sequential B (warm): 25.5s for 3 iters; per-iter ~8.5s
  Speedup ~1.4-1.9x with per-iter overhead ~1-2s under contention.
All 192 sample requests forwarded; 0 served by the engine.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Per discussion: raising max_connections is essentially free (the only
cost is file descriptors). Make the default UX "hammer vllm-router and
let it queue" by setting forwarding_inference_max_connections to
Optional[int] with default None = unlimited.

Operators get the simplest path: just raise `ulimit -n` to match peak
concurrent samples across all tenants, and let vllm-router + each vLLM
server's max_num_seqs be the only queues. Multi-tenant SaaS-style
deployments can still set an int to bound per-API-process FD usage.

httpx maps max_connections=None to INT64_MAX internally (unlimited).
Wire the argparse path so `--forwarding-inference-max-connections None`
from the CLI parses to Python None (not the literal string).

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
High-priority correctness:
- Double-checked lock in _resolve_proxy_url: hot path returns the cached
  URL without acquiring _cache_lock, so concurrent samples no longer
  serialize at this point under the unlimited-pool default.
- Test _server_is_up: add explicit `import urllib.request`; the previous
  `import urllib.error` alone only worked when some other module (tinker
  SDK / transformers) had already pulled urllib.request into the urllib
  namespace.
- delete_model: reorder so all local state is reset before calling
  _publish_engine_state, and make _publish_engine_state best-effort
  (log + return on any DB error). Prevents a partially-torn-down
  controller if SQLite writes fail during shutdown.

Schema cleanup:
- Drop EngineStateDB.is_colocated and inference_server_urls. Both were
  written by the backend but never read — lifespan checks colocation
  from the user's --backend-config dict, and the forwarding client only
  needs inference_proxy_url. _publish_engine_state signature simplified
  to (proxy_url,) only.

Test docstrings:
- Add test_concurrent_samples_per_adapter to module Coverage list.
- Rewrite that test's docstring (referenced the removed semaphore;
  now describes the connection pool + model_id routing it actually
  exercises).
- Drop the is_colocated/server_urls assertions from
  test_engine_state_published.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@erictang000
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an asynchronous sample routing path designed to reduce inference latency by bypassing the engine's serial scheduling loop. Key additions include the BackendForwardingInferenceClient, which forwards requests directly to a backend-managed vLLM instance, and the EngineStateDB model to facilitate the handoff of proxy URLs between the engine and the API. The changes also include lifecycle management for publishing engine state and comprehensive end-to-end tests. Review feedback focuses on enhancing the robustness of the new forwarding client by adding null checks for database retrievals, broadening exception handling for network requests, and implementing safer JSON parsing for server responses.

Comment thread skyrl/tinker/extra/skyrl_train_inference_forwarding.py
Comment thread skyrl/tinker/extra/skyrl_train_inference_forwarding.py Outdated
Comment thread skyrl/tinker/extra/skyrl_train_inference_forwarding.py Outdated
erictang000 and others added 5 commits May 11, 2026 23:43
…eForwardingClient

The previous "BackendForwardingInferenceClient" name was generic —
"backend" could mean anything. Rename to make it explicit that this
client forwards sample requests to the SkyRL-Train-managed inference
URL (parallel to ExternalInferenceClient which forwards to a
user-supplied external vLLM URL).

  Class:     BackendForwardingInferenceClient
          -> SkyRLTrainInferenceForwardingClient
  Module:    skyrl.tinker.extra.backend_forwarding_inference
          -> skyrl.tinker.extra.skyrl_train_inference_forwarding

Updated references in api.py, config.py, extra/__init__.py, and the
test file's docstring/comments. Lifespan log message updated.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Three review nits on SkyRLTrainInferenceForwardingClient:

1. session.get(FutureDB, request_id) -> None when the future was
   deleted between asample handler scheduling us and our arrival
   here (manual cleanup, stale-session GC). Previously we'd
   AttributeError on `.result_data = ...`. Now: null-check, log a
   warning, return cleanly. The retrieve_future poller times out on
   the caller's side as the natural failure mode.

2. response.json() raises ValueError when the upstream returns
   non-JSON (e.g. vllm-router 502 HTML even with a 2xx status from
   an intermediate proxy). Wrap in try/except and surface the raw
   body + content-type so the failure is diagnosable from FutureDB's
   error_data alone.

3. Broaden retry catch from (ConnectError, ReadError,
   RemoteProtocolError) to httpx.RequestError. Now also covers
   TimeoutException and PoolTimeout. Verified all five specific
   classes are RequestError subclasses; HTTP 4xx/5xx (RuntimeError)
   is intentionally NOT retried — real upstream errors should
   surface.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@erictang000 erictang000 requested a review from pcmoritz May 12, 2026 00:17
Comment thread skyrl/backends/skyrl_train_backend.py Outdated
self._cfg.generator.inference_engine,
)

def set_engine_database_url(self, database_url: str) -> None:
Copy link
Copy Markdown
Collaborator

@pcmoritz pcmoritz May 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the database out of the backends, and even better would be to not write the proxy_url into the database (since it is ephemeral), how about we just make it an (optional) attribute of the backend?

If it really needs to be in the database, I think the better way would be to re-use the DB connection in engine.py.

It is better to keep the database out of the backend to avoid propagating it everywhere in the future.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm good point. i think because the proxy URL doesn't exist until _create_new_inference_client is run in the backend it's tricky to configure as an attribute up front

but I can move it up to reuse the engine.py db connection so it's cleaner

Address review: SkyRLTrainBackend should not know about the database.

Backend changes:
- Drop _engine_database_url, set_engine_database_url, and the entire
  inline DB write body (sqlmodel imports, create_engine/dispose, table
  schema create_all).
- Add set_inference_state_publisher(fn: Callable[[str | None], None])
  and a thin _publish_inference_state helper that invokes the callback
  with the current proxy URL. Failures from the callback are logged but
  do not propagate — local state reset on delete_model must complete.

Engine changes:
- Wire the publisher at construction time:
    backend.set_inference_state_publisher(self._write_inference_state_to_db)
- New _write_inference_state_to_db reuses self.db_engine (already
  created at line 241) — no per-write create_engine churn.

Result:
- Backend has zero DB symbols (verified: no sqlmodel/EngineStateDB
  references in the module).
- Engine owns the connection; one Session per write.
- The contract surface between the two is a single Optional callable.
- Easy to swap DB for a different sink (Unix socket, HTTP push) by
  replacing the engine's _write_inference_state_to_db.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants