Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions skyrl/backends/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,22 @@ async def reset_prefix_cache(self):
raise NotImplementedError

@abstractmethod
async def pause_generation(self) -> None:
"""Pause generation, freezing in-flight requests so they can be resumed later."""
async def pause_generation(self, lora_name: Optional[str] = None) -> None:
"""Pause generation, freezing in-flight requests so they can be resumed later.

When ``lora_name`` is None (default), pauses all generation globally
(vLLM keep-mode pause). When ``lora_name`` is provided, only requests
targeting that specific LoRA adapter are paused (HTTP path only).
"""
raise NotImplementedError

@abstractmethod
async def resume_generation(self) -> None:
"""Resume generation after a pause, continuing any frozen in-flight requests."""
async def resume_generation(self, lora_name: Optional[str] = None) -> None:
"""Resume generation after a pause, continuing any frozen in-flight requests.

``lora_name`` must match the value used in the corresponding
``pause_generation`` call.
"""
raise NotImplementedError

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,21 +366,29 @@ def dp_size(self) -> int:
# ----------------------------
# Generation pause and resume
# ----------------------------
async def pause_generation(self) -> None:
async def pause_generation(self, lora_name: Optional[str] = None) -> None:
"""
Pauses generation for all engines using vLLM's native keep mode.

In-flight requests are frozen (not aborted) and will resume from where they left off
when `resume_generation()` is called. New requests are blocked until resume.

``lora_name`` is accepted for interface parity with the HTTP path but
targeted (per-LoRA) pause is HTTP-only; passing a non-None value
raises ``NotImplementedError``.
"""
if lora_name is not None:
raise NotImplementedError("targeted pause is HTTP-only")
await self._run_on_all_engines("pause_generation")

async def resume_generation(self) -> None:
async def resume_generation(self, lora_name: Optional[str] = None) -> None:
"""
Resumes generation for all engines after a keep-mode pause.

Frozen in-flight requests continue from where they left off, and new requests are unblocked.
"""
if lora_name is not None:
raise NotImplementedError("targeted pause is HTTP-only")
await self._run_on_all_engines("resume_generation")

# ----------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import ray
from packaging import version
Expand Down Expand Up @@ -76,10 +76,14 @@ async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, An
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
return await self.inference_engine_actor.completion.remote(request_payload)

async def pause_generation(self) -> None:
async def pause_generation(self, lora_name: Optional[str] = None) -> None:
if lora_name is not None:
raise NotImplementedError("targeted pause is HTTP-only")
return await self.inference_engine_actor.pause_generation.remote()

async def resume_generation(self) -> None:
async def resume_generation(self, lora_name: Optional[str] = None) -> None:
if lora_name is not None:
raise NotImplementedError("targeted pause is HTTP-only")
return await self.inference_engine_actor.resume_generation.remote()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,10 @@ async def reset_prefix_cache(self):
"body": text,
}

async def pause_generation(self) -> None:
async def pause_generation(self, lora_name: Optional[str] = None) -> None:
"""Pause generation using vLLM's native keep mode, freezing in-flight requests."""
if lora_name is not None:
raise NotImplementedError("targeted pause is HTTP-only")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.url}/pause",
Expand All @@ -293,8 +295,10 @@ async def pause_generation(self) -> None:
if resp.status != 200:
raise RuntimeError(f"Failed to pause generation: {result.get('error', result)}")

async def resume_generation(self) -> None:
async def resume_generation(self, lora_name: Optional[str] = None) -> None:
"""Resume generation after a keep-mode pause."""
if lora_name is not None:
raise NotImplementedError("targeted pause is HTTP-only")
async with aiohttp.ClientSession() as session:
async with session.post(f"{self.url}/resume") as resp:
result = await resp.json()
Expand Down
12 changes: 8 additions & 4 deletions skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ def reset_prefix_cache(self):
"""Reset the prefix cache. Subclasses override for async version."""
return self.llm.llm_engine.reset_prefix_cache()

async def pause_generation(self, clear_cache: bool = False) -> None:
async def pause_generation(self, lora_name: Optional[str] = None, clear_cache: bool = False) -> None:
raise NotImplementedError("pause_generation is only supported for AsyncVLLMInferenceEngine.")

async def resume_generation(self) -> None:
async def resume_generation(self, lora_name: Optional[str] = None) -> None:
raise NotImplementedError("resume_generation is only supported for AsyncVLLMInferenceEngine.")


Expand Down Expand Up @@ -648,14 +648,18 @@ async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
return await self._handle_openai_request(request_payload, endpoint="/completions")

async def pause_generation(self, clear_cache: bool = False) -> None:
async def pause_generation(self, lora_name: Optional[str] = None, clear_cache: bool = False) -> None:
"""Pause generation using vLLM's native keep mode, freezing in-flight requests."""
if lora_name is not None:
raise NotImplementedError("targeted pause is HTTP-only")
engine = self._get_engine()
await engine.pause_generation(mode="keep", clear_cache=clear_cache)
logger.info("pause_generation(mode='keep') finished")

async def resume_generation(self) -> None:
async def resume_generation(self, lora_name: Optional[str] = None) -> None:
"""Resume generation after a keep-mode pause."""
if lora_name is not None:
raise NotImplementedError("targeted pause is HTTP-only")
engine = self._get_engine()
await engine.resume_generation()
logger.info("resume_generation() finished")
Expand Down
Loading
Loading