Commit 77f92c2
authored
[lora][tinker] Add pause and resume for multi-tenant lora (#1657)
Addresses #1647
Today's `pause_generation()` is a **global** vLLM keep-mode pause: when
one LoRA tenant syncs new weights, every other tenant's in-flight
generation freezes for the duration of the swap. This blocks practical
multi-tenant LoRA RL training.
This PR adds a per-LoRA `lora_name` arg to pause/resume so a weight sync
for adapter A only aborts A's requests; other adapters keep generating.
The aborted requests come back with `finish_reason="abort"` and partial
tokens, and a new `sample_with_retry()` client method accumulates those
partial tokens, awaits resume, and resubmits with `prompt + accumulated`
and remaining `max_tokens` until completion.
**This is a transient fix.** Hopefully in the future if we can upstream
a lora specific pause, we delete `sample_with_retry`, the per-LoRA gate,
and the abort endpoint; the `lora_name` kwarg stays and routes to the
new vLLM API.
## What changed
### New: `/skyrl/v1/abort_lora_requests` server endpoint
[vllm_server_actor.py](skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py)
gets a small custom endpoint that iterates
`engine.output_processor.request_states`, filters by `lora_name`, and
calls `engine.abort(ids, internal=True)`. `internal=True` is
load-bearing — the states dict is keyed by internal IDs.
### `pause_generation`/`resume_generation` gain `lora_name:
Optional[str] = None`
- `lora_name=None` → unchanged: vLLM `/pause?mode=keep` global pause.
- `lora_name="X"` → clears a per-LoRA `asyncio.Event` (gates retries
client-side), sleeps a 5 s grace, then fans out to
`/skyrl/v1/abort_lora_requests`.
The new inference path (`_SKYRL_USE_NEW_INFERENCE=1`, the default) goes
through `RemoteInferenceClient` → `vllm_server_actor`, which is where
the new `/skyrl/v1/abort_lora_requests` endpoint lives — so that's the
only path that actually supports targeted pause. Legacy-path classes
(`InferenceEngineClient`, `RayWrappedInferenceEngine`,
`RemoteInferenceEngine`, `AsyncVLLMInferenceEngine`) accept the
`lora_name` kwarg as a guardrail but raise
`NotImplementedError("targeted pause is HTTP-only")` when it's non-None.
The legacy path keeps its existing global-pause behavior unchanged.
### `RemoteInferenceClient`: `sample()` refactor + new
`sample_with_retry()`
This is the only data-plane surface that gained retry logic.
**`sample()` is split**:
- `sample()` — renders the prompt and dispatches once. **Public behavior
unchanged.**
- `_sample_with_rendered_tokens()` — the post-render half, parameterized
on rendered `token_ids`. Pure refactor (regression-checked by the
existing `TestSample` tests).
**`sample_with_retry()` is new** — the only retry-bearing method.
Renders once, then runs a `while stop_reason == "abort"` loop:
1. `await _lora_pause_events[model].wait()` (no-op if no event exists).
2. Build a body with `token_ids = original_prompt + accum_tokens` and
`max_tokens = original_max_tokens - len(accum_tokens)`.
3. Dispatch `_sample_with_rendered_tokens`.
4. Extend `accum_tokens` + `accum_logprobs` with the returned segment.
5. Repeat until non-abort.
Returns the same `SampleResponse` shape as `sample()`. Asserts
`num_samples == 1` (current Tinker callers all do this; multi-sample
retry is straightforward but deferred).
`generate()`, `chat_completion()`, `completion()` are **untouched** — no
retry, no per-LoRA gating. The multi-tenant Tinker path is the only
production caller that flips to retry mode.
### `worker_dispatch.save_weights_for_sampler` threads `lora_name`
[worker_dispatch.py](skyrl/backends/skyrl_train/workers/worker_dispatch.py)
non-colocate branch forwards `model_id` (set by
`SkyRLTrainBackend.save_sampler_checkpoint` to the LoRA name for
multi-tenant, `None` for FFT) to the pause/resume calls.
## API impact
| API | Change | Migration |
|---|---|---|
| `InferenceEngineInterface.pause_generation` / `resume_generation` |
New optional `lora_name: Optional[str] = None` kwarg. | None — default
preserves existing behavior on every subclass. |
| `RemoteInferenceClient.sample()` | **Refactored** into render +
dispatch (pure refactor; public behavior unchanged). | None. |
| `RemoteInferenceClient.sample_with_retry()` | **New.** Same
signature/return as `sample()`. | Optional; auto-used by the multi-LoRA
path in `SkyRLTrainBackend`. |
| `generate()`, `chat_completion()`, `completion()` | **Unchanged.** No
retry, no per-LoRA gate. | None. |
| `worker_dispatch.save_weights_for_sampler` | Now forwards
`lora_name=model_id` to pause/resume on the non-colocate path. | None —
`model_id=None` for FFT preserves the global keep-mode pause. |
## Tests
### Unit (no GPU) —
`tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py`
7 new cases in `TestTargetedLoraPause`:
- `pause_generation(lora_name=X)` clears the event and fans out abort.
- `pause_generation()` (no lora_name) still drives keep-mode.
- `sample_with_retry` accumulates partial tokens across an abort,
`max_tokens` decrements correctly, logprobs concatenate, final response
shape OK.
- `sample_with_retry` no-abort path is a single shot (refactor
regression).
- `sample_with_retry` blocks on the per-LoRA event until
`resume_generation` is called.
- `sample_with_retry` rejects `num_samples > 1`.
- LoRAs that were never paused never block (no spurious event creation).
### GPU integration —
`tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_pause_lora.py`
4 end-to-end cases, run against a real vLLM server with two LoRA
adapters (Meow + Woof) loaded:
1. **`test_pause_lora_does_not_affect_other_lora`** — while lora-meow is
paused, 4 concurrent `sample_with_retry(model="lora-woof")` calls
complete promptly and contain "woof" content (proves the gate doesn't
spill across adapters and weights aren't mixed up).
2. **`test_sample_with_retry_recovers_from_abort`** — 4+4 concurrent
in-flight samples on meow + woof; pausing lora-meow mid-flight aborts
all 4 meow requests, retry resubmits after resume, all 8 complete with
non-abort stop reasons and correct content. **Hard-asserts** 0/8 tasks
completed pre-pause and 0/4 escaped during the pause window — without
these assertions the test would silently no-op if the LoRA emitted EOS
too fast.
3. **`test_pause_swap_weights_resume_mid_sample`** — single sample call
spans a real weight swap: starts with Meow weights, mid-flight calls
`pause → load_lora_adapter("lora-target", woof_path) → resume`, and the
merged output literally shows `meow meow meow ... woof woof woof`.
Proves the abort/retry boundary preserves accumulated state AND that the
retried request observes the newly-loaded weights.
4. **`test_global_pause_still_works`** — `pause_generation()` with no
`lora_name` still drives keep-mode pause (FFT regression).
---------
Signed-off-by: ahao-anyscale <ahao@anyscale.com>1 parent 29f0fdf commit 77f92c2
12 files changed
Lines changed: 1309 additions & 32 deletions
File tree
- skyrl/backends
- skyrl_train
- inference_engines
- vllm
- inference_servers
- workers
- tests/backends/skyrl_train
- distributed
- gpu/gpu_ci/inference_servers
- inference_servers
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
165 | 165 | | |
166 | 166 | | |
167 | 167 | | |
168 | | - | |
169 | | - | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
170 | 175 | | |
171 | 176 | | |
172 | 177 | | |
173 | | - | |
174 | | - | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
175 | 184 | | |
176 | 185 | | |
177 | 186 | | |
| |||
Lines changed: 10 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
366 | 366 | | |
367 | 367 | | |
368 | 368 | | |
369 | | - | |
| 369 | + | |
370 | 370 | | |
371 | 371 | | |
372 | 372 | | |
373 | 373 | | |
374 | 374 | | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
375 | 379 | | |
| 380 | + | |
| 381 | + | |
376 | 382 | | |
377 | 383 | | |
378 | | - | |
| 384 | + | |
379 | 385 | | |
380 | 386 | | |
381 | 387 | | |
382 | 388 | | |
383 | 389 | | |
| 390 | + | |
| 391 | + | |
384 | 392 | | |
385 | 393 | | |
386 | 394 | | |
| |||
Lines changed: 7 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | | - | |
| 1 | + | |
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
| |||
76 | 76 | | |
77 | 77 | | |
78 | 78 | | |
79 | | - | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
80 | 82 | | |
81 | 83 | | |
82 | | - | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
83 | 87 | | |
84 | 88 | | |
85 | 89 | | |
| |||
Lines changed: 6 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
282 | 282 | | |
283 | 283 | | |
284 | 284 | | |
285 | | - | |
| 285 | + | |
286 | 286 | | |
| 287 | + | |
| 288 | + | |
287 | 289 | | |
288 | 290 | | |
289 | 291 | | |
| |||
293 | 295 | | |
294 | 296 | | |
295 | 297 | | |
296 | | - | |
| 298 | + | |
297 | 299 | | |
| 300 | + | |
| 301 | + | |
298 | 302 | | |
299 | 303 | | |
300 | 304 | | |
| |||
Lines changed: 8 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
216 | 216 | | |
217 | 217 | | |
218 | 218 | | |
219 | | - | |
| 219 | + | |
220 | 220 | | |
221 | 221 | | |
222 | | - | |
| 222 | + | |
223 | 223 | | |
224 | 224 | | |
225 | 225 | | |
| |||
648 | 648 | | |
649 | 649 | | |
650 | 650 | | |
651 | | - | |
| 651 | + | |
652 | 652 | | |
| 653 | + | |
| 654 | + | |
653 | 655 | | |
654 | 656 | | |
655 | 657 | | |
656 | 658 | | |
657 | | - | |
| 659 | + | |
658 | 660 | | |
| 661 | + | |
| 662 | + | |
659 | 663 | | |
660 | 664 | | |
661 | 665 | | |
| |||
0 commit comments