Skip to content

Commit 9138c61

Browse files
committed
import wake/sleep into router
1 parent de7cc3e commit 9138c61

1 file changed

Lines changed: 56 additions & 0 deletions

File tree

src/sglang_diffusion_routing/router/diffusion_router.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def __init__(self, args, verbose: bool = False):
3838
self.worker_video_support: dict[str, bool | None] = {}
3939
# quarantined workers excluded from routing
4040
self.dead_workers: set[str] = set()
41+
# record workers in sleeping status
42+
self.sleeping_workers: set[str] = set()
4143
# video_id -> worker URL mapping for stable query routing
4244
self.video_job_to_worker: dict[str, str] = {}
4345
self._health_task: asyncio.Task | None = None
@@ -81,6 +83,8 @@ def _setup_routes(self) -> None:
8183
self.app.get("/v1/models")(self.get_models)
8284
self.app.get("/health")(self.health)
8385
self.app.post("/update_weights_from_disk")(self.update_weights_from_disk)
86+
self.app.post("/release_memory_occupation")(self.release_memory_occupation)
87+
self.app.post("/resume_memory_occupation")(self.resume_memory_occupation)
8488
self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(
8589
self.proxy
8690
)
@@ -704,6 +708,58 @@ async def update_weights_from_disk(self, request: Request):
704708
)
705709
return JSONResponse(content={"results": results})
706710

711+
async def release_memory_occupation(self, request: Request):
712+
"""Broadcast sleep to all healthy workers and mark them as sleeping on success."""
713+
healthy_workers = [
714+
url for url in self.worker_request_counts if url not in self.dead_workers
715+
]
716+
if not healthy_workers:
717+
return JSONResponse(
718+
status_code=503,
719+
content={"error": "No healthy workers available in the pool"},
720+
)
721+
722+
body = await request.body()
723+
headers = dict(request.headers)
724+
headers.pop("content-length", None)
725+
headers.setdefault("content-type", "application/json")
726+
727+
results = await self._broadcast_to_workers("release_memory_occupation", body, headers)
728+
729+
for item in results:
730+
if item.get("status_code") == 200:
731+
self.sleeping_workers.add(item["worker_url"])
732+
733+
return JSONResponse(content={"results": results})
734+
735+
736+
async def resume_memory_occupation(self, request: Request):
737+
"""Broadcast wake to all healthy workers and unmark sleeping on success."""
738+
healthy_workers = [
739+
url for url in self.worker_request_counts if url not in self.dead_workers
740+
]
741+
if not healthy_workers:
742+
return JSONResponse(
743+
status_code=503,
744+
content={"error": "No healthy workers available in the pool"},
745+
)
746+
body = await request.body()
747+
headers = dict(request.headers)
748+
headers.pop("content-length", None)
749+
headers.setdefault("content-type", "application/json")
750+
751+
results = await self._broadcast_to_workers("resume_memory_occupation", body, headers)
752+
753+
for item in results:
754+
if item.get("status_code") == 200:
755+
self.sleeping_workers.discard(item["worker_url"])
756+
# Reset health failure counter on successful wake:
757+
# waking is an explicit recovery point and should not inherit failures
758+
# accumulated during intentional sleep.
759+
self.worker_failure_counts[item["worker_url"]] = 0
760+
761+
return JSONResponse(content={"results": results})
762+
707763
def register_worker(self, url: str) -> None:
708764
"""Register a worker URL if not already known."""
709765
normalized_url = self.normalize_worker_url(url)

0 commit comments

Comments
 (0)