Skip to content

Commit 77f92c2

Browse files
authored
[lora][tinker] Add pause and resume for multi-tenant lora (#1657)
Addresses #1647 Today's `pause_generation()` is a **global** vLLM keep-mode pause: when one LoRA tenant syncs new weights, every other tenant's in-flight generation freezes for the duration of the swap. This blocks practical multi-tenant LoRA RL training. This PR adds a per-LoRA `lora_name` arg to pause/resume so a weight sync for adapter A only aborts A's requests; other adapters keep generating. The aborted requests come back with `finish_reason="abort"` and partial tokens, and a new `sample_with_retry()` client method accumulates those partial tokens, awaits resume, and resubmits with `prompt + accumulated` and remaining `max_tokens` until completion. **This is a transient fix.** Hopefully in the future if we can upstream a lora specific pause, we delete `sample_with_retry`, the per-LoRA gate, and the abort endpoint; the `lora_name` kwarg stays and routes to the new vLLM API. ## What changed ### New: `/skyrl/v1/abort_lora_requests` server endpoint [vllm_server_actor.py](skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py) gets a small custom endpoint that iterates `engine.output_processor.request_states`, filters by `lora_name`, and calls `engine.abort(ids, internal=True)`. `internal=True` is load-bearing — the states dict is keyed by internal IDs. ### `pause_generation`/`resume_generation` gain `lora_name: Optional[str] = None` - `lora_name=None` → unchanged: vLLM `/pause?mode=keep` global pause. - `lora_name="X"` → clears a per-LoRA `asyncio.Event` (gates retries client-side), sleeps a 5 s grace, then fans out to `/skyrl/v1/abort_lora_requests`. The new inference path (`_SKYRL_USE_NEW_INFERENCE=1`, the default) goes through `RemoteInferenceClient` → `vllm_server_actor`, which is where the new `/skyrl/v1/abort_lora_requests` endpoint lives — so that's the only path that actually supports targeted pause. Legacy-path classes (`InferenceEngineClient`, `RayWrappedInferenceEngine`, `RemoteInferenceEngine`, `AsyncVLLMInferenceEngine`) accept the `lora_name` kwarg as a guardrail but raise `NotImplementedError("targeted pause is HTTP-only")` when it's non-None. The legacy path keeps its existing global-pause behavior unchanged. ### `RemoteInferenceClient`: `sample()` refactor + new `sample_with_retry()` This is the only data-plane surface that gained retry logic. **`sample()` is split**: - `sample()` — renders the prompt and dispatches once. **Public behavior unchanged.** - `_sample_with_rendered_tokens()` — the post-render half, parameterized on rendered `token_ids`. Pure refactor (regression-checked by the existing `TestSample` tests). **`sample_with_retry()` is new** — the only retry-bearing method. Renders once, then runs a `while stop_reason == "abort"` loop: 1. `await _lora_pause_events[model].wait()` (no-op if no event exists). 2. Build a body with `token_ids = original_prompt + accum_tokens` and `max_tokens = original_max_tokens - len(accum_tokens)`. 3. Dispatch `_sample_with_rendered_tokens`. 4. Extend `accum_tokens` + `accum_logprobs` with the returned segment. 5. Repeat until non-abort. Returns the same `SampleResponse` shape as `sample()`. Asserts `num_samples == 1` (current Tinker callers all do this; multi-sample retry is straightforward but deferred). `generate()`, `chat_completion()`, `completion()` are **untouched** — no retry, no per-LoRA gating. The multi-tenant Tinker path is the only production caller that flips to retry mode. ### `worker_dispatch.save_weights_for_sampler` threads `lora_name` [worker_dispatch.py](skyrl/backends/skyrl_train/workers/worker_dispatch.py) non-colocate branch forwards `model_id` (set by `SkyRLTrainBackend.save_sampler_checkpoint` to the LoRA name for multi-tenant, `None` for FFT) to the pause/resume calls. ## API impact | API | Change | Migration | |---|---|---| | `InferenceEngineInterface.pause_generation` / `resume_generation` | New optional `lora_name: Optional[str] = None` kwarg. | None — default preserves existing behavior on every subclass. | | `RemoteInferenceClient.sample()` | **Refactored** into render + dispatch (pure refactor; public behavior unchanged). | None. | | `RemoteInferenceClient.sample_with_retry()` | **New.** Same signature/return as `sample()`. | Optional; auto-used by the multi-LoRA path in `SkyRLTrainBackend`. | | `generate()`, `chat_completion()`, `completion()` | **Unchanged.** No retry, no per-LoRA gate. | None. | | `worker_dispatch.save_weights_for_sampler` | Now forwards `lora_name=model_id` to pause/resume on the non-colocate path. | None — `model_id=None` for FFT preserves the global keep-mode pause. | ## Tests ### Unit (no GPU) — `tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py` 7 new cases in `TestTargetedLoraPause`: - `pause_generation(lora_name=X)` clears the event and fans out abort. - `pause_generation()` (no lora_name) still drives keep-mode. - `sample_with_retry` accumulates partial tokens across an abort, `max_tokens` decrements correctly, logprobs concatenate, final response shape OK. - `sample_with_retry` no-abort path is a single shot (refactor regression). - `sample_with_retry` blocks on the per-LoRA event until `resume_generation` is called. - `sample_with_retry` rejects `num_samples > 1`. - LoRAs that were never paused never block (no spurious event creation). ### GPU integration — `tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_pause_lora.py` 4 end-to-end cases, run against a real vLLM server with two LoRA adapters (Meow + Woof) loaded: 1. **`test_pause_lora_does_not_affect_other_lora`** — while lora-meow is paused, 4 concurrent `sample_with_retry(model="lora-woof")` calls complete promptly and contain "woof" content (proves the gate doesn't spill across adapters and weights aren't mixed up). 2. **`test_sample_with_retry_recovers_from_abort`** — 4+4 concurrent in-flight samples on meow + woof; pausing lora-meow mid-flight aborts all 4 meow requests, retry resubmits after resume, all 8 complete with non-abort stop reasons and correct content. **Hard-asserts** 0/8 tasks completed pre-pause and 0/4 escaped during the pause window — without these assertions the test would silently no-op if the LoRA emitted EOS too fast. 3. **`test_pause_swap_weights_resume_mid_sample`** — single sample call spans a real weight swap: starts with Meow weights, mid-flight calls `pause → load_lora_adapter("lora-target", woof_path) → resume`, and the merged output literally shows `meow meow meow ... woof woof woof`. Proves the abort/retry boundary preserves accumulated state AND that the retried request observes the newly-loaded weights. 4. **`test_global_pause_still_works`** — `pause_generation()` with no `lora_name` still drives keep-mode pause (FFT regression). --------- Signed-off-by: ahao-anyscale <ahao@anyscale.com>
1 parent 29f0fdf commit 77f92c2

12 files changed

Lines changed: 1309 additions & 32 deletions

File tree

skyrl/backends/skyrl_train/inference_engines/base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,22 @@ async def reset_prefix_cache(self):
165165
raise NotImplementedError
166166

167167
@abstractmethod
168-
async def pause_generation(self) -> None:
169-
"""Pause generation, freezing in-flight requests so they can be resumed later."""
168+
async def pause_generation(self, lora_name: Optional[str] = None) -> None:
169+
"""Pause generation, freezing in-flight requests so they can be resumed later.
170+
171+
When ``lora_name`` is None (default), pauses all generation globally
172+
(vLLM keep-mode pause). When ``lora_name`` is provided, only requests
173+
targeting that specific LoRA adapter are paused (HTTP path only).
174+
"""
170175
raise NotImplementedError
171176

172177
@abstractmethod
173-
async def resume_generation(self) -> None:
174-
"""Resume generation after a pause, continuing any frozen in-flight requests."""
178+
async def resume_generation(self, lora_name: Optional[str] = None) -> None:
179+
"""Resume generation after a pause, continuing any frozen in-flight requests.
180+
181+
``lora_name`` must match the value used in the corresponding
182+
``pause_generation`` call.
183+
"""
175184
raise NotImplementedError
176185

177186
@abstractmethod

skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,21 +366,29 @@ def dp_size(self) -> int:
366366
# ----------------------------
367367
# Generation pause and resume
368368
# ----------------------------
369-
async def pause_generation(self) -> None:
369+
async def pause_generation(self, lora_name: Optional[str] = None) -> None:
370370
"""
371371
Pauses generation for all engines using vLLM's native keep mode.
372372
373373
In-flight requests are frozen (not aborted) and will resume from where they left off
374374
when `resume_generation()` is called. New requests are blocked until resume.
375+
376+
``lora_name`` is accepted for interface parity with the HTTP path but
377+
targeted (per-LoRA) pause is HTTP-only; passing a non-None value
378+
raises ``NotImplementedError``.
375379
"""
380+
if lora_name is not None:
381+
raise NotImplementedError("targeted pause is HTTP-only")
376382
await self._run_on_all_engines("pause_generation")
377383

378-
async def resume_generation(self) -> None:
384+
async def resume_generation(self, lora_name: Optional[str] = None) -> None:
379385
"""
380386
Resumes generation for all engines after a keep-mode pause.
381387
382388
Frozen in-flight requests continue from where they left off, and new requests are unblocked.
383389
"""
390+
if lora_name is not None:
391+
raise NotImplementedError("targeted pause is HTTP-only")
384392
await self._run_on_all_engines("resume_generation")
385393

386394
# ----------------------------

skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Dict, List
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
22

33
import ray
44
from packaging import version
@@ -76,10 +76,14 @@ async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, An
7676
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
7777
return await self.inference_engine_actor.completion.remote(request_payload)
7878

79-
async def pause_generation(self) -> None:
79+
async def pause_generation(self, lora_name: Optional[str] = None) -> None:
80+
if lora_name is not None:
81+
raise NotImplementedError("targeted pause is HTTP-only")
8082
return await self.inference_engine_actor.pause_generation.remote()
8183

82-
async def resume_generation(self) -> None:
84+
async def resume_generation(self, lora_name: Optional[str] = None) -> None:
85+
if lora_name is not None:
86+
raise NotImplementedError("targeted pause is HTTP-only")
8387
return await self.inference_engine_actor.resume_generation.remote()
8488

8589

skyrl/backends/skyrl_train/inference_engines/remote_inference_engine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,10 @@ async def reset_prefix_cache(self):
282282
"body": text,
283283
}
284284

