Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ Video query routing is stable by `video_id`: router caches `video_id -> worker`
| Method | Path | Description |
|---|---|---|
| `POST` | `/update_weights_from_disk` | Reload weights from disk on all healthy workers |

| `POST` | `/release_memory_occupation` | Broadcast sleep to all healthy workers (release GPU memory occupation) |
| `POST` | `/resume_memory_occupation` | Broadcast wake to all healthy workers (resume GPU memory occupation) |

## Acknowledgment

Expand Down
90 changes: 84 additions & 6 deletions src/sglang_diffusion_routing/router/diffusion_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self, args, verbose: bool = False):
self.worker_video_support: dict[str, bool | None] = {}
# quarantined workers excluded from routing
self.dead_workers: set[str] = set()
# record workers in sleeping status
self.sleeping_workers: set[str] = set()
# video_id -> worker URL mapping for stable query routing
self.video_job_to_worker: dict[str, str] = {}
self._health_task: asyncio.Task | None = None
Expand Down Expand Up @@ -81,6 +83,8 @@ def _setup_routes(self) -> None:
self.app.get("/v1/models")(self.get_models)
self.app.get("/health")(self.health)
self.app.post("/update_weights_from_disk")(self.update_weights_from_disk)
self.app.post("/release_memory_occupation")(self.release_memory_occupation)
self.app.post("/resume_memory_occupation")(self.resume_memory_occupation)
Comment on lines +86 to +87
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The new endpoints /release_memory_occupation and /resume_memory_occupation lack any authentication or authorization checks. Since these endpoints perform administrative actions (broadcasting sleep/wake commands to all workers and modifying the router's internal state), exposing them without access control allows any user with network access to the router to disrupt the service by putting all workers to sleep. This is a Missing Function-Level Access Control vulnerability.

self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(
self.proxy
)
Expand Down Expand Up @@ -173,7 +177,9 @@ def _select_worker_by_routing(self, worker_urls: list[str] | None = None) -> str
raise RuntimeError("No workers registered in the pool")

valid_workers = [
w for w in self.worker_request_counts if w not in self.dead_workers
w
for w in self.worker_request_counts
if w not in self.dead_workers and w not in self.sleeping_workers
]
if worker_urls is not None:
allowed = {w for w in worker_urls if w in self.worker_request_counts}
Expand Down Expand Up @@ -296,6 +302,10 @@ async def _forward_to_registered_worker(
return JSONResponse(
status_code=503, content={"error": "Mapped worker is unavailable"}
)
if worker_url in self.sleeping_workers:
return JSONResponse(
status_code=503, content={"error": "Mapped worker is sleeping"}
)
self.worker_request_counts[worker_url] += 1
return await self._forward_to_selected_worker(request, path, worker_url)

Expand Down Expand Up @@ -351,7 +361,12 @@ async def _probe_worker_video_support(self, worker_url: str) -> bool | None:
)
if isinstance(task_type, str):
return task_type.upper() not in _IMAGE_TASK_TYPES
except (httpx.RequestError, json.JSONDecodeError):
except (httpx.RequestError, json.JSONDecodeError) as exc:
logger.debug(
"[diffusion-router] video support probe failed: worker=%s error=%s",
worker_url,
exc,
)
return None

async def refresh_worker_video_support(self, worker_url: str) -> None:
Expand All @@ -363,8 +378,23 @@ async def refresh_worker_video_support(self, worker_url: str) -> None:
async def _broadcast_to_workers(
self, path: str, body: bytes, headers: dict
) -> list[dict]:
"""Send a request to all healthy workers and collect results."""
urls = [u for u in self.worker_request_counts if u not in self.dead_workers]
"""
Broadcast request to eligible workers.

Rules:
- For resume_memory_occupation (wake): target sleeping workers (even if currently marked dead).
- For all other requests: target only active workers (exclude dead AND sleeping).
"""
if path == "resume_memory_occupation":
# Wake is a recovery point: allow waking workers that were marked dead during sleep.
urls = [u for u in self.sleeping_workers if u in self.worker_request_counts]
else:
urls = [
u
for u in self.worker_request_counts
if u not in self.dead_workers and u not in self.sleeping_workers
]

if not urls:
return []

Expand Down Expand Up @@ -484,6 +514,7 @@ def _build_worker_payload(self, worker_url: str) -> dict:
"url": worker_url,
"active_requests": self.worker_request_counts.get(worker_url, 0),
"is_dead": worker_url in self.dead_workers,
"is_sleeping": worker_url in self.sleeping_workers,
"consecutive_failures": self.worker_failure_counts.get(worker_url, 0),
"video_support": self.worker_video_support.get(worker_url),
}
Expand Down Expand Up @@ -544,6 +575,7 @@ def _video_capable_workers(self) -> list[str]:
if support
and worker_url in self.worker_request_counts
and worker_url not in self.dead_workers
and worker_url not in self.sleeping_workers
]

