Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
0ab77ba
add release and resume handle in scheduler
klhhhhh Feb 22, 2026
5e0d5aa
implement release and resume memory in gpu worker
klhhhhh Feb 22, 2026
8a42a36
update release and resume api in http server
klhhhhh Feb 22, 2026
97145fb
pre-commit lint
klhhhhh Feb 22, 2026
db77216
remove tags for diffusion models and update io_struct in post training
klhhhhh Feb 23, 2026
5ce36fd
adjust tags in function call
klhhhhh Feb 23, 2026
d28118f
implement new wake sleep func directly use get updated module in pipe…
klhhhhh Feb 23, 2026
abcc12f
run lint
klhhhhh Feb 23, 2026
e78b252
Implement new wake and sleep also add sanitize moving part for modules
klhhhhh Feb 24, 2026
bc1ed2a
refactor weight api
klhhhhh Feb 24, 2026
db57344
Retrun correct status code call generation when sleeping
klhhhhh Feb 24, 2026
8d4820a
update comment
klhhhhh Feb 24, 2026
4b90d58
update lint
klhhhhh Feb 24, 2026
5106056
refactor all the code
klhhhhh Feb 24, 2026
ebc544f
refactor all the code
klhhhhh Feb 24, 2026
94e45bf
refactor gpu_worker and utils in openai entrypoint
klhhhhh Feb 24, 2026
7d7f665
fix bugs in utils
klhhhhh Feb 24, 2026
12d2190
fix bugs in utils
klhhhhh Feb 24, 2026
e2aae15
fix bugs in weight api
klhhhhh Feb 24, 2026
55b0789
fix comment in wake func
klhhhhh Feb 24, 2026
a24b08f
refactor wake func
klhhhhh Feb 24, 2026
2c50a65
add test wake sleep in ci
klhhhhh Feb 25, 2026
4ef1c1d
adds pytest entry
zhaochenyang20 Feb 25, 2026
704309d
fix race condition
zhaochenyang20 Feb 25, 2026
d3480ad
refactor process generation batch
klhhhhh Feb 25, 2026
a93fe56
fix bugs:access output using details
klhhhhh Feb 26, 2026
ad5d455
change test name
klhhhhh Feb 26, 2026
a27fe01
avoid worker exectution failed and keep consistent self._sleeping
klhhhhh Feb 26, 2026
d114525
refactor gpu worker
klhhhhh Feb 26, 2026
c9819d2
add roll out function
klhhhhh Feb 26, 2026
f1ca773
update scheduler
klhhhhh Feb 26, 2026
ed9a1b7
refactor weight api
klhhhhh Feb 26, 2026
726e4c6
move modules and rollback
zhaochenyang20 Feb 26, 2026
9acdc6e
refactor: unit test
zhaochenyang20 Feb 26, 2026
d9b283e
refactor: unit test, assert generate correct
zhaochenyang20 Feb 26, 2026
e464185
refactor: _get_module_device
zhaochenyang20 Feb 26, 2026
622c5cc
refactor: logging control
zhaochenyang20 Feb 26, 2026
d756081
refactor: _handle_memory_occupation
zhaochenyang20 Feb 26, 2026
c780a6c
refactor: resume_memory_occupation
zhaochenyang20 Feb 26, 2026
b47c8ed
update docs
zhaochenyang20 Feb 26, 2026
52da0ec
change docs string
zhaochenyang20 Feb 26, 2026
98aac4e
fix rocm ci
zhaochenyang20 Feb 27, 2026
0ef7e50
minor refactor
alphabetc1 Feb 27, 2026
07087b2
add todo for rollback expection
zhaochenyang20 Feb 27, 2026
6fcd9c0
self fixing comments
zhaochenyang20 Mar 3, 2026
a418d46
refactor: pass the request instance instead of class type
alphabetc1 Mar 3, 2026
a9e364d
move RL related tests to post-training dir
zhaochenyang20 Mar 3, 2026
c5b7e7c
fix untoched unit test in CI
zhaochenyang20 Mar 3, 2026
6fe5ef0
extract _get_module_device into utility helper & add TODO for io_stru…
zhaochenyang20 Mar 3, 2026
8ccda32
remove unit redunct tests
zhaochenyang20 Mar 4, 2026
888c15e
move test
alphabetc1 Mar 6, 2026
cff8755
refactor: reduce sleep/wake diff noise
MikukuOvO Apr 13, 2026
32271e7
refactor: simplify sleep/wake error handling
MikukuOvO Apr 13, 2026
8258d79
refactor: simplify sleep/wake worker control flow
MikukuOvO Apr 13, 2026
11cd364
refactor: use structured sleeping error type
MikukuOvO Apr 13, 2026
9498a80
refactor: drop test and utils changes
MikukuOvO Apr 13, 2026
84c2fdd
Merge remote-tracking branch 'upstream/main' into dev/pr-19152-sleep-…
MikukuOvO Apr 13, 2026
c504629
refactor: decouple timer logging from fastapi
MikukuOvO Apr 13, 2026
6c611ae
refactor: simplify generation error mapping
MikukuOvO Apr 13, 2026
9a2298b
refactor: trim defensive sleep/wake checks
MikukuOvO Apr 13, 2026
1dbf16c
refactor: derive sleep state in worker responses
MikukuOvO Apr 13, 2026
83c5f02
refactor: extract memory occupation controller
MikukuOvO Apr 13, 2026
e24b85f
refactor: replace error type with status code
MikukuOvO Apr 13, 2026
408bc0d
refactor: drop sleep wake docs and todos
MikukuOvO Apr 13, 2026
5aff54f
refactor: inline memory occupation scheduler handlers
MikukuOvO Apr 13, 2026
89b26ac
fix: restore gc import in gpu worker
MikukuOvO Apr 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, Generator, List, Optional, Union

