Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/gpu_ci_run_tinker_skyrl_train_backend.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ export CI=true
# v1 single-tenant sample guard, per-adapter Adam step isolation, and
# delete-then-train continuity.
uv run --directory . --isolated --extra tinker --extra megatron --with pytest --with pytest-timeout \
pytest -s --timeout=600 tests/tinker/skyrl_train/test_multi_lora_megatron.py
pytest -s --timeout=600 tests/tinker/skyrl_train/
2 changes: 1 addition & 1 deletion docs/content/docs/tinker/multi_tenancy.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ uv run --extra tinker --extra megatron -m skyrl.tinker.api \
"generator.inference_engine.tensor_parallel_size": 1,
"trainer.policy.model.lora.max_loras": 2,
"trainer.policy.model.lora.max_cpu_loras": 2,
"trainer.logprobs_chunk_size": null,
"trainer.logprobs_chunk_size": null
}'
```

Expand Down
39 changes: 39 additions & 0 deletions skyrl/backends/skyrl_train_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import tarfile
import tempfile
from typing import Callable

import ray
import torch
Expand Down Expand Up @@ -138,6 +139,13 @@ def __init__(self, base_model: str, config: SkyRLTrainBackendOverrides):
self._server_groups: list = []
self._inference_router = None

# Optional hook invoked on inference-engine state changes (after
# _create_new_inference_client, on delete_model teardown). The host
# (e.g. the Tinker engine subprocess) wires the persistence side via
# set_inference_state_publisher. None when running outside a host
# that needs to be notified (unit tests, non-Tinker uses).
self._inference_state_publisher: Callable[[str | None], None] | None = None

def has_model(self, model_id: str) -> bool:
return model_id in self._model_ids_to_role

Expand Down Expand Up @@ -325,6 +333,29 @@ def _create_legacy_inference_client(self):
self._cfg.generator.inference_engine,
)

def set_inference_state_publisher(self, publisher: Callable[[str | None], None]) -> None:
"""Wire a callback invoked when the inference proxy URL changes.

Called by the host (e.g. the Tinker engine subprocess) after backend
construction. The callback receives the current proxy URL after a
new inference engine is brought up, or ``None`` on teardown. The
backend has no opinion on what the callback does — typical use is
to persist the URL somewhere the API process can read.
"""
self._inference_state_publisher = publisher

def _publish_inference_state(self, proxy_url: str | None) -> None:
"""Invoke the publisher if set; best-effort (failure must not raise).

Callers rely on local state being reset regardless of publish outcome.
"""
if self._inference_state_publisher is None:
return
try:
self._inference_state_publisher(proxy_url)
except Exception as e:
logger.warning(f"Inference-state publisher failed (proxy_url={proxy_url!r}): {e}")

def _create_new_inference_client(self):
"""Create new HTTP-based inference client."""
from skyrl.backends.skyrl_train.inference_servers.setup import (
Expand All @@ -341,6 +372,10 @@ def _create_new_inference_client(self):
self._server_groups = server_setup.server_groups
self._inference_engine_client = client

# Publish inference endpoint so the API can forward samples directly
# (only meaningful in non-colocated mode; the API gates on this).
self._publish_inference_state(server_setup.proxy_url)

def _ensure_inference_engines(self):
"""Lazily create inference engines and init weight sync on first sampling-related call."""
if self._inference_engines_initialized:
Expand Down Expand Up @@ -496,6 +531,10 @@ def delete_model(self, model_id: str) -> None:
self._renderer = None
self._colocate_pg = None
self._base_lora_signature = None
# Local state is fully reset above. Notify the host last so a
# publisher failure can't leave the controller half-torn-down.
# Next _create_new_inference_client repopulates.
self._publish_inference_state(None)
logger.info(f"Successfully deleted model {model_id}")

def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch, role: str) -> TrainingInputBatch:
Expand Down
42 changes: 40 additions & 2 deletions skyrl/tinker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
enable_sqlite_wal,
get_async_database_url,
)
from skyrl.tinker.extra import ExternalInferenceClient
from skyrl.tinker.extra import (
ExternalInferenceClient,
SkyRLTrainInferenceForwardingClient,
)
from skyrl.utils.log import get_uvicorn_log_config, logger
from skyrl.utils.storage import download_file

Expand Down Expand Up @@ -111,10 +114,36 @@ async def lifespan(app: FastAPI):
async with app.state.db_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)

# Setup external inference client if configured
# Setup external inference client if configured.
#
# Three cases:
# 1. external_inference_url set: forward sample requests to a fully
# external vLLM (existing behavior).
# 2. backend in (megatron, fsdp) and colocate_all=False: install
# SkyRLTrainInferenceForwardingClient so sample requests go directly
# to the SkyRL-Train-managed vLLM, bypassing the engine's serial loop.
# 3. otherwise (JAX, colocated SkyRL-Train, etc.): route everything
# through the engine subprocess.
#
# The colocated path stays on the engine because vLLM is asleep during
# training and only the engine's synchronous sample path knows how to
# wake it (save_weights_for_sampler → broadcast → sample).
backend_name = app.state.engine_config.backend
backend_cfg = app.state.engine_config.backend_config or {}
# SkyRL-Train default is colocate_all=True; only opt into forwarding
# when the operator explicitly sets it to False.
is_colocated = bool(backend_cfg.get("trainer.placement.colocate_all", True))
if app.state.engine_config.external_inference_url:
app.state.external_inference_client = ExternalInferenceClient(app.state.engine_config, app.state.db_engine)
logger.info(f"External engine configured: {app.state.engine_config.external_inference_url}")
elif backend_name in ("megatron", "fsdp") and not is_colocated:
app.state.external_inference_client = SkyRLTrainInferenceForwardingClient(
app.state.engine_config, app.state.db_engine
)
logger.info(
"SkyRL-Train inference forwarding client enabled for non-colocated backend=%s",
backend_name,
)
else:
app.state.external_inference_client = None
logger.info("Using internal engine for inference")
Expand Down Expand Up @@ -158,6 +187,15 @@ def force_exit():
shutting_down = True
monitor_task.cancel()

# Close the forwarding client's persistent httpx connection pool if we
# installed one. Cheap no-op when external_inference_client doesn't own
# an httpx client (ExternalInferenceClient creates one per call).
inference_client = getattr(app.state, "external_inference_client", None)
aclose = getattr(inference_client, "aclose", None)
if aclose is not None:
with suppress(Exception):
await aclose()

logger.info(f"Stopping background engine (PID {app.state.background_engine.pid})")
with suppress(ProcessLookupError):
background_engine.terminate()
Expand Down
15 changes: 15 additions & 0 deletions skyrl/tinker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ class EngineConfig(BaseModel):
default=Path("/tmp/lora_models"),
description="Directory where LoRA models will be extracted for external inference engines",
)
forwarding_inference_max_connections: int | None = Field(
default=None,
description=(
"Optional cap on the httpx connection pool used by "
"SkyRLTrainInferenceForwardingClient to forward sample requests to "
"the engine-managed vLLM. The natural backpressure chain is "
"httpx pool -> vllm-router -> vLLM's max_num_seqs; this knob "
"only sets the API-side connection ceiling. Default `None` is "
"unlimited — vllm-router/vLLM are the only queues — which is "
"usually what you want. Raise your host's `ulimit -n` for very "
"high fan-out (the only hard cost of unlimited connections is "
"file descriptors). Set an int to enforce a per-API-process cap."
),
json_schema_extra={"argparse_type": lambda v: None if v == "None" else int(v)},
)
session_cleanup_interval_sec: int = Field(
default=60,
description="How often to check for stale sessions (seconds). Set to -1 to disable cleanup.",
Expand Down
19 changes: 19 additions & 0 deletions skyrl/tinker/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,22 @@ class SamplingSessionDB(SQLModel, table=True):
base_model: str | None = None
model_path: str | None = None
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True))


class EngineStateDB(SQLModel, table=True):
"""Engine→API handoff for the inference engine the backend stands up.

Singleton row (``singleton_id=1``). Written by the backend when a new
inference client is built (or torn down) and read by the API's
forwarding client to resolve the vLLM proxy URL.
"""

__tablename__ = "engine_state"

singleton_id: int = Field(default=1, primary_key=True)

# Proxy URL of the engine-managed vLLM. None when no vLLM has been
# stood up yet (no create_model, FFT path, or last delete tore down).
inference_proxy_url: str | None = None

updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True))
22 changes: 22 additions & 0 deletions skyrl/tinker/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from skyrl.tinker.db_models import (
CheckpointDB,
CheckpointStatus,
EngineStateDB,
FutureDB,
ModelDB,
RequestStatus,
Expand Down Expand Up @@ -246,6 +247,13 @@ def __init__(
backend_config = backend_config_class(**config.backend_config)
self.backend = backend_class(config.base_model, backend_config)

# Backends that support async sample routing notify us when their
# inference endpoint changes; we persist it to EngineStateDB so the
# API process can forward sample requests directly. Backends stay
# DB-free; only the engine owns the connection.
if hasattr(self.backend, "set_inference_state_publisher"):
self.backend.set_inference_state_publisher(self._write_inference_state_to_db)

# Track last cleanup time for periodic stale session cleanup
self._last_cleanup_time: float = time.time()

Expand All @@ -256,6 +264,20 @@ def metrics(self) -> types.EngineMetrics:
"""Pass-through to backend metrics for backwards compatibility."""
return self.backend.metrics

def _write_inference_state_to_db(self, proxy_url: str | None) -> None:
"""Upsert the singleton EngineStateDB row.

Wired into the backend via set_inference_state_publisher so the API
process can resolve the engine-managed vLLM URL on the async sample
routing path. ``proxy_url=None`` clears the row (post-teardown).
"""
with Session(self.db_engine) as session:
row = session.get(EngineStateDB, 1) or EngineStateDB(singleton_id=1)
row.inference_proxy_url = proxy_url
row.updated_at = datetime.now(timezone.utc)
session.add(row)
session.commit()

@contextmanager
def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoint_type: types.CheckpointType):
"""Context manager to handle checkpoint DB status updates.
Expand Down
5 changes: 4 additions & 1 deletion skyrl/tinker/extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from skyrl.tinker.extra.external_inference import ExternalInferenceClient
from skyrl.tinker.extra.skyrl_train_inference_forwarding import (
SkyRLTrainInferenceForwardingClient,
)

__all__ = ["ExternalInferenceClient"]
__all__ = ["ExternalInferenceClient", "SkyRLTrainInferenceForwardingClient"]
Loading
Loading