285-
async def pause_generation(self) -> None:
285+
async def pause_generation(self, lora_name: Optional[str] = None) -> None:
286286
"""Pause generation using vLLM's native keep mode, freezing in-flight requests."""
287+
if lora_name is not None:
288+
raise NotImplementedError("targeted pause is HTTP-only")
287289
async with aiohttp.ClientSession() as session:
288290
async with session.post(
289291
f"{self.url}/pause",
@@ -293,8 +295,10 @@ async def pause_generation(self) -> None:
293295
if resp.status != 200:
294296
raise RuntimeError(f"Failed to pause generation: {result.get('error', result)}")
295297

296-
async def resume_generation(self) -> None:
298+
async def resume_generation(self, lora_name: Optional[str] = None) -> None:
297299
"""Resume generation after a keep-mode pause."""
300+
if lora_name is not None:
301+
raise NotImplementedError("targeted pause is HTTP-only")
298302
async with aiohttp.ClientSession() as session:
299303
async with session.post(f"{self.url}/resume") as resp:
300304
result = await resp.json()

skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,10 @@ def reset_prefix_cache(self):
216216
"""Reset the prefix cache. Subclasses override for async version."""
217217
return self.llm.llm_engine.reset_prefix_cache()
218218

219-
async def pause_generation(self, clear_cache: bool = False) -> None:
219+
async def pause_generation(self, lora_name: Optional[str] = None, clear_cache: bool = False) -> None:
220220
raise NotImplementedError("pause_generation is only supported for AsyncVLLMInferenceEngine.")
221221

222-
async def resume_generation(self) -> None:
222+
async def resume_generation(self, lora_name: Optional[str] = None) -> None:
223223
raise NotImplementedError("resume_generation is only supported for AsyncVLLMInferenceEngine.")
224224

225225

@@ -648,14 +648,18 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
648648
"""
649649
return await self._handle_openai_request(request_payload, endpoint="/completions")
650650

651-
async def pause_generation(self, clear_cache: bool = False) -> None:
651+
async def pause_generation(self, lora_name: Optional[str] = None, clear_cache: bool = False) -> None:
652652
"""Pause generation using vLLM's native keep mode, freezing in-flight requests."""
653+
if lora_name is not None:
654+
raise NotImplementedError("targeted pause is HTTP-only")
653655
engine = self._get_engine()
654656
await engine.pause_generation(mode="keep", clear_cache=clear_cache)
655657
logger.info("pause_generation(mode='keep') finished")
656658

657-
async def resume_generation(self) -> None:
659+
async def resume_generation(self, lora_name: Optional[str] = None) -> None:
658660
"""Resume generation after a keep-mode pause."""
661+
if lora_name is not None:
662+
raise NotImplementedError("targeted pause is HTTP-only")
659663
engine = self._get_engine()
660664
await engine.resume_generation()
661665
logger.info("resume_generation() finished")

0 commit comments

Comments
 (0)