Skip to content

Commit 967959d

Browse files
Merge pull request zhaochenyang20#35 from klhhhhh/main
Support wake up/sleep into router
2 parents 634aba8 + 68b69c8 commit 967959d

3 files changed

Lines changed: 197 additions & 7 deletions

File tree

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ Video query routing is stable by `video_id`: router caches `video_id -> worker`
258258
| Method | Path | Description |
259259
|---|---|---|
260260
| `POST` | `/update_weights_from_disk` | Reload weights from disk on all healthy workers |
261-
261+
| `POST` | `/release_memory_occupation` | Broadcast sleep to all healthy workers (release GPU memory occupation) |
262+
| `POST` | `/resume_memory_occupation` | Broadcast wake to all healthy workers (resume GPU memory occupation) |
262263

263264
## Acknowledgment
264265

src/sglang_diffusion_routing/router/diffusion_router.py

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def __init__(self, args, verbose: bool = False):
3838
self.worker_video_support: dict[str, bool | None] = {}
3939
# quarantined workers excluded from routing
4040
self.dead_workers: set[str] = set()
41+
# record workers in sleeping status
42+
self.sleeping_workers: set[str] = set()
4143
# video_id -> worker URL mapping for stable query routing
4244
self.video_job_to_worker: dict[str, str] = {}
4345
self._health_task: asyncio.Task | None = None
@@ -81,6 +83,8 @@ def _setup_routes(self) -> None:
8183
self.app.get("/v1/models")(self.get_models)
8284
self.app.get("/health")(self.health)
8385
self.app.post("/update_weights_from_disk")(self.update_weights_from_disk)
86+
self.app.post("/release_memory_occupation")(self.release_memory_occupation)
87+
self.app.post("/resume_memory_occupation")(self.resume_memory_occupation)
8488
self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(
8589
self.proxy
8690
)
@@ -183,7 +187,9 @@ def _select_worker_by_routing(self, worker_urls: list[str] | None = None) -> str
183187
raise RuntimeError("No workers registered in the pool")
184188

185189
valid_workers = [
186-
w for w in self.worker_request_counts if w not in self.dead_workers
190+
w
191+
for w in self.worker_request_counts
192+
if w not in self.dead_workers and w not in self.sleeping_workers
187193
]
188194
if worker_urls is not None:
189195
allowed = {w for w in worker_urls if w in self.worker_request_counts}
@@ -306,6 +312,10 @@ async def _forward_to_registered_worker(
306312
return JSONResponse(
307313
status_code=503, content={"error": "Mapped worker is unavailable"}
308314
)
315+
if worker_url in self.sleeping_workers:
316+
return JSONResponse(
317+
status_code=503, content={"error": "Mapped worker is sleeping"}
318+
)
309319
self.worker_request_counts[worker_url] += 1
310320
return await self._forward_to_selected_worker(request, path, worker_url)
311321

@@ -361,7 +371,12 @@ async def _probe_worker_video_support(self, worker_url: str) -> bool | None:
361371
)
362372
if isinstance(task_type, str):
363373
return task_type.upper() not in _IMAGE_TASK_TYPES
364-
except (httpx.RequestError, json.JSONDecodeError):
374+
except (httpx.RequestError, json.JSONDecodeError) as exc:
375+
logger.debug(
376+
"[diffusion-router] video support probe failed: worker=%s error=%s",
377+
worker_url,
378+
exc,
379+
)
365380
return None
366381

367382
async def refresh_worker_video_support(self, worker_url: str) -> None:
@@ -373,8 +388,23 @@ async def refresh_worker_video_support(self, worker_url: str) -> None:
373388
async def _broadcast_to_workers(
374389
self, path: str, body: bytes, headers: dict
375390
) -> list[dict]:
376-
"""Send a request to all healthy workers and collect results."""
377-
urls = [u for u in self.worker_request_counts if u not in self.dead_workers]
391+
"""
392+
Broadcast request to eligible workers.
393+
394+
Rules:
395+
- For resume_memory_occupation (wake): target sleeping workers (even if currently marked dead).
396+
- For all other requests: target only active workers (exclude dead AND sleeping).
397+
"""
398+
if path == "resume_memory_occupation":
399+
# Wake is a recovery point: allow waking workers that were marked dead during sleep.
400+
urls = [u for u in self.sleeping_workers if u in self.worker_request_counts]
401+
else:
402+
urls = [
403+
u
404+
for u in self.worker_request_counts
405+
if u not in self.dead_workers and u not in self.sleeping_workers
406+
]
407+
378408
if not urls:
379409
return []
380410