import httpx
from fastapi import UploadFile
from fastapi import HTTPException, UploadFile

from sglang.multimodal_gen.configs.sample.sampling_params import (
DataType,
Expand Down Expand Up @@ -260,14 +260,12 @@ async def process_generation_batch(
batch,
) -> tuple[list[str], OutputBatch]:
total_start_time = time.perf_counter()

with log_generation_timer(logger, batch.prompt):
result = await scheduler_client.forward([batch])

if result.output is None and result.output_file_paths is None:
error_msg = result.error or "Unknown error"
raise RuntimeError(
f"Model generation returned no output. Error from scheduler: {error_msg}"
)
_raise_generation_error(result)

if result.output_file_paths:
save_file_path_list = result.output_file_paths
Expand Down Expand Up @@ -300,6 +298,18 @@ async def process_generation_batch(
return save_file_path_list, result


def _raise_generation_error(result: OutputBatch) -> None:
error_msg = result.error or "Unknown error"
if result.error_status_code is not None:
raise HTTPException(
status_code=result.error_status_code,
detail={"message": error_msg},
)
raise RuntimeError(
f"Model generation returned no output. Error from scheduler: {error_msg}"
)


def merge_image_input_list(*inputs: Union[List, Any, None]) -> List:
"""
Merge multiple image input sources into a single list.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,16 @@ class GetWeightsChecksumReqInput:
"""Compute SHA-256 checksum of loaded module weights for verification."""

module_names: list[str] | None = None


@dataclass
class ReleaseMemoryOccupationReqInput:
"""Request to release (sleep) GPU memory occupation for the diffusion engine."""
pass


@dataclass
class ResumeMemoryOccupationReqInput:
"""Request to resume (wake) GPU memory occupation for the diffusion engine."""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@

from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import (
GetWeightsChecksumReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
UpdateWeightFromDiskReqInput,
)
from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client
from sglang.srt.utils.json_response import orjson_response
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

router = APIRouter()

logger = init_logger(__name__)


@router.post("/update_weights_from_disk")
async def update_weights_from_disk(request: Request):
Expand All @@ -37,6 +42,16 @@ async def update_weights_from_disk(request: Request):
status_code=500,
)

if response.output is None:
status_code = response.error_status_code or 500
return orjson_response(
{
"success": False,
"message": response.error or "Unknown status",
},
status_code=status_code,
)

result = response.output
success = result.get("success", False)
message = result.get("message", "Unknown status")
Expand All @@ -60,3 +75,41 @@ async def get_weights_checksum(request: Request):
return orjson_response({"error": str(e)}, status_code=500)

return orjson_response(response.output, status_code=200)


async def _handle_memory_occupation_request(
req: ReleaseMemoryOccupationReqInput | ResumeMemoryOccupationReqInput,
):
"""Handle memory sleep/wake requests forwarded to scheduler."""
try:
response = await async_scheduler_client.forward(req)
except Exception as e:
logger.exception(f"scheduler_client.forward failed for {type(req).__name__}")
return orjson_response({"success": False, "message": str(e)}, status_code=500)

payload = response.output if isinstance(response.output, dict) else None

if not isinstance(payload, dict) or "success" not in payload:
logger.error(f"missing success in scheduler output: {response.output}")
return orjson_response(
{
"success": False,
"message": f"Missing 'success' field in scheduler response: {response.output}",
},
status_code=500,
)

success = bool(payload["success"])
return orjson_response(payload, status_code=200 if success else 400)


@router.post("/release_memory_occupation")
async def release_memory_occupation():
"""Release GPU memory occupation (sleep the engine)."""
return await _handle_memory_occupation_request(ReleaseMemoryOccupationReqInput())


@router.post("/resume_memory_occupation")
async def resume_memory_occupation():
"""Resume GPU memory occupation (wake the engine)."""
return await _handle_memory_occupation_request(ResumeMemoryOccupationReqInput())
17 changes: 17 additions & 0 deletions python/sglang/multimodal_gen/runtime/managers/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
WeightsUpdater,
get_updatable_modules,
)
from sglang.multimodal_gen.runtime.managers.memory_occupation_controller import (
MemoryOccupationController,
)
from sglang.multimodal_gen.runtime.pipelines_core import (
ComposedPipelineBase,
LoRAPipeline,
Expand Down Expand Up @@ -91,6 +94,14 @@ def __init__(
self.cfg_group = get_cfg_group()
self.cfg_cpu_group = self.cfg_group.cpu_group

self.memory_occupation = MemoryOccupationController(
pipeline=self.pipeline,
rank=self.rank,
)

def is_sleeping(self) -> bool:
return self.memory_occupation.is_sleeping()

def init_device_and_model(self) -> None:
"""Initialize the device and load the model."""
torch.get_device_module().set_device(self.local_rank)
Expand Down Expand Up @@ -462,6 +473,12 @@ def get_weights_checksum(
)
return checksums

def release_memory_occupation(self) -> dict:
return self.memory_occupation.release_memory_occupation()

def resume_memory_occupation(self) -> dict:
return self.memory_occupation.resume_memory_occupation()


OOM_MSG = f"""
OOM detected. Possible solutions:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# SPDX-License-Identifier: Apache-2.0

import gc

import torch

from sglang.multimodal_gen.runtime.loader.weights_updater import get_updatable_modules
from sglang.multimodal_gen.runtime.pipelines_core import ComposedPipelineBase
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)


def _get_module_device(module: torch.nn.Module) -> str:
"""Return best-effort device string for a module."""
param = next(module.parameters(), None)
if param is not None:
return str(param.device)
buffer = next(module.buffers(), None)
if buffer is not None:
return str(buffer.device)

for key, val in vars(module).items():
if key.startswith("_"):
continue
if isinstance(val, torch.Tensor):
return str(val.device)

return "cpu"


def _move_unregistered_tensors(module: torch.nn.Module, device: str) -> None:
"""Move tensor attributes that are not covered by `module.to(device)`."""

def move_tensors(obj):
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, dict):
return {k: move_tensors(v) for k, v in obj.items()}
if isinstance(obj, list):
return [move_tensors(v) for v in obj]
if isinstance(obj, tuple):
return tuple(move_tensors(v) for v in obj)
return obj