@staticmethod
Expand Down Expand Up @@ -597,7 +629,7 @@ async def get_models(self, request: Request):
worker_urls = [
url
for url in self.worker_request_counts.keys()
if url not in self.dead_workers
if url not in self.dead_workers and url not in self.sleeping_workers
]
if not worker_urls:
return JSONResponse(
Expand Down Expand Up @@ -674,7 +706,8 @@ async def health(self, request: Request):
"""Aggregated health status: healthy if at least one worker is alive."""
total = len(self.worker_request_counts)
dead = len(self.dead_workers)
healthy = total - dead
sleeping = len(self.sleeping_workers)
healthy = total - dead - sleeping
status = "healthy" if healthy > 0 else "unhealthy"
code = 200 if healthy > 0 else 503
return JSONResponse(
Expand All @@ -683,6 +716,8 @@ async def health(self, request: Request):
"status": status,
"healthy_workers": healthy,
"total_workers": total,
"dead_workers": dead,
"sleeping_workers": sleeping,
},
)

Expand All @@ -704,6 +739,48 @@ async def update_weights_from_disk(self, request: Request):
)
return JSONResponse(content={"results": results})

async def _broadcast_to_pool(self, request: Request, path: str) -> tuple[int, dict]:
if not self.worker_request_counts:
return 503, {"error": "No workers registered in the pool"}

body = await request.body()
headers = dict(request.headers)

results = await self._broadcast_to_workers(path, body, headers)
if not results:
return 503, {"error": "No eligible workers available in the pool"}

return 200, {"results": results}

async def release_memory_occupation(self, request: Request):
status, payload = await self._broadcast_to_pool(
request, "release_memory_occupation"
)
if status != 200:
return JSONResponse(status_code=status, content=payload)

for item in payload["results"]:
if item.get("status_code") == 200:
self.sleeping_workers.add(item["worker_url"])

return JSONResponse(content=payload)

async def resume_memory_occupation(self, request: Request):
status, payload = await self._broadcast_to_pool(
request, "resume_memory_occupation"
)
if status != 200:
return JSONResponse(status_code=status, content=payload)

for item in payload["results"]:
if item.get("status_code") == 200:
url = item["worker_url"]
self.sleeping_workers.discard(url)
self.worker_failure_counts[url] = 0
self.dead_workers.discard(url) # wake success => recover

return JSONResponse(content=payload)

