Skip to content

Commit f9cf81b

Browse files
authored
[feat][inf] Multi-LoRA serving for RemoteInferenceClient (NovaSky-AI#1579)
## Summary Adds dynamic multi-LoRA serving to the new inference path: `RemoteInferenceClient` can now load/unload arbitrary LoRA adapters at runtime, route every data-plane call to a specific adapter (or the base model) via an explicit `model` parameter, and serve from multiple adapters concurrently. ## Key design decision: `model` is always explicit Every data-plane method requires the caller to identify the target model — there is no implicit fallback to a client-side default. | Method | How `model` is supplied | |---|---| | `generate(input_batch, model)` | Required keyword argument | | `sample(payload)`, `chat_completion(payload)`, `completion(payload)`, `render_chat_completion(payload)` | Required field in the request body; missing/empty raises `ValueError` | Callers know (or can resolve) which adapter they are addressing — there is no client-side guessing about which LoRA is "active". `RemoteInferenceClient.model_name` is now *only* the base model the vLLM server was started with; it is consumed internally by `tokenize`/`detokenize` (LoRA-agnostic) and is no longer auto-injected into data-plane requests. ## What's new - **Control plane** on `RemoteInferenceClient`: - `load_lora_adapter(name, path, load_inplace=False)` — fans out to all backend servers via `POST /v1/load_lora_adapter`. - `unload_lora_adapter(name)` — symmetric `POST /v1/unload_lora_adapter`. - **Config knobs** plumbed to vLLM: - `trainer.policy.model.lora.max_loras` (concurrent adapters per batch) - `trainer.policy.model.lora.max_cpu_loras` (CPU LRU cache size) - **`SKYRL_LORA_ADAPTER_NAME`** promoted to a public module-level constant; FSDP and Megatron (unmerged) workers register their trained adapter under this name. ## Single source of truth for "what does the policy resolve to?" Added `resolve_policy_model_name(cfg)` in `skyrl/backends/skyrl_train/inference_servers/utils.py`. It returns: - `SKYRL_LORA_ADAPTER_NAME` when the worker registers a LoRA adapter on the inference engine (FSDP + LoRA, or Megatron + LoRA with `merge_lora=False`). - `cfg.trainer.policy.model.path` otherwise (including Megatron + LoRA with `merge_lora=True`, where the engine receives merged base weights). This is called **once at wiring time** and threaded through: - `skyrl_train_backend._sample_with_remote_client` - `SkyRLGymGenerator` / `SkyRLVLMGymGenerator` (new required `policy_model_name: str` constructor arg) - `main_base.get_generator` - `SkyRLBackend` in `skyrl-agent` The `is_single_lora` ternary that previously decided `client.model_name` at construction is gone from both `skyrl_train_backend` and `main_base`. ## Breaking change: legacy inference path The legacy `InferenceEngineClient` (Ray-wrapped) path is **not updated** for multi-LoRA. We assume `_SKYRL_USE_NEW_INFERENCE=1` (the current default) is in use everywhere. The legacy `LoraLoadRequest` smuggling path still works for the single-LoRA legacy flow, but anything calling the legacy client's `generate`/`sample`/etc. without `model=` plumbing will need updating before that path can serve LoRAs. ## Tests - **Mock tests** (`tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py`): - `TestLoRAControlPlane` — fan-out, in-place reload, conflict on `load_inplace=False`, fan-out unload, 404 on unknown unload. - `TestExplicitModelRequired` — every body-style method raises `ValueError` when `model` is missing; `generate` raises `TypeError` without the kwarg. - All existing + new mock tests pass. - **Real-GPU multi-LoRA tests** (`tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_multi_lora_serving.py`): - `test_multi_lora_interleaved_generation` — load `lora-meow` + `lora-woof`, interleave per-call routing, verify each adapter's signature output. - `test_lora_inplace_reload_isolated` — in-place reload of one adapter must not perturb the other. --------- Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: hao-aaron <ahao@anyscale.com>
1 parent a502e48 commit f9cf81b

19 files changed

Lines changed: 883 additions & 86 deletions

File tree

skyrl-agent/skyrl_agent/integrations/skyrl_train/skyrl_train_backend.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
from typing import Any, List
2+
3+
from skyrl.backends.skyrl_train.inference_servers.utils import resolve_policy_model_name
4+
25
from ..base import AsyncInferBackend, GeneratorOutput, GeneratorInput
36

47

58
class SkyRLBackend(AsyncInferBackend):
69
def __init__(self, infer_engine, tokenizer: Any = None, cfg: Any = None):
710
self.client = infer_engine
11+
# Resolve the name the inference engine knows the policy by (base
12+
# model or registered LoRA adapter) once at construction. Threaded
13+
# into every ``client.generate`` call so the data plane never has
14+
# to guess the target adapter.
15+
self.policy_model_name = resolve_policy_model_name(cfg) if cfg is not None else self.client.model_name
816

917
async def async_generate_prompts(self, prompts: Any, sampling_params: Any, **kwargs) -> List[str]:
1018
input_obj = {
1119
"prompts": [prompts],
1220
"session_ids": [kwargs.get("request_id", None)],
1321
"sampling_params": sampling_params,
1422
}
15-
output = await self.client.generate(input_obj)
23+
output = await self.client.generate(input_obj, model=self.policy_model_name)
1624
return output["responses"][0], output["stop_reasons"][0]
1725

1826
async def async_generate_ids(self, input_ids: List[int], sampling_params: Any, **kwargs) -> List[str]:
@@ -21,7 +29,7 @@ async def async_generate_ids(self, input_ids: List[int], sampling_params: Any, *
2129
"session_ids": [kwargs.get("request_id", None)],
2230
"sampling_params": sampling_params,
2331
}
24-
output = await self.client.generate(input_obj)
32+
output = await self.client.generate(input_obj, model=self.policy_model_name)
2533
# todo(@csy) probably need to be finish_reason
2634
# https://github.com/vllm-project/vllm/blob/a0f8a7964694a6077689b242b5eca95de392d4bb/vllm/v1/engine/__init__.py#L22
2735
meta_info = {

skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py

Lines changed: 116 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@
7979

8080
_DATA_PLANE_RETRIES = 30
8181

82+
SKYRL_LORA_ADAPTER_NAME = "skyrl-lora"
83+
"""Default LoRA adapter name used for single-LoRA training inside SkyRL."""
84+
8285
_TINKER_SAMPLE_TO_VLLM_PARAM_MAP = {
8386
"temperature": "temperature",
8487
"max_tokens": "max_tokens",
@@ -192,14 +195,20 @@ class RemoteInferenceClient:
192195
reports the full DP world size per server, so we divide by num_deployments."""
193196

194197
model_name: str = "default"
195-
"""Model name for OpenAI-compatible API calls."""
198+
"""The base model identifier the inference server was started with.
199+
200+
Always the base model — never a LoRA adapter name. LoRA adapters are
201+
addressed by the names callers register them under via
202+
``load_lora_adapter(name, path)``, and per-call routing is done by
203+
passing that name as ``model`` on the data-plane methods.
204+
205+
Used internally only by ``tokenize``/``detokenize``, which are LoRA-
206+
agnostic but still require a ``model`` field per the OpenAI schema.
207+
"""
196208

197209
enable_return_routed_experts: bool = False
198210
"""Whether to return routed expert indices (R3 / rollout router replay)."""
199211

200-
active_lora_name: Optional[str] = None
201-
"""Name of the active LoRA adapter. If set, generation requests use this adapter instead of the base model."""
202-
203212
uses_lora_weight_sync: bool = False
204213
"""True when the trainer syncs LoRA adapters (rather than full/merged weights). When True,
205214
`sleep()` is forced to level=1: level=2 discards the base model from VRAM with no CPU backup,
@@ -317,9 +326,30 @@ async def _post(self, url: str, json: Dict[str, Any], headers: Optional[Dict[str
317326
# Data Plane
318327
# ---------------------------
319328

329+
def _resolve_model(self, model: Optional[str], method_name: str) -> str:
330+
"""Pick the target model name for a data-plane call.
331+
332+
- If ``model`` is non-empty, use it as-is.
333+
- Otherwise, when LoRA is in use (``uses_lora_weight_sync=True``) raise
334+
``ValueError`` — the caller must name the adapter explicitly because
335+
falling back to the base model would silently bypass LoRA.
336+
- Otherwise return ``self.model_name`` (the base model the server was
337+
started with).
338+
"""
339+
if model:
340+
return model
341+
if self.uses_lora_weight_sync:
342+
raise ValueError(
343+
f"RemoteInferenceClient.{method_name}: `model` is required when LoRA "
344+
f"is enabled (uses_lora_weight_sync=True). Pass the LoRA adapter name "
345+
f"explicitly so the request doesn't silently target the base model."
346+
)
347+
return self.model_name
348+
320349
async def generate(
321350
self,
322351
input_batch: InferenceEngineInput,
352+
model: Optional[str] = None,
323353
) -> InferenceEngineOutput:
324354
"""
325355
Generate completions via /v1/completions.
@@ -335,10 +365,14 @@ async def generate(
335365
336366
Args:
337367
input_batch: Contains prompt_token_ids, sampling_params, and optional session_ids.
368+
model: Optional model identifier — the base model name or a loaded
369+
LoRA adapter name. When omitted, defaults to ``self.model_name``
370+
if LoRA is not in use; raises ``ValueError`` if it is.
338371
339372
Returns:
340373
InferenceEngineOutput with responses, response_ids, and stop_reasons.
341374
"""
375+
model = self._resolve_model(model, "generate")
342376

343377
prompt_token_ids = input_batch.get("prompt_token_ids")
344378
if prompt_token_ids is None:
@@ -373,13 +407,15 @@ async def _throttled_generate(idx: int) -> Dict[str, Any]:
373407
sampling_params=sampling_params,
374408
session_id=session_ids[idx] if session_ids and idx < len(session_ids) else None,
375409
mm_features=mm_features[idx] if mm_features and idx < len(mm_features) else None,
410+
model=model,
376411
)
377412
async with gen_sem:
378413
return await self._generate_single(
379414
prompt_token_ids=prompt_token_ids[idx],
380415
sampling_params=sampling_params,
381416
session_id=session_ids[idx] if session_ids and idx < len(session_ids) else None,
382417
mm_features=mm_features[idx] if mm_features and idx < len(mm_features) else None,
418+
model=model,
383419
)
384420

385421
async def _throttled_detokenize(token_ids: List[int]) -> str:
@@ -407,6 +443,7 @@ async def _generate_single(
407443
prompt_token_ids: List[int],
408444
sampling_params: Dict[str, Any],
409445
session_id: Optional[Any],
446+
model: str,
410447
mm_features: Optional[MultiModalFeatures] = None,
411448
) -> Dict[str, Any]:
412449
"""
@@ -425,12 +462,9 @@ async def _generate_single(
425462
else f"{self.proxy_url}/inference/v1/generate"
426463
)
427464

428-
# Use LoRA adapter name if one is active, otherwise use base model name
429-
effective_model = self.active_lora_name if self.active_lora_name else self.model_name
430-
431465
payload: dict[str, Any] = {
432466
"sampling_params": sampling_params,
433-
"model": effective_model,
467+
"model": model,
434468
"token_ids": prompt_token_ids,
435469
}
436470
if mm_features:
@@ -466,6 +500,7 @@ async def _render_for_sample(
466500
self,
467501
prompt: Dict[str, Any],
468502
session_id: Optional[str],
503+
model: str,
469504
) -> Tuple[List[int], Optional[MultiModalFeatures]]:
470505
"""Build token_ids and optional multi-modal features from a Tinker prompt.
471506
@@ -497,10 +532,9 @@ async def _render_for_sample(
497532
url = c["location"]
498533
content_parts.append({"type": "image_url", "image_url": {"url": url}})
499534

500-
effective_model = self.active_lora_name if self.active_lora_name else self.model_name
501535
render_payload: Dict[str, Any] = {
502536
"json": {
503-
"model": effective_model,
537+
"model": model,
504538
"messages": [{"role": "user", "content": content_parts}],
505539
}
506540
}
@@ -548,7 +582,10 @@ async def _render_for_sample(
548582

549583
return final_token_ids, adjusted_features
550584

551-
async def sample(self, request_payload: SampleRequestPayload) -> SampleResponse:
585+
async def sample(
586+
self,
587+
request_payload: SampleRequestPayload,
588+
) -> SampleResponse:
552589
"""
553590
Sample completions via /inference/v1/generate (Tinker API).
554591
@@ -557,13 +594,16 @@ async def sample(self, request_payload: SampleRequestPayload) -> SampleResponse:
557594
558595
Args:
559596
request_payload: SampleRequestPayload with {"json": <request-body>}.
560-
Expected keys in json: prompt, num_samples, sampling_params, session_id,
561-
include_prompt_logprobs (bool), topk_prompt_logprobs (int).
597+
Expected keys in json: prompt, num_samples, sampling_params,
598+
session_id, include_prompt_logprobs (bool), topk_prompt_logprobs (int).
599+
``model`` is optional and resolved via ``_resolve_model``.
562600
563601
Returns:
564602
SampleResponse with type="sample", sequences list, prompt_logprobs, and topk_prompt_logprobs.
565603
"""
566604
session_id, body = _extract_session_id_and_body(request_payload)
605+
model = self._resolve_model(body.get("model"), "sample")
606+
body["model"] = model
567607

568608
prompt = body.get("prompt", {})
569609
num_samples = body.get("num_samples", 1)
@@ -581,7 +621,7 @@ async def sample(self, request_payload: SampleRequestPayload) -> SampleResponse:
581621

582622
# Render prompt: flatten text tokens and, if images are present,
583623
# call the render endpoint to get placeholder tokens + features.
584-
token_ids, mm_features = await self._render_for_sample(prompt, session_id)
624+
token_ids, mm_features = await self._render_for_sample(prompt, session_id, model=model)
585625

586626
# Map Tinker SamplingParams → vLLM format
587627
sampling_params: Dict[str, Any] = {
@@ -596,11 +636,9 @@ async def sample(self, request_payload: SampleRequestPayload) -> SampleResponse:
596636
if val is not None:
597637
sampling_params[vllm_key] = val
598638

599-
effective_model = self.active_lora_name if self.active_lora_name else self.model_name
600-
601639
payload: Dict[str, Any] = {
602640
"sampling_params": sampling_params,
603-
"model": effective_model,
641+
"model": model,
604642
"token_ids": token_ids,
605643
}
606644
if mm_features is not None:
@@ -679,13 +717,17 @@ async def chat_completion(
679717
680718
Args:
681719
request_payload: Dict with {"json": <request-body>, "headers": <headers-dict>}.
682-
The request body should be OpenAI-compatible chat completion request.
683-
session_id can be included in json for consistent routing.
720+
The request body must be an OpenAI-compatible chat completion
721+
request. ``model`` is optional and resolved via
722+
``_resolve_model``; if omitted the body is mutated to inject the
723+
resolved value before forwarding to vLLM. ``session_id`` can be
724+
included in the body for consistent routing.
684725
685726
Returns:
686727
OpenAI-compatible chat completion response.
687728
"""
688729
session_id, body = _extract_session_id_and_body(request_payload)
730+
body["model"] = self._resolve_model(body.get("model"), "chat_completion")
689731

690732
headers = {"Content-Type": "application/json"}
691733
if session_id:
@@ -708,13 +750,16 @@ async def render_chat_completion(
708750
709751
Args:
710752
request_payload: Dict with {"json": <request-body>}.
711-
The request body should be OpenAI-compatible chat completion request.
712-
session_id can be included in json for consistent routing.
753+
The request body should be OpenAI-compatible chat completion
754+
request. ``model`` is optional and resolved via
755+
``_resolve_model``. session_id can be included in json for
756+
consistent routing.
713757
714758
Returns:
715759
Rendered chat completion response (template-applied prompt and token IDs).
716760
"""
717761
session_id, body = _extract_session_id_and_body(request_payload)
762+
body["model"] = self._resolve_model(body.get("model"), "render_chat_completion")
718763

719764
headers = {"Content-Type": "application/json"}
720765
if session_id:
@@ -737,13 +782,16 @@ async def completion(
737782
738783
Args:
739784
request_payload: Dict with {"json": <request-body>, "headers": <headers-dict>}.
740-
The request body should be OpenAI-compatible completion request.
741-
session_id can be included in json for consistent routing.
785+
The request body should be OpenAI-compatible completion
786+
request. ``model`` is optional and resolved via
787+
``_resolve_model``. session_id can be included in json for
788+
consistent routing.
742789
743790
Returns:
744791
OpenAI-compatible completion response.
745792
"""
746793
session_id, body = _extract_session_id_and_body(request_payload)
794+
body["model"] = self._resolve_model(body.get("model"), "completion")
747795

748796
headers = {"Content-Type": "application/json"}
749797
if session_id:
@@ -1010,7 +1058,7 @@ async def update_named_weights(
10101058
"""
10111059
Update model weights via vLLM native /update_weights. Used for full parameter fine-tuning.
10121060
1013-
For LoRA weight sync, use update_lora_from_disk() instead.
1061+
For LoRA weight sync, use load_lora_adapter() instead.
10141062
10151063
Args:
10161064
update_info: Dict with keys expected by vLLM (names, dtype_names, shapes, packed, etc.)
@@ -1095,20 +1143,21 @@ async def finish_weight_update(self) -> Dict[str, Any]:
10951143
{"method": "finish_weight_update"},
10961144
)
10971145

1098-
async def update_lora_from_disk(
1146+
async def load_lora_adapter(
10991147
self,
1148+
lora_name: str,
11001149
lora_path: str,
11011150
) -> Dict[str, Any]:
11021151
"""
1103-
Update LoRA adapter weights by loading from disk on all backend servers
1104-
via the SkyRL custom /skyrl/v1/load_lora_adapter endpoint.
1152+
Load (or reload) a LoRA adapter on all backend servers via the SkyRL
1153+
custom /skyrl/v1/load_lora_adapter endpoint.
11051154
1106-
Always loads under self.active_lora_name so the same lora_int_id slot
1107-
is reused across weight syncs.
1155+
After loading, generation/chat/completion requests can target this LoRA
1156+
by passing ``model=lora_name``.
11081157
1109-
TODO(aaron): switch back to /v1/load_lora_adapter (with "load_inplace": True)
1110-
once the upstream fix in https://github.com/vllm-project/vllm/pull/41482
1111-
lands in a vLLM release we depend on.
1158+
TODO(aaron): switch back to vLLM's /v1/load_lora_adapter once the
1159+
upstream fix in https://github.com/vllm-project/vllm/pull/41482 lands
1160+
in a vLLM release we depend on.
11121161
11131162
The custom endpoint (defined in vllm_server_actor.py) wraps add_lora
11141163
with load_inplace=True (so the engine reloads the freshly-written
@@ -1118,19 +1167,13 @@ async def update_lora_from_disk(
11181167
num_engines>=2 — see vllm_server_actor.py:_skyrl_load_lora_adapter for
11191168
the detailed explanation.
11201169
1121-
After loading, generation requests will automatically use the LoRA
1122-
adapter by setting the model name to the LoRA adapter name.
1123-
11241170
Args:
1171+
lora_name: Name to register the adapter under on each server.
11251172
lora_path: Path to the LoRA adapter on disk (must be accessible from servers).
11261173
11271174
Returns:
11281175
Dict mapping server_url to response.
11291176
"""
1130-
if self.active_lora_name is None:
1131-
raise ValueError("active_lora_name must be set on RemoteInferenceClient before loading a LoRA adapter.")
1132-
1133-
lora_name = self.active_lora_name
11341177
session = await self._get_session()
11351178

11361179
async def _load_on_server(server_url: str):
@@ -1148,6 +1191,40 @@ async def _load_on_server(server_url: str):
11481191

11491192
return {url: resp for url, resp in results}
11501193

1194+
async def unload_lora_adapter(self, lora_name: str) -> Dict[str, Any]:
1195+
"""
1196+
Unload a previously-loaded LoRA adapter on all backend servers via /v1/unload_lora_adapter.
1197+
1198+
After unloading, ``lora_name`` is no longer accepted as a ``model``
1199+
target on any server. The underlying CPU/GPU LRU entries on vLLM age
1200+
out naturally as new adapters are loaded.
1201+
1202+
Args:
1203+
lora_name: Name of the adapter to unload.
1204+
1205+
Returns:
1206+
Dict mapping server_url to response.
1207+
"""
1208+
payload = {"lora_name": lora_name}
1209+
1210+
# Mirror load_lora_adapter: vLLM returns plain text on success and JSON
1211+
# ErrorResponse (e.g. 404) on failure.
1212+
session = await self._get_session()
1213+
1214+
async def _unload_on_server(server_url: str):
1215+
url = f"{server_url}/v1/unload_lora_adapter"
1216+
async with session.post(url, json=payload) as resp:
1217+
if resp.status >= 400:
1218+
body = await resp.json()
1219+
raise_for_status(resp, body)
1220+
return server_url, {"status": resp.status, "body": await resp.text()}
1221+
1222+
results = await asyncio.gather(*[_unload_on_server(url) for url in self.server_urls])
1223+
1224+
logger.info(f"Unloaded LoRA adapter '{lora_name}'")
1225+
1226+
return {url: resp for url, resp in results}
1227+
11511228
# ---------------------------
11521229
# Info
11531230
# ---------------------------

0 commit comments

Comments
 (0)