[tinker] Forward sample requests directly to backend vLLM (non-colocated)#1638
[tinker] Forward sample requests directly to backend vLLM (non-colocated)#1638erictang000 wants to merge 14 commits into
Conversation
When SkyRL-Train runs non-colocated, vLLM is always-on but its sample
capacity is wasted: the engine serializes sample behind forward_backward
in its 100ms tick loop, so a sample submitted at the start of a 30s
training step waits 30s. For multi-tenant RL the queueing compounds.
This change reuses the existing EXTERNAL request-type plumbing (built
for fully external vLLMs) but points it at the engine-managed vLLM:
- New EngineStateDB singleton row carries the engine-managed vLLM proxy
URL from the engine subprocess to the API process.
- SkyRLTrainBackend._publish_engine_state writes it on every
_create_new_inference_client and clears it on full teardown.
- New BackendForwardingInferenceClient mirrors ExternalInferenceClient
but reads the proxy URL lazily from EngineStateDB and uses
model=<model_id> (the name save_weights_for_sampler registered with
vLLM). Refreshes the cached URL once on connection error.
- api.py lifespan installs it when backend in (megatron, fsdp) and
trainer.placement.colocate_all=False. Colocated and JAX paths
unchanged.
Adds tests/tinker/skyrl_train/test_async_sample_routing.py covering:
- EngineStateDB is published with proxy URL + is_colocated=False
- Sample futures use RequestType.EXTERNAL (off the engine queue)
- Sample latency is bounded well below a concurrent training stream's
duration (the parallelism test)
- Many concurrent samples across two adapters all resolve
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
- Forwarding client was sending logprobs=true to /v1/completions; vLLM (and the upstream vllm-router) require an int. Use 1 to get the chosen token's logprob, which is what Tinker's SampleOutput needs. - The SDK's user-facing future doesn't expose the server-side request_id, so test_sample_uses_external_path now snapshots max(request_id) before the sample, then queries for any new EXTERNAL row afterward — robust to the module-scoped server fixture sharing state across tests. All four async sample routing tests pass (sample latency 1.3s vs concurrent 10s training, ratio 0.13). All six existing multi-LoRA Megatron tests still pass under the new forwarding path. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
There was a problem hiding this comment.
Code Review
This pull request implements an asynchronous sample routing path, enabling the API process to forward sample requests directly to a backend-managed vLLM instance, thereby bypassing the engine's serial scheduling loop. Key changes include the introduction of an EngineStateDB singleton for state handoff, the BackendForwardingInferenceClient for request handling, and a new suite of end-to-end tests. Feedback points out that the vLLM payload is missing stop parameters, the response parsing logic needs to be more robust regarding finish reasons and logprobs, and the use of a persistent HTTP client is recommended to improve efficiency.
I am having trouble creating individual review comments. Click here to see my feedback.
skyrl/tinker/extra/backend_forwarding_inference.py (121-137)
The payload for the vLLM /v1/completions request is missing the stop parameters (stop and stop_token_ids). These are present in the SamplingParams and should be forwarded to ensure generation terminates correctly according to the request configuration.
payload = {
"model": model_name,
"prompt": prompt_tokens,
"n": sample_req.num_samples,
"seed": sample_req.sampling_params.seed,
"max_tokens": sample_req.sampling_params.max_tokens,
"temperature": sample_req.sampling_params.temperature,
"top_p": sample_req.sampling_params.top_p,
"top_k": sample_req.sampling_params.top_k,
# vLLM (and the upstream vllm-router) expects an integer for the
# OpenAI-compatible /v1/completions endpoint — the number of top
# tokens to return logprobs for. 1 gives the chosen token's
# logprob, which is what the Tinker SampleOutput requires.
"logprobs": 1,
"stream": False,
"return_token_ids": True,
}
if sample_req.sampling_params.stop_strings:
payload["stop"] = sample_req.sampling_params.stop_strings
if sample_req.sampling_params.stop_tokens:
payload["stop_token_ids"] = sample_req.sampling_params.stop_tokensskyrl/tinker/extra/backend_forwarding_inference.py (151-162)
The parsing logic for the vLLM response should be more robust and consistent with the internal engine path. Specifically:
- Normalization of
finish_reason: vLLM may returnstop_tokenor other strings that should be mapped tostopto remain compatible with the SDK's expectations (which typically expectsstoporlength). - Logprobs Safeguard: If
logprobsare missing or empty (e.g., due to an upstream issue), it's critical for RL workloads to have a fallback (e.g., zeros) to avoid downstream failures. - Prompt Logprobs: The
prompt_logprobsfield should returnNonerather than an empty list to match the behavior inskyrl_train_backend.py.
sequences = []
for choice in result["choices"]:
tokens = choice.get("token_ids", [])
lp = choice.get("logprobs") or {}
logprobs = lp.get("token_logprobs") or []
# Ensure logprobs exist (critical for RL)
if not logprobs and tokens:
logger.warning("No logprobs returned from vLLM - filling with zeros")
logprobs = [0.0] * len(tokens)
# Map vLLM finish reason to Tinker format
finish_reason = choice.get("finish_reason")
stop_reason = "stop" if finish_reason in ("stop", "stop_token") else "length"
sequences.append(
types.GeneratedSequence(
tokens=tokens,
logprobs=logprobs,
stop_reason=stop_reason,
)
)
return types.SampleOutput(sequences=sequences, prompt_logprobs=None)skyrl/tinker/extra/backend_forwarding_inference.py (139-142)
Creating a new httpx.AsyncClient for every request is inefficient as it prevents connection pooling and incurs the overhead of a new TCP/TLS handshake for every sample. Consider initializing a single persistent client in __init__ and reusing it across requests. If you do this, ensure the client is properly closed during the API server's shutdown phase in api.py lifespan.
… persistent httpx
Three review nits on the forwarding client:
1. Stop params (high): the vLLM /v1/completions payload was missing
`stop` and `stop_token_ids`. SamplingParams carries both
stop_strings and stop_tokens — forward them.
2. Response parse robustness (high): mirror the normalization in
skyrl_train_backend._sample_with_remote_client —
- normalize vLLM's finish_reason ({"stop","stop_token"} -> "stop",
else "length"; Tinker's GeneratedSequence.stop_reason is a Literal
so this matters)
- if vLLM returns None/empty token_logprobs, zero-fill (RL needs
these for advantages)
- return prompt_logprobs=None instead of []
3. Persistent httpx.AsyncClient (medium): a fresh client per call kills
connection pooling and pays a new TCP/TLS handshake per sample.
Construct the client once in __init__ and reuse. Expose
`aclose()` and call it from api.py lifespan shutdown.
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
…kpressure
Review feedback: the asyncio.Semaphore cap (default 64) was a real
bottleneck for high fan-out workloads (e.g. 4k concurrent requests from
multiple training runs), and it stacked unnecessarily on top of vLLM's
own scheduling (vllm-router -> per-server max_num_seqs).
Drop the semaphore. The remaining backpressure chain is:
httpx connection pool -> vllm-router -> vLLM max_num_seqs
The httpx pool is the only API-side cap. Rename the config knob to
`forwarding_inference_max_connections` to reflect what it actually
controls and bump the default from 64 to 1024. max_keepalive scales
with it (max_conn / 4, floor 32). Operators with very high fan-out can
raise this further.
Also fix a regression introduced in the last review pass: SamplingParams
on the API-level model uses a polymorphic .stop field (list[str] |
list[int]), not the normalized stop_strings/stop_tokens. Dispatch on
the element type to set "stop" vs "stop_token_ids" correctly.
Verified with the multi_tenancy.mdx RL quickstart:
Concurrent A+B (3 iters each, batch=16, group=4, max_tokens=128):
36s wall-clock; per-iter ~9-10s
Sequential B (warm): 25.5s for 3 iters; per-iter ~8.5s
Speedup ~1.4-1.9x with per-iter overhead ~1-2s under contention.
All 192 sample requests forwarded; 0 served by the engine.
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Per discussion: raising max_connections is essentially free (the only cost is file descriptors). Make the default UX "hammer vllm-router and let it queue" by setting forwarding_inference_max_connections to Optional[int] with default None = unlimited. Operators get the simplest path: just raise `ulimit -n` to match peak concurrent samples across all tenants, and let vllm-router + each vLLM server's max_num_seqs be the only queues. Multi-tenant SaaS-style deployments can still set an int to bound per-API-process FD usage. httpx maps max_connections=None to INT64_MAX internally (unlimited). Wire the argparse path so `--forwarding-inference-max-connections None` from the CLI parses to Python None (not the literal string). Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
High-priority correctness: - Double-checked lock in _resolve_proxy_url: hot path returns the cached URL without acquiring _cache_lock, so concurrent samples no longer serialize at this point under the unlimited-pool default. - Test _server_is_up: add explicit `import urllib.request`; the previous `import urllib.error` alone only worked when some other module (tinker SDK / transformers) had already pulled urllib.request into the urllib namespace. - delete_model: reorder so all local state is reset before calling _publish_engine_state, and make _publish_engine_state best-effort (log + return on any DB error). Prevents a partially-torn-down controller if SQLite writes fail during shutdown. Schema cleanup: - Drop EngineStateDB.is_colocated and inference_server_urls. Both were written by the backend but never read — lifespan checks colocation from the user's --backend-config dict, and the forwarding client only needs inference_proxy_url. _publish_engine_state signature simplified to (proxy_url,) only. Test docstrings: - Add test_concurrent_samples_per_adapter to module Coverage list. - Rewrite that test's docstring (referenced the removed semaphore; now describes the connection pool + model_id routing it actually exercises). - Drop the is_colocated/server_urls assertions from test_engine_state_published. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an asynchronous sample routing path designed to reduce inference latency by bypassing the engine's serial scheduling loop. Key additions include the BackendForwardingInferenceClient, which forwards requests directly to a backend-managed vLLM instance, and the EngineStateDB model to facilitate the handoff of proxy URLs between the engine and the API. The changes also include lifecycle management for publishing engine state and comprehensive end-to-end tests. Review feedback focuses on enhancing the robustness of the new forwarding client by adding null checks for database retrievals, broadening exception handling for network requests, and implementing safer JSON parsing for server responses.
…eForwardingClient
The previous "BackendForwardingInferenceClient" name was generic —
"backend" could mean anything. Rename to make it explicit that this
client forwards sample requests to the SkyRL-Train-managed inference
URL (parallel to ExternalInferenceClient which forwards to a
user-supplied external vLLM URL).
Class: BackendForwardingInferenceClient
-> SkyRLTrainInferenceForwardingClient
Module: skyrl.tinker.extra.backend_forwarding_inference
-> skyrl.tinker.extra.skyrl_train_inference_forwarding
Updated references in api.py, config.py, extra/__init__.py, and the
test file's docstring/comments. Lifespan log message updated.
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Three review nits on SkyRLTrainInferenceForwardingClient: 1. session.get(FutureDB, request_id) -> None when the future was deleted between asample handler scheduling us and our arrival here (manual cleanup, stale-session GC). Previously we'd AttributeError on `.result_data = ...`. Now: null-check, log a warning, return cleanly. The retrieve_future poller times out on the caller's side as the natural failure mode. 2. response.json() raises ValueError when the upstream returns non-JSON (e.g. vllm-router 502 HTML even with a 2xx status from an intermediate proxy). Wrap in try/except and surface the raw body + content-type so the failure is diagnosable from FutureDB's error_data alone. 3. Broaden retry catch from (ConnectError, ReadError, RemoteProtocolError) to httpx.RequestError. Now also covers TimeoutException and PoolTimeout. Verified all five specific classes are RequestError subclasses; HTTP 4xx/5xx (RuntimeError) is intentionally NOT retried — real upstream errors should surface. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
| self._cfg.generator.inference_engine, | ||
| ) | ||
|
|
||
| def set_engine_database_url(self, database_url: str) -> None: |
There was a problem hiding this comment.
Let's keep the database out of the backends, and even better would be to not write the proxy_url into the database (since it is ephemeral), how about we just make it an (optional) attribute of the backend?
If it really needs to be in the database, I think the better way would be to re-use the DB connection in engine.py.
It is better to keep the database out of the backend to avoid propagating it everywhere in the future.
There was a problem hiding this comment.
hmm good point. i think because the proxy URL doesn't exist until _create_new_inference_client is run in the backend it's tricky to configure as an attribute up front
but I can move it up to reuse the engine.py db connection so it's cleaner
Address review: SkyRLTrainBackend should not know about the database.
Backend changes:
- Drop _engine_database_url, set_engine_database_url, and the entire
inline DB write body (sqlmodel imports, create_engine/dispose, table
schema create_all).
- Add set_inference_state_publisher(fn: Callable[[str | None], None])
and a thin _publish_inference_state helper that invokes the callback
with the current proxy URL. Failures from the callback are logged but
do not propagate — local state reset on delete_model must complete.
Engine changes:
- Wire the publisher at construction time:
backend.set_inference_state_publisher(self._write_inference_state_to_db)
- New _write_inference_state_to_db reuses self.db_engine (already
created at line 241) — no per-write create_engine churn.
Result:
- Backend has zero DB symbols (verified: no sqlmodel/EngineStateDB
references in the module).
- Engine owns the connection; one Session per write.
- The contract surface between the two is a single Optional callable.
- Easy to swap DB for a different sink (Unix socket, HTTP push) by
replacing the engine's _write_inference_state_to_db.
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Description
When SkyRL-Train runs non-colocated (
colocate_all=false), vLLM is always-on but its sample capacity is wasted: the Tinker engine subprocess serializessamplebehindforward_backward/optim_stepin its 100ms tick loop, so a sample submitted at the start of a 30s training step waits 30s. For multi-tenant RL the queueing compounds — everysamplefrom any tenant queues behind every other tenant's training step.This PR hoists sample requests off the engine queue and into the API process's asyncio loop, where they're forwarded directly to the engine-managed vLLM. The colocated and JAX paths are unchanged.
How it works
We reuse the existing
RequestType.EXTERNALplumbing (already in place for fully external vLLM URLs):The cross-process plumbing was already wired for external vLLM URLs; the gap was that nothing was pointing it at the engine-managed vLLM. This PR closes that gap.
What's in the diff
8 files, +796/-3 LOC.
New
EngineStateDB(skyrl/tinker/db_models.py) — singleton row carrying engine→API handoff state (currently justinference_proxy_url+updated_at). The backend writes it after building vLLM; the forwarding client reads it lazily on first sample.SkyRLTrainBackend._publish_engine_state(skyrl/backends/skyrl_train_backend.py) — called inside_create_new_inference_clientafterbuild_new_inference_clientreturns, and ondelete_modelteardown to clear the row. Best-effort: a DB write failure logs and returns rather than corrupting controller state. Wired through a newset_engine_database_url(...)setter that the engine calls right after constructing the backend.New
SkyRLTrainInferenceForwardingClient(skyrl/tinker/extra/skyrl_train_inference_forwarding.py) — pair toExternalInferenceClient, with the same EXTERNAL future-write contract but resolves the target URL fromEngineStateDBinstead of from a user-suppliedEngineConfig.external_inference_url. Usesmodel=<model_id>(the namesave_weights_for_samplerregistered with vLLM viaload_lora_adapter).Key design choices:
max_num_seqs. Adding a Python semaphore on top would just serialize work above what vLLM already manages.httpx.AsyncClientwithforwarding_inference_max_connectionsdefaulting toNone(unlimited). The only cost of "unlimited" is file descriptors, so operators just raiseulimit -nto match peak concurrent samples. Set an int to enforce a per-process cap. Closed viaaclose()from the API lifespan..stopdispatch matchesapi.py SamplingParams.to_types()(list[str] →stop, list[int] →stop_token_ids).finish_reasonnormalized to Literal["stop", "length"]; missingtoken_logprobszero-filled; non-JSON responses (e.g. proxy 502 HTML) surfaced with content-type + body excerpt for diagnosis.httpx.RequestErrorumbrella —ConnectError,ReadError,RemoteProtocolError,TimeoutException,PoolTimeout). HTTP 4xx/5xx from vLLM is NOT retried — it's a real upstream signal.api.pylifespan — installs the forwarding client whenbackend in ("megatron", "fsdp")ANDtrainer.placement.colocate_all=False. Colocated runs and JAX keep the synchronous engine flow. Callsaclose()on shutdown.Synchronization invariants preserved:
validate_checkpoint(...)atapi.py:1037before the future is created.WorkerDispatch.save_weights_for_samplerbrackets the broadcast withpause_generation/resume_generation; vLLM's KEEP-mode pause freezes in-flight requests in its scheduler.colocate_all=false && backend ∈ (megatron, fsdp).Test plan
test_engine_state_published— aftersave_weights_for_sampler,EngineStateDB.inference_proxy_urlis populated with the engine-managed vLLM proxy URL.test_sample_uses_external_path— issued sample creates aFutureDBrow withrequest_type=EXTERNAL(off the engine queue).test_sample_concurrent_with_training_is_fast— central parallelism test. While 24forward_backward+optim_stepcalls stream in a background thread, a concurrent sample resolves in 1.3s vs 10s training stream (ratio 0.13×). Without async routing the sample would queue behind the entire training stream and the latencies would be comparable.test_concurrent_samples_per_adapter— 8 concurrent samples across two adapters all resolve via the forwarding client's connection pool and per-adaptermodel_idrouting on vLLM.test_multi_lora_megatron.pytests continue to pass under the new forwarding path (backwards compatible with the multi-LoRA RL workload).tests/tinker/test_api.pysingle-tenant paths pass.End-to-end multi-tenant RL bench following
docs/tinker/multi_tenancy.mdx#quickstart---two-rl-clients(Qwen3-0.6B, GSM8K, 4 policy GPUs + 1 vLLM, 3 iters × batch 16 × group 4 × max 128 tokens per client):Speedup ~1.93×, close to the theoretical 2× ceiling. Per-iter cost grows ~1-2s under contention (vLLM multiplexes both adapters' samples + engine batches both fwd_bwds), but total wall-clock halves because both tenants progress simultaneously instead of waiting their turn.
Server-side instrumentation: 240 sample requests forwarded via the API-process client; 0 served by the engine (single startup-time engine sample is pre-bench).
process_batch_requests(forward_backward, n=2)lines in the server log confirm the engine batches concurrent A+B fwd_bwd calls together.Operator config
To enable, set in
--backend-config:{ "trainer.placement.colocate_all": false, "trainer.policy.megatron_config.lora_config.merge_lora": false, "trainer.policy.model.lora.max_loras": <max concurrent adapters in a batch>, "trainer.policy.model.lora.max_cpu_loras": <total adapter capacity> }For very high fan-out workloads, raise the host's
ulimit -nto match peak concurrent samples (default httpx pool is unbounded). Operators who want a per-API-process FD cap can set--forwarding-inference-max-connections <N>.See docs/content/docs/tinker/multi_tenancy.mdx (added in the multi-lora-rl PR) for the full operator contract.
Out of scope
max_cpu_lorasis too low and vLLM evicts an adapter mid-run, the next sample 404s. Re-loading from disk on demand is a follow-up; for now the operator sizesmax_cpu_loras≥ expected concurrent adapters.httpx.RequestError. Application-level retry is the SDK's job (already implemented intinker/retry_handler).forward_backward/optim_stepstill serialize through the engine's main loop.