@@ -38,6 +38,8 @@ def __init__(self, args, verbose: bool = False):
3838 self .worker_video_support : dict [str , bool | None ] = {}
3939 # quarantined workers excluded from routing
4040 self .dead_workers : set [str ] = set ()
41+ # record workers in sleeping status
42+ self .sleeping_workers : set [str ] = set ()
4143 # video_id -> worker URL mapping for stable query routing
4244 self .video_job_to_worker : dict [str , str ] = {}
4345 self ._health_task : asyncio .Task | None = None
@@ -81,6 +83,8 @@ def _setup_routes(self) -> None:
8183 self .app .get ("/v1/models" )(self .get_models )
8284 self .app .get ("/health" )(self .health )
8385 self .app .post ("/update_weights_from_disk" )(self .update_weights_from_disk )
86+ self .app .post ("/release_memory_occupation" )(self .release_memory_occupation )
87+ self .app .post ("/resume_memory_occupation" )(self .resume_memory_occupation )
8488 self .app .api_route ("/{path:path}" , methods = ["GET" , "POST" , "PUT" , "DELETE" ])(
8589 self .proxy
8690 )
@@ -183,7 +187,9 @@ def _select_worker_by_routing(self, worker_urls: list[str] | None = None) -> str
183187 raise RuntimeError ("No workers registered in the pool" )
184188
185189 valid_workers = [
186- w for w in self .worker_request_counts if w not in self .dead_workers
190+ w
191+ for w in self .worker_request_counts
192+ if w not in self .dead_workers and w not in self .sleeping_workers
187193 ]
188194 if worker_urls is not None :
189195 allowed = {w for w in worker_urls if w in self .worker_request_counts }
@@ -306,6 +312,10 @@ async def _forward_to_registered_worker(
306312 return JSONResponse (
307313 status_code = 503 , content = {"error" : "Mapped worker is unavailable" }
308314 )
315+ if worker_url in self .sleeping_workers :
316+ return JSONResponse (
317+ status_code = 503 , content = {"error" : "Mapped worker is sleeping" }
318+ )
309319 self .worker_request_counts [worker_url ] += 1
310320 return await self ._forward_to_selected_worker (request , path , worker_url )
311321
@@ -361,7 +371,12 @@ async def _probe_worker_video_support(self, worker_url: str) -> bool | None:
361371 )
362372 if isinstance (task_type , str ):
363373 return task_type .upper () not in _IMAGE_TASK_TYPES
364- except (httpx .RequestError , json .JSONDecodeError ):
374+ except (httpx .RequestError , json .JSONDecodeError ) as exc :
375+ logger .debug (
376+ "[diffusion-router] video support probe failed: worker=%s error=%s" ,
377+ worker_url ,
378+ exc ,
379+ )
365380 return None
366381
367382 async def refresh_worker_video_support (self , worker_url : str ) -> None :
@@ -373,8 +388,23 @@ async def refresh_worker_video_support(self, worker_url: str) -> None:
373388 async def _broadcast_to_workers (
374389 self , path : str , body : bytes , headers : dict
375390 ) -> list [dict ]:
376- """Send a request to all healthy workers and collect results."""
377- urls = [u for u in self .worker_request_counts if u not in self .dead_workers ]
391+ """
392+ Broadcast request to eligible workers.
393+
394+ Rules:
395+ - For resume_memory_occupation (wake): target sleeping workers (even if currently marked dead).
396+ - For all other requests: target only active workers (exclude dead AND sleeping).
397+ """
398+ if path == "resume_memory_occupation" :
399+ # Wake is a recovery point: allow waking workers that were marked dead during sleep.
400+ urls = [u for u in self .sleeping_workers if u in self .worker_request_counts ]
401+ else :
402+ urls = [
403+ u
404+ for u in self .worker_request_counts
405+ if u not in self .dead_workers and u not in self .sleeping_workers
406+ ]
407+
378408 if not urls :
379409 return []
380410
@@ -494,6 +524,7 @@ def _build_worker_payload(self, worker_url: str) -> dict:
494524 "url" : worker_url ,
495525 "active_requests" : self .worker_request_counts .get (worker_url , 0 ),
496526 "is_dead" : worker_url in self .dead_workers ,
527+ "is_sleeping" : worker_url in self .sleeping_workers ,
497528 "consecutive_failures" : self .worker_failure_counts .get (worker_url , 0 ),
498529 "video_support" : self .worker_video_support .get (worker_url ),
499530 }
@@ -554,6 +585,7 @@ def _video_capable_workers(self) -> list[str]:
554585 if support
555586 and worker_url in self .worker_request_counts
556587 and worker_url not in self .dead_workers
588+ and worker_url not in self .sleeping_workers
557589 ]
558590
559591 @staticmethod
@@ -607,7 +639,7 @@ async def get_models(self, request: Request):
607639 worker_urls = [
608640 url
609641 for url in self .worker_request_counts .keys ()
610- if url not in self .dead_workers
642+ if url not in self .dead_workers and url not in self . sleeping_workers
611643 ]
612644 if not worker_urls :
613645 return JSONResponse (
@@ -684,7 +716,8 @@ async def health(self, request: Request):
684716 """Aggregated health status: healthy if at least one worker is alive."""
685717 total = len (self .worker_request_counts )
686718 dead = len (self .dead_workers )
687- healthy = total - dead
719+ sleeping = len (self .sleeping_workers )
720+ healthy = total - dead - sleeping
688721 status = "healthy" if healthy > 0 else "unhealthy"
689722 code = 200 if healthy > 0 else 503
690723 return JSONResponse (
@@ -693,6 +726,8 @@ async def health(self, request: Request):
693726 "status" : status ,
694727 "healthy_workers" : healthy ,
695728 "total_workers" : total ,
729+ "dead_workers" : dead ,
730+ "sleeping_workers" : sleeping ,
696731 },
697732 )
698733
@@ -714,6 +749,48 @@ async def update_weights_from_disk(self, request: Request):
714749 )
715750 return JSONResponse (content = {"results" : results })
716751
752+ async def _broadcast_to_pool (self , request : Request , path : str ) -> tuple [int , dict ]:
753+ if not self .worker_request_counts :
754+ return 503 , {"error" : "No workers registered in the pool" }
755+
756+ body = await request .body ()
757+ headers = dict (request .headers )
758+
759+ results = await self ._broadcast_to_workers (path , body , headers )
760+ if not results :
761+ return 503 , {"error" : "No eligible workers available in the pool" }
762+
763+ return 200 , {"results" : results }
764+
765+ async def release_memory_occupation (self , request : Request ):
766+ status , payload = await self ._broadcast_to_pool (
767+ request , "release_memory_occupation"
768+ )
769+ if status != 200 :
770+ return JSONResponse (status_code = status , content = payload )
771+
772+ for item in payload ["results" ]:
773+ if item .get ("status_code" ) == 200 :
774+ self .sleeping_workers .add (item ["worker_url" ])
775+
776+ return JSONResponse (content = payload )
777+
778+ async def resume_memory_occupation (self , request : Request ):
779+ status , payload = await self ._broadcast_to_pool (
780+ request , "resume_memory_occupation"
781+ )
782+ if status != 200 :
783+ return JSONResponse (status_code = status , content = payload )
784+
785+ for item in payload ["results" ]:
786+ if item .get ("status_code" ) == 200 :
787+ url = item ["worker_url" ]
788+ self .sleeping_workers .discard (url )
789+ self .worker_failure_counts [url ] = 0
790+ self .dead_workers .discard (url ) # wake success => recover
791+
792+ return JSONResponse (content = payload )
793+
717794 def register_worker (self , url : str ) -> None :
718795 """Register a worker URL if not already known."""
719796 normalized_url = self .normalize_worker_url (url )
@@ -730,6 +807,7 @@ def deregister_worker(self, url: str) -> None:
730807 self .worker_failure_counts .pop (normalized_url , None )
731808 self .worker_video_support .pop (normalized_url , None )
732809 self .dead_workers .discard (normalized_url )
810+ self .sleeping_workers .discard (normalized_url )
733811 stale_video_ids = [
734812 video_id
735813 for video_id , mapped_worker in self .video_job_to_worker .items ()
0 commit comments