attrs = module.__dict__
for attr_name, attr_value in list(attrs.items()):
if attr_name.startswith("_"):
continue
if attr_name in {"_parameters", "_buffers", "_modules"}:
continue

moved_value = move_tensors(attr_value)
if moved_value is not attr_value:
attrs[attr_name] = moved_value


class MemoryOccupationController:
def __init__(self, pipeline: ComposedPipelineBase | None, rank: int):
self.pipeline = pipeline
self.rank = rank
self._sleeping = False
self._sleep_restore_map: dict[str, str] = {}

def is_sleeping(self) -> bool:
return self._sleeping

def _memory_occupation_result(
self, success: bool, message: str
) -> dict[str, bool | str]:
return {
"success": success,
"sleeping": self._sleeping,
"message": message,
}

@staticmethod
def _clear_torch_device_cache() -> None:
device = torch.get_device_module()
device.synchronize()
gc.collect()
device.empty_cache()

def _move_modules(self, names: list[str], device: str) -> None:
"""
Move selected modules to device.

This function has all-or-nothing semantics:
- Stop on first failure (device query / move / sanitize).
- Roll back modules already moved in this call.
- Raise RuntimeError to caller after rollback.
"""
modules = get_updatable_modules(self.pipeline)
moved: list[str] = []
src_device_map: dict[str, str] = {}