@@ -494,6 +524,7 @@ def _build_worker_payload(self, worker_url: str) -> dict:
494524
"url": worker_url,
495525
"active_requests": self.worker_request_counts.get(worker_url, 0),
496526
"is_dead": worker_url in self.dead_workers,
527+
"is_sleeping": worker_url in self.sleeping_workers,
497528
"consecutive_failures": self.worker_failure_counts.get(worker_url, 0),
498529
"video_support": self.worker_video_support.get(worker_url),
499530
}
@@ -554,6 +585,7 @@ def _video_capable_workers(self) -> list[str]:
554585
if support
555586
and worker_url in self.worker_request_counts
556587
and worker_url not in self.dead_workers
588+
and worker_url not in self.sleeping_workers
557589
]
558590

559591
@staticmethod
@@ -607,7 +639,7 @@ async def get_models(self, request: Request):
607639
worker_urls = [
608640
url
609641
for url in self.worker_request_counts.keys()
610-
if url not in self.dead_workers
642+
if url not in self.dead_workers and url not in self.sleeping_workers
611643
]
612644
if not worker_urls:
613645
return JSONResponse(
@@ -684,7 +716,8 @@ async def health(self, request: Request):
684716
"""Aggregated health status: healthy if at least one worker is alive."""
685717
total = len(self.worker_request_counts)
686718
dead = len(self.dead_workers)
687-
healthy = total - dead
719+
sleeping = len(self.sleeping_workers)
720+
healthy = total - dead - sleeping
688721
status = "healthy" if healthy > 0 else "unhealthy"
689722
code = 200 if healthy > 0 else 503
690723
return JSONResponse(
@@ -693,6 +726,8 @@ async def health(self, request: Request):
693726
"status": status,
694727
"healthy_workers": healthy,
695728
"total_workers": total,
729+
"dead_workers": dead,
730+
"sleeping_workers": sleeping,
696731
},
697732
)
698733

@@ -714,6 +749,48 @@ async def update_weights_from_disk(self, request: Request):
714749
)
715750
return JSONResponse(content={"results": results})
716751

