1616logger = logging .getLogger (__name__ )
1717
1818_METADATA_HOSTS = {"169.254.169.254" , "metadata.google.internal" }
19+ _IMAGE_TASK_TYPES = {"T2I" , "I2I" , "TI2I" }
1920
2021
2122class DiffusionRouter :
23+
2224 def __init__ (self , args , verbose : bool = False ):
2325 """Initialize the router for load-balancing sglang-diffusion workers."""
2426 self .args = args
@@ -32,6 +34,9 @@ def __init__(self, args, verbose: bool = False):
3234 self .worker_request_counts : dict [str , int ] = {}
3335 # URL -> consecutive health check failures
3436 self .worker_failure_counts : dict [str , int ] = {}
37+ # URL -> whether worker supports video generation
38+ # True: supports, False: image-only, None: unknown/unprobed
39+ self .worker_video_support : dict [str , bool | None ] = {}
3540 # quarantined workers excluded from routing
3641 self .dead_workers : set [str ] = set ()
3742 self ._health_task : asyncio .Task | None = None
@@ -139,14 +144,23 @@ async def _health_check_loop(self) -> None:
139144 )
140145 await asyncio .sleep (5 )
141146
142- def _use_url (self ) -> str :
143- """Select a worker URL based on the configured routing algorithm."""
147+ def _select_worker_by_routing (self , worker_urls : list [str ] | None = None ) -> str :
148+ """Select a worker URL based on routing algorithm and optional candidates.
149+
150+ Args:
151+ worker_urls: Optional list of worker URLs to consider. If provided,
152+ only these workers will be considered for selection. If not provided,
153+ all registered workers will be considered.
154+ """
144155 if not self .worker_request_counts :
145156 raise RuntimeError ("No workers registered in the pool" )
146157
147158 valid_workers = [
148159 w for w in self .worker_request_counts if w not in self .dead_workers
149160 ]
161+ if worker_urls is not None :
162+ allowed = {w for w in worker_urls if w in self .worker_request_counts }
163+ valid_workers = [w for w in valid_workers if w in allowed ]
150164 if not valid_workers :
151165 raise RuntimeError ("No healthy workers available in the pool" )
152166
@@ -202,13 +216,14 @@ def _build_proxy_response(
202216 media_type = content_type ,
203217 )
204218
205- async def _forward_to_worker (self , request : Request , path : str ) -> Response :
206- """Forward a request to a selected worker and return the response."""
219+ async def _forward_to_worker (
220+ self , request : Request , path : str , worker_urls : list [str ] | None = None
221+ ) -> Response :
222+ """Forward request to a selected worker (optionally from candidate URLs)."""
207223 try :
208- worker_url = self ._use_url ( )
224+ worker_url = self ._select_worker_by_routing ( worker_urls = worker_urls )
209225 except RuntimeError as exc :
210226 return JSONResponse (status_code = 503 , content = {"error" : str (exc )})
211-
212227 try :
213228 query = request .url .query
214229 url = (
@@ -243,6 +258,29 @@ async def _forward_to_worker(self, request: Request, path: str) -> Response:
243258 finally :
244259 self ._finish_url (worker_url )
245260
261+ async def _probe_worker_video_support (self , worker_url : str ) -> bool | None :
262+ """Probe /v1/models and infer if this worker supports video generation."""
263+ try :
264+ response = await self .client .get (f"{ worker_url } /v1/models" , timeout = 5.0 )
265+ if response .status_code == 200 :
266+ payload = response .json ()
267+ data = payload .get ("data" )
268+ task_type = (
269+ data [0 ].get ("task_type" )
270+ if isinstance (data , list ) and data
271+ else None
272+ )
273+ if isinstance (task_type , str ):
274+ return task_type .upper () not in _IMAGE_TASK_TYPES
275+ except (httpx .RequestError , json .JSONDecodeError ):
276+ return None
277+
278+ async def refresh_worker_video_support (self , worker_url : str ) -> None :
279+ """Refresh cached video capability for a single worker."""
280+ self .worker_video_support [worker_url ] = await self ._probe_worker_video_support (
281+ worker_url
282+ )
283+
246284 async def _broadcast_to_workers (
247285 self , path : str , body : bytes , headers : dict
248286 ) -> list [dict ]:
@@ -297,7 +335,7 @@ def _sanitize_response_headers(headers) -> dict:
297335 }
298336
299337 @staticmethod
300- def _normalize_worker_url (url : str ) -> str :
338+ def normalize_worker_url (url : str ) -> str :
301339 if not isinstance (url , str ):
302340 raise ValueError ("worker_url must be a string" )
303341
@@ -345,7 +383,22 @@ async def generate(self, request: Request):
345383
346384 async def generate_video (self , request : Request ):
347385 """Route video generation to /v1/videos."""
348- return await self ._forward_to_worker (request , "v1/videos" )
386+ candidate_workers = [
387+ worker_url
388+ for worker_url , support in self .worker_video_support .items ()
389+ if support
390+ ]
391+
392+ if not candidate_workers :
393+ return JSONResponse (
394+ status_code = 400 ,
395+ content = {
396+ "error" : "No video-capable workers available in current worker pool." ,
397+ },
398+ )
399+ return await self ._forward_to_worker (
400+ request , "v1/videos" , worker_urls = candidate_workers
401+ )
349402
350403 async def health (self , request : Request ):
351404 """Aggregated health status: healthy if at least one worker is alive."""
@@ -388,10 +441,11 @@ async def update_weights_from_disk(self, request: Request):
388441
389442 def register_worker (self , url : str ) -> None :
390443 """Register a worker URL if not already known."""
391- normalized_url = self ._normalize_worker_url (url )
444+ normalized_url = self .normalize_worker_url (url )
392445 if normalized_url not in self .worker_request_counts :
393446 self .worker_request_counts [normalized_url ] = 0
394447 self .worker_failure_counts [normalized_url ] = 0
448+ self .worker_video_support [normalized_url ] = None
395449 if self .verbose :
396450 print (f"[diffusion-router] Added new worker: { normalized_url } " )
397451
@@ -422,6 +476,7 @@ async def add_worker(self, request: Request):
422476 self .register_worker (worker_url )
423477 except ValueError as exc :
424478 return JSONResponse (status_code = 400 , content = {"error" : str (exc )})
479+ await self .refresh_worker_video_support (worker_url )
425480 return {
426481 "status" : "success" ,
427482 "worker_urls" : list (self .worker_request_counts .keys ()),
0 commit comments