try:
for name in names:
module = modules[name]
src_device_map[name] = _get_module_device(module)
module.to(device)
moved.append(name)
_move_unregistered_tensors(module, device)
except Exception as e:
logger.warning(
f"[_move_modules] move failed, rollback started: target={device} moved={moved} error={e}",
)
for name in moved:
module = modules.get(name)
src_dev = src_device_map.get(name)
module.to(src_dev)
_move_unregistered_tensors(module, src_dev)
raise RuntimeError(
f"failed to move modules to {device}; rollback finished: error={e}"
) from e

def _offload_active_modules_to_cpu(self) -> dict[str, str]:
restore_map: dict[str, str] = {}
for name, module in get_updatable_modules(self.pipeline).items():
device = _get_module_device(module)
if not device.startswith("cpu"):
restore_map[name] = device

self._move_modules(list(restore_map.keys()), "cpu")
self._clear_torch_device_cache()
return restore_map

def _restore_modules_to_original_devices(
self, module_device_map: dict[str, str]
) -> None:
grouped: dict[str, list[str]] = {}
for name, device in module_device_map.items():
grouped.setdefault(device, []).append(name)

for device, names in grouped.items():
self._move_modules(names, device)

def release_memory_occupation(self) -> dict[str, bool | str]:
logger.info(f"[SLEEP] release_memory_occupation rank={self.rank}")
if self._sleeping:
return self._memory_occupation_result(
success=True,
message="already sleeping",
)
if self.pipeline is None:
return self._memory_occupation_result(
success=False,
message="pipeline not initialized",
)

self._sleep_restore_map = self._offload_active_modules_to_cpu()
self._sleeping = True
return self._memory_occupation_result(
success=True,
message="released GPU memory (moved active modules to CPU)",
)

def resume_memory_occupation(self) -> dict[str, bool | str]:
logger.info(f"[WAKE] resume_memory_occupation rank={self.rank}")
if not self._sleeping:
return self._memory_occupation_result(
success=True,
message="already awake",
)
if self.pipeline is None:
return self._memory_occupation_result(
success=False,
message="pipeline not initialized",
)

if not self._sleep_restore_map:
self._sleeping = False
return self._memory_occupation_result(
success=True,
message="no restore map; marked awake",
)

self._restore_modules_to_original_devices(self._sleep_restore_map)
self._sleep_restore_map = {}
self._sleeping = False
return self._memory_occupation_result(
success=True,
message="resumed GPU memory (restored modules to original devices)",
)
Loading
Loading