Skip to content

Commit 6a01e51

Browse files
support identical method
1 parent aff26e6 commit 6a01e51

3 files changed

Lines changed: 78 additions & 22 deletions

File tree

README.md

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ sglang-d-router --port 30081 --launcher-config examples/local_launcher.yaml
5555
```yaml
5656
launcher:
5757
backend: local
58-
model: stabilityai/stable-diffusion-3-medium-diffusers
58+
model: Qwen/Qwen-Image
5959
num_workers: 8
6060
num_gpus_per_worker: 1
6161
worker_base_port: 10090
@@ -70,14 +70,14 @@ launcher:
7070

7171
# worker 1
7272
CUDA_VISIBLE_DEVICES=0 sglang serve \
73-
--model-path stabilityai/stable-diffusion-3-medium-diffusers \
73+
--model-path Qwen/Qwen-Image \
7474
--num-gpus 1 \
7575
--host 127.0.0.1 \
7676
--port 30000
7777

7878
# worker 2
7979
CUDA_VISIBLE_DEVICES=1 sglang serve \
80-
--model-path stabilityai/stable-diffusion-3-medium-diffusers \
80+
--model-path Qwen/Qwen-Image \
8181
--num-gpus 1 \
8282
--host 127.0.0.1 \
8383
--port 30002
@@ -139,7 +139,7 @@ print("Saved to output.png")
139139
# so this request will fail. Use a video-capable model instead.
140140

141141
resp = requests.post(f"{ROUTER}/v1/videos", json={
142-
"model": "stabilityai/stable-diffusion-3-medium-diffusers",
142+
"model": "Qwen/Qwen-Image",
143143
"prompt": "a flowing river",
144144
})
145145
print(resp.json())
@@ -149,9 +149,17 @@ if video_id:
149149

150150
# Update weights from disk
151151
resp = requests.post(f"{ROUTER}/update_weights_from_disk", json={
152-
"model_path": "/path/to/new/checkpoint",
152+
"model_path": "Qwen/Qwen-Image-2512",
153153
})
154154
print(resp.json())
155+
156+
# sleep and wake up
157+
resp = requests.post(f"{ROUTER}/release_memory_occupation", json={})
158+
print(resp.json())
159+
160+
161+
resp = requests.post(f"{ROUTER}/resume_memory_occupation", json={})
162+
print(resp.json())
155163
```
156164

157165
### Native Diffusion Generate Endpoint (with Trajectory & Log-Prob)

src/sglang_diffusion_routing/router/diffusion_router.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -771,30 +771,53 @@ async def release_memory_occupation(self, request: Request):
771771
status, payload = await self._broadcast_to_pool(
772772
request, "release_memory_occupation"
773773
)
774-
if status != 200:
775-
return JSONResponse(status_code=status, content=payload)
776-
777-
for item in payload["results"]:
778-
if item.get("status_code") == 200:
779-
self.sleeping_workers.add(item["worker_url"])
774+
if status == 200:
775+
for item in payload["results"]:
776+
if item.get("status_code") == 200:
777+
self.sleeping_workers.add(item["worker_url"])
778+
return JSONResponse(content=payload)
779+
780+
all_already_sleeping = bool(self.worker_request_counts) and all(
781+
url in self.sleeping_workers
782+
for url in self.worker_request_counts
783+
if url not in self.dead_workers
784+
)
785+
if all_already_sleeping:
786+
return JSONResponse(
787+
content={
788+
"message": "All workers are already sleeping",
789+
"sleeping_workers": len(self.sleeping_workers),
790+
}
791+
)
780792

781-
return JSONResponse(content=payload)
793+
return JSONResponse(status_code=status, content=payload)
782794

783795
async def resume_memory_occupation(self, request: Request):
784796
status, payload = await self._broadcast_to_pool(
785797
request, "resume_memory_occupation"
786798
)
787-
if status != 200:
788-
return JSONResponse(status_code=status, content=payload)
789-
790-
for item in payload["results"]:
791-
if item.get("status_code") == 200:
792-
url = item["worker_url"]
793-
self.sleeping_workers.discard(url)
794-
self.worker_failure_counts[url] = 0
795-
self.dead_workers.discard(url) # wake success => recover
799+
if status == 200:
800+
for item in payload["results"]:
801+
if item.get("status_code") == 200:
802+
url = item["worker_url"]
803+
self.sleeping_workers.discard(url)
804+
self.worker_failure_counts[url] = 0
805+
self.dead_workers.discard(url)
806+
return JSONResponse(content=payload)
807+
808+
has_no_sleeping_workers = not self.sleeping_workers and bool(
809+
self.worker_request_counts
810+
)
811+
if has_no_sleeping_workers:
812+
return JSONResponse(
813+
content={
814+
"message": "All workers are already active",
815+
"active_workers": len(self.worker_request_counts)
816+
- len(self.dead_workers),
817+
}
818+
)
796819

797-
return JSONResponse(content=payload)
820+
return JSONResponse(status_code=status, content=payload)
798821

799822
def register_worker(self, url: str) -> None:
800823
"""Register a worker URL if not already known."""

tests/unit/test_router_endpoints.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,31 @@ async def fake_broadcast(path: str, body: bytes, headers: dict):
415415
assert router.worker_failure_counts["http://localhost:10091"] == 2
416416

417417

418+
def test_release_memory_occupation_idempotent_when_all_already_sleeping():
419+
router = DiffusionRouter(make_router_args())
420+
router.register_worker("http://localhost:10090")
421+
router.register_worker("http://localhost:10091")
422+
router.sleeping_workers.update(["http://localhost:10090", "http://localhost:10091"])
423+
424+
with TestClient(router.app) as client:
425+
resp = client.post("/release_memory_occupation", json={})
426+
assert resp.status_code == 200
427+
assert "already sleeping" in resp.json()["message"].lower()
428+
assert resp.json()["sleeping_workers"] == 2
429+
430+
431+
def test_resume_memory_occupation_idempotent_when_all_already_active():
432+
router = DiffusionRouter(make_router_args())
433+
router.register_worker("http://localhost:10090")
434+
router.register_worker("http://localhost:10091")
435+
436+
with TestClient(router.app) as client:
437+
resp = client.post("/resume_memory_occupation", json={})
438+
assert resp.status_code == 200
439+
assert "already active" in resp.json()["message"].lower()
440+
assert resp.json()["active_workers"] == 2
441+
442+
418443
def test_select_worker_excludes_sleeping_workers():
419444
router = DiffusionRouter(make_router_args(routing_algorithm="least-request"))
420445
router.register_worker("http://localhost:10090")

0 commit comments

Comments
 (0)