752+
async def _broadcast_to_pool(self, request: Request, path: str) -> tuple[int, dict]:
753+
if not self.worker_request_counts:
754+
return 503, {"error": "No workers registered in the pool"}
755+
756+
body = await request.body()
757+
headers = dict(request.headers)
758+
759+
results = await self._broadcast_to_workers(path, body, headers)
760+
if not results:
761+
return 503, {"error": "No eligible workers available in the pool"}
762+
763+
return 200, {"results": results}
764+
765+
async def release_memory_occupation(self, request: Request):
766+
status, payload = await self._broadcast_to_pool(
767+
request, "release_memory_occupation"
768+
)
769+
if status != 200:
770+
return JSONResponse(status_code=status, content=payload)
771+
772+
for item in payload["results"]:
773+
if item.get("status_code") == 200:
774+
self.sleeping_workers.add(item["worker_url"])
775+
776+
return JSONResponse(content=payload)
777+
778+
async def resume_memory_occupation(self, request: Request):
779+
status, payload = await self._broadcast_to_pool(
780+
request, "resume_memory_occupation"
781+
)
782+
if status != 200:
783+
return JSONResponse(status_code=status, content=payload)
784+
785+
for item in payload["results"]:
786+
if item.get("status_code") == 200:
787+
url = item["worker_url"]
788+
self.sleeping_workers.discard(url)
789+
self.worker_failure_counts[url] = 0
790+
self.dead_workers.discard(url) # wake success => recover
791+
792+
return JSONResponse(content=payload)
793+
717794
def register_worker(self, url: str) -> None:
718795
"""Register a worker URL if not already known."""
719796
normalized_url = self.normalize_worker_url(url)
@@ -730,6 +807,7 @@ def deregister_worker(self, url: str) -> None:
730807
self.worker_failure_counts.pop(normalized_url, None)
731808
self.worker_video_support.pop(normalized_url, None)
732809
self.dead_workers.discard(normalized_url)
810+
self.sleeping_workers.discard(normalized_url)
733811
stale_video_ids = [
734812
video_id
735813
for video_id, mapped_worker in self.video_job_to_worker.items()

tests/unit/test_router_endpoints.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,114 @@ def test_update_weights_from_disk_returns_503_without_healthy_workers():
310310
response = client.post("/update_weights_from_disk", json={"model_path": "abc"})
311311
assert response.status_code == 503
312312
assert "No healthy workers available" in response.json()["error"]
313+
314+
315+
def test_get_v1_videos_by_query_video_id_returns_503_if_worker_sleeping():
316+
router = DiffusionRouter(make_router_args())
317+
router.register_worker("http://localhost:10090")
318+
router.video_job_to_worker["video_123"] = "http://localhost:10090"
319+
router.sleeping_workers.add("http://localhost:10090")
320+
321+
with TestClient(router.app) as client:
322+
resp = client.get("/v1/videos", params={"video_id": "video_123"})
323+
assert resp.status_code == 503
324+
assert "sleeping" in resp.json()["error"].lower()
325+
326+
327+
def test_release_memory_occupation_broadcasts_and_marks_sleeping_on_200_only():
328+
router = DiffusionRouter(make_router_args())
329+
router.register_worker("http://localhost:10090")
330+
router.register_worker("http://localhost:10091")
331+
332+
async def fake_broadcast(path: str, body: bytes, headers: dict):
333+
assert path == "release_memory_occupation"
334+
assert body in (b"{}", b"")
335+
assert headers.get("content-type", "").startswith("application/json")
336+
return [
337+
{
338+
"worker_url": "http://localhost:10090",
339+
"status_code": 200,
340+
"body": {"success": True},
341+
},
342+
{
343+
"worker_url": "http://localhost:10091",
344+
"status_code": 400,
345+
"body": {"success": False},
346+
},
347+
]
348+
349+
router._broadcast_to_workers = fake_broadcast
350+
351+
with TestClient(router.app) as client:
352+
resp = client.post("/release_memory_occupation", json={})
353+
assert resp.status_code == 200
354+
results = resp.json()["results"]
355+
assert len(results) == 2
356+
357+
assert "http://localhost:10090" in router.sleeping_workers
358+
assert "http://localhost:10091" not in router.sleeping_workers
359+
360+
361+
def test_resume_memory_occupation_unmarks_sleeping_resets_failure_and_revives_dead_on_200():
362+
router = DiffusionRouter(make_router_args())
363+
router.register_worker("http://localhost:10090")
364+
router.register_worker("http://localhost:10091")
365+
366+
# both sleeping
367+
router.sleeping_workers.update(["http://localhost:10090", "http://localhost:10091"])
368+
router.worker_failure_counts["http://localhost:10090"] = 2
369+
router.worker_failure_counts["http://localhost:10091"] = 2
370+
371+
# simulate: worker got marked dead during sleep
372+
router.dead_workers.add("http://localhost:10090")
373+
374+
async def fake_broadcast(path: str, body: bytes, headers: dict):
375+
assert path == "resume_memory_occupation"
376+
return [
377+
{"worker_url": "http://localhost:10090", "status_code": 200, "body": {}},
378+
{"worker_url": "http://localhost:10091", "status_code": 500, "body": {}},
379+
]
380+
381+
router._broadcast_to_workers = fake_broadcast # type: ignore[assignment]
382+
383+
with TestClient(router.app) as client:
384+
resp = client.post("/resume_memory_occupation", json={})
385+
assert resp.status_code == 200
386+
387+
# 200 worker: woken => not sleeping, failure reset, dead cleared
388+
assert "http://localhost:10090" not in router.sleeping_workers
389+
assert router.worker_failure_counts["http://localhost:10090"] == 0
390+
assert "http://localhost:10090" not in router.dead_workers
391+
392+
# non-200 worker: still sleeping, failure unchanged
393+
assert "http://localhost:10091" in router.sleeping_workers
394+
assert router.worker_failure_counts["http://localhost:10091"] == 2
395+
396+
397+
def test_select_worker_excludes_sleeping_workers():
398+
router = DiffusionRouter(make_router_args(routing_algorithm="least-request"))
399+
router.register_worker("http://localhost:10090")
400+
router.register_worker("http://localhost:10091")
401+
402+
# Mark one worker sleeping
403+
router.sleeping_workers.add("http://localhost:10090")
404+
405+
picked = router._select_worker_by_routing()
406+
assert picked == "http://localhost:10091"
407+
408+
409+
def test_select_worker_increments_count_only_for_selected():
410+
router = DiffusionRouter(make_router_args(routing_algorithm="least-request"))
411+
router.register_worker("http://localhost:10090")
412+
router.register_worker("http://localhost:10091")
413+
414+
router.sleeping_workers.add("http://localhost:10090")
415+
416+
before_10090 = router.worker_request_counts["http://localhost:10090"]
417+
before_10091 = router.worker_request_counts["http://localhost:10091"]
418+
419+
picked = router._select_worker_by_routing()
420+
assert picked == "http://localhost:10091"
421+
422+
assert router.worker_request_counts["http://localhost:10090"] == before_10090
423+
assert router.worker_request_counts["http://localhost:10091"] == before_10091 + 1

0 commit comments

Comments
 (0)