def register_worker(self, url: str) -> None:
"""Register a worker URL if not already known."""
normalized_url = self.normalize_worker_url(url)
Expand All @@ -720,6 +797,7 @@ def deregister_worker(self, url: str) -> None:
self.worker_failure_counts.pop(normalized_url, None)
self.worker_video_support.pop(normalized_url, None)
self.dead_workers.discard(normalized_url)
self.sleeping_workers.discard(normalized_url)
stale_video_ids = [
video_id
for video_id, mapped_worker in self.video_job_to_worker.items()
Expand Down
111 changes: 111 additions & 0 deletions tests/unit/test_router_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,114 @@ def test_update_weights_from_disk_returns_503_without_healthy_workers():
response = client.post("/update_weights_from_disk", json={"model_path": "abc"})
assert response.status_code == 503
assert "No healthy workers available" in response.json()["error"]


def test_get_v1_videos_by_query_video_id_returns_503_if_worker_sleeping():
router = DiffusionRouter(make_router_args())
router.register_worker("http://localhost:10090")
router.video_job_to_worker["video_123"] = "http://localhost:10090"
router.sleeping_workers.add("http://localhost:10090")

with TestClient(router.app) as client:
resp = client.get("/v1/videos", params={"video_id": "video_123"})
assert resp.status_code == 503
assert "sleeping" in resp.json()["error"].lower()


def test_release_memory_occupation_broadcasts_and_marks_sleeping_on_200_only():
router = DiffusionRouter(make_router_args())
router.register_worker("http://localhost:10090")
router.register_worker("http://localhost:10091")

async def fake_broadcast(path: str, body: bytes, headers: dict):
assert path == "release_memory_occupation"
assert body in (b"{}", b"")
assert headers.get("content-type", "").startswith("application/json")
return [
{
"worker_url": "http://localhost:10090",
"status_code": 200,
"body": {"success": True},
},
{
"worker_url": "http://localhost:10091",
"status_code": 400,
"body": {"success": False},
},
]

router._broadcast_to_workers = fake_broadcast

with TestClient(router.app) as client:
resp = client.post("/release_memory_occupation", json={})
assert resp.status_code == 200
results = resp.json()["results"]
assert len(results) == 2

assert "http://localhost:10090" in router.sleeping_workers
assert "http://localhost:10091" not in router.sleeping_workers


def test_resume_memory_occupation_unmarks_sleeping_resets_failure_and_revives_dead_on_200():
router = DiffusionRouter(make_router_args())
router.register_worker("http://localhost:10090")
router.register_worker("http://localhost:10091")

# both sleeping
router.sleeping_workers.update(["http://localhost:10090", "http://localhost:10091"])
router.worker_failure_counts["http://localhost:10090"] = 2
router.worker_failure_counts["http://localhost:10091"] = 2

# simulate: worker got marked dead during sleep
router.dead_workers.add("http://localhost:10090")

async def fake_broadcast(path: str, body: bytes, headers: dict):
assert path == "resume_memory_occupation"
return [
{"worker_url": "http://localhost:10090", "status_code": 200, "body": {}},
{"worker_url": "http://localhost:10091", "status_code": 500, "body": {}},
]

router._broadcast_to_workers = fake_broadcast # type: ignore[assignment]

with TestClient(router.app) as client:
resp = client.post("/resume_memory_occupation", json={})
assert resp.status_code == 200

# 200 worker: woken => not sleeping, failure reset, dead cleared
assert "http://localhost:10090" not in router.sleeping_workers
assert router.worker_failure_counts["http://localhost:10090"] == 0
assert "http://localhost:10090" not in router.dead_workers

# non-200 worker: still sleeping, failure unchanged
assert "http://localhost:10091" in router.sleeping_workers
assert router.worker_failure_counts["http://localhost:10091"] == 2


def test_select_worker_excludes_sleeping_workers():
router = DiffusionRouter(make_router_args(routing_algorithm="least-request"))
router.register_worker("http://localhost:10090")
router.register_worker("http://localhost:10091")

# Mark one worker sleeping
router.sleeping_workers.add("http://localhost:10090")

picked = router._select_worker_by_routing()
assert picked == "http://localhost:10091"


def test_select_worker_increments_count_only_for_selected():
router = DiffusionRouter(make_router_args(routing_algorithm="least-request"))
router.register_worker("http://localhost:10090")
router.register_worker("http://localhost:10091")

router.sleeping_workers.add("http://localhost:10090")

before_10090 = router.worker_request_counts["http://localhost:10090"]
before_10091 = router.worker_request_counts["http://localhost:10091"]

picked = router._select_worker_by_routing()
assert picked == "http://localhost:10091"

assert router.worker_request_counts["http://localhost:10090"] == before_10090
assert router.worker_request_counts["http://localhost:10091"] == before_10091 + 1