@@ -48,13 +48,25 @@ def get_request_num_tokens(request: OpenAIRequest) -> int:
4848
4949class ServerState :
5050
51- def __init__ (self , server : str , use_tokens : bool = False ):
51+ def __init__ (
52+ self ,
53+ server : str ,
54+ use_tokens : bool = False ,
55+ session_provider : Optional [Callable [[],
56+ aiohttp .ClientSession ]] = None ):
5257 self ._server = server
58+ self ._base_url = server if server .startswith (
59+ "http" ) else f"http://{ server } "
5360 self ._num_active_requests = 0
5461 self ._num_active_tokens = 0
5562 self ._use_tokens = use_tokens
63+ self ._session_provider = session_provider
5664 self ._lock = asyncio .Lock ()
5765
66+ @property
67+ def _session (self ) -> Optional [aiohttp .ClientSession ]:
68+ return self ._session_provider () if self ._session_provider else None
69+
5870 async def increment_load (self , request : OpenAIRequest ):
5971 num_tokens = get_request_num_tokens (request ) if self ._use_tokens else 0
6072 async with self ._lock :
@@ -69,19 +81,23 @@ async def decrement_load(self, request: OpenAIRequest):
6981
7082 async def is_healthy (self ) -> bool :
7183 try :
72- async with self ._session .get (self ._server + "/health" ) as response :
84+ async with self ._session .get (
85+ f"{ self ._base_url } /health" ) as response :
7386 return response .status == 200
7487 except Exception :
7588 return False
7689
7790
7891class KvCacheAwareServerState (ServerState ):
7992
80- def __init__ (self ,
81- server : str ,
82- use_tokens : bool = False ,
83- tokens_per_block : int = 32 ):
84- super ().__init__ (server , use_tokens )
93+ def __init__ (
94+ self ,
95+ server : str ,
96+ use_tokens : bool = False ,
97+ tokens_per_block : int = 32 ,
98+ session_provider : Optional [Callable [[],
99+ aiohttp .ClientSession ]] = None ):
100+ super ().__init__ (server , use_tokens , session_provider )
85101 self ._kv_cache_block_table : set [int ] = set ()
86102 self ._tokens_per_block = tokens_per_block
87103
@@ -108,7 +124,8 @@ def update_with_events(self, events: Iterable[dict]):
108124 self .remove_blocks (event ["block_hashes" ])
109125
110126 async def poll_events (self , session : aiohttp .ClientSession ):
111- async with session .post (self ._server + "/kv_cache_events" ) as response :
127+ async with session .post (
128+ f"{ self ._base_url } /kv_cache_events" ) as response :
112129 events_raw = await response .json ()
113130 return events_raw
114131
@@ -124,19 +141,23 @@ async def matched_tokens(self, block_hashes: list[list[int]]) -> int:
124141 break
125142 return match_count
126143
127- async def decrement_load (self ,
128- request : OpenAIRequest ,
129- session : Optional [aiohttp .ClientSession ] = None ):
144+ async def decrement_load (self , request : OpenAIRequest ):
130145 num_tokens = get_request_num_tokens (request ) if self ._use_tokens else 0
131- if session is not None :
132- events_raw = await self .poll_events (session )
133- else :
134- events_raw = None
135146 async with self ._lock :
136147 self ._num_active_requests -= 1
137148 self ._num_active_tokens -= num_tokens
138- if events_raw is not None :
139- self .update_with_events (events_raw )
149+
150+ async def poll_and_update (self ):
151+ """Poll KV cache events and update block table. Called outside the critical path."""
152+ try :
153+ assert self ._session is not None , "session must be set on KvCacheAwareServerState"
154+ events_raw = await self .poll_events (self ._session )
155+ async with self ._lock :
156+ if events_raw is not None :
157+ self .update_with_events (events_raw )
158+ except Exception as e :
159+ logger .warning (
160+ f"Failed to poll KV cache events from { self ._server } : { e } " )
140161
141162 def num_active_tokens (self ):
142163 return self ._num_active_tokens
@@ -165,7 +186,8 @@ def _init_load_balancing(self,
165186 self ._server_state [server ] = self ._create_server_state (server )
166187
167188 def _create_server_state (self , server : str ) -> ServerState :
168- return self ._server_state_class (server , self ._use_tokens )
189+ return self ._server_state_class (server , self ._use_tokens ,
190+ lambda : self .session )
169191
170192 def _get_server_load (self , server : str ) -> int :
171193 state = self ._server_state [server ]
@@ -185,11 +207,12 @@ async def _register_request(self, server: str, request: OpenAIRequest):
185207 await self ._server_state [server ].increment_load (request )
186208 self ._req_routing_table [id (request )] = server
187209
188- async def _unregister_request (self , request : OpenAIRequest ,
189- ** kwargs ) -> str :
190- server = self ._req_routing_table .pop (id (request ))
210+ async def _unregister_request (self , request : OpenAIRequest ) -> str :
211+ server = self ._req_routing_table .pop (id (request ), None )
212+ if server is None :
213+ return ""
191214 if server in self ._server_state :
192- await self ._server_state [server ].decrement_load (request , ** kwargs )
215+ await self ._server_state [server ].decrement_load (request )
193216 return server
194217
195218 def _select_least_loaded (self ,
@@ -231,6 +254,17 @@ def __init__(
231254 self ._server_preparation_func = server_preparation_func
232255 self ._prepared_ready_servers : set [str ] = set ()
233256
257+ async def close (self ):
258+ """Close the shared HTTP session."""
259+ if self ._session :
260+ try :
261+ await self ._session .close ()
262+ self ._session = None
263+ logger .debug ("HTTP session closed" )
264+ except Exception as e :
265+ logger .error (f"Error closing session: { e } " )
266+ self ._session = None
267+
234268 @abstractmethod
235269 def _on_servers_updated (self , old_servers , new_servers ):
236270 """Called when the server list changes.
@@ -247,19 +281,21 @@ def _on_servers_updated(self, old_servers, new_servers):
247281 def servers (self ) -> List [str ]:
248282 return self ._servers
249283
284+ @staticmethod
285+ def _ensure_url (server : str ) -> str :
286+ return server if server .startswith ("http" ) else f"http://{ server } "
287+
250288 async def _fetch_server_info (self , server : str , timeout : float ) -> dict :
251- session = aiohttp .ClientSession ()
252289 try :
253- async with session .get (f"http://{ server } /server_info" ,
254- timeout = timeout ) as response :
290+ url = self ._ensure_url (server )
291+ async with self .session .get (f"{ url } /server_info" ,
292+ timeout = timeout ) as response :
255293 return await response .json ()
256294 except Exception as e :
257295 logger .warning (
258296 f"Error fetching server info for server { server } : { e } " )
259297 raise RuntimeError (
260298 f"Failed to fetch server info for server { server } " ) from e
261- finally :
262- await session .close ()
263299
264300 async def _prepare_server (self , server : str ):
265301 if server in self ._prepared_ready_servers :
@@ -322,15 +358,17 @@ async def get_next_server(
322358 async def finish_request (self , request : OpenAIRequest ):
323359 pass
324360
361+ @property
362+ def session (self ) -> aiohttp .ClientSession :
363+ if not self ._session :
364+ self ._session = aiohttp .ClientSession ()
365+ return self ._session
366+
325367 async def start_server_monitoring (self , poll_interval : float = 10.0 ):
326368 """Start monitoring servers update from metadata service"""
327369 if not self ._metadata_server :
328370 raise RuntimeError ("Metadata server is not initialized" )
329371
330- # Create a session for health checks if it doesn't exist
331- if not self ._session :
332- self ._session = aiohttp .ClientSession ()
333-
334372 logger .info (
335373 f"Starting server monitoring for { self ._server_role } servers" )
336374 self ._monitor_task = asyncio .create_task (
@@ -348,18 +386,7 @@ async def stop_server_monitoring(self):
348386 pass
349387 self ._monitor_task = None
350388
351- # Close session when stopping monitoring
352- await self .close_session ()
353-
354- async def close_session (self ):
355- if self ._session :
356- try :
357- await self ._session .close ()
358- self ._session = None
359- logger .debug ("HTTP session closed" )
360- except Exception as e :
361- logger .error (f"Error closing session: { e } " )
362- self ._session = None
389+ await self .close ()
363390
364391 async def _monitor_servers (self , poll_interval : float = 10.0 ):
365392 while True :
@@ -515,12 +542,9 @@ async def check_servers_health(self,
515542
516543 async def _check_server_health (self , server_url ) -> bool :
517544 """Check if a server is healthy by querying its health endpoint"""
518- if not self ._session :
519- self ._session = aiohttp .ClientSession ()
520-
521545 assert self ._health_check_timeout is not None , "health_check_timeout is not set"
522546 try :
523- async with self ._session .get (
547+ async with self .session .get (
524548 f"{ server_url } /health" ,
525549 timeout = self ._health_check_timeout ) as response :
526550 if response .status != 200 :
@@ -744,9 +768,10 @@ def __init__(self,
744768 # TODO: use max_num_tokens? per server?
745769 self ._max_batch_size = max_batch_size
746770
747- def _create_server_state (self , server ) :
771+ def _create_server_state (self , server : str ) -> KvCacheAwareServerState :
748772 return KvCacheAwareServerState (server , self ._use_tokens ,
749- self ._tokens_per_block )
773+ self ._tokens_per_block ,
774+ lambda : self .session )
750775
751776 async def get_next_server (
752777 self ,
@@ -792,11 +817,13 @@ async def get_next_server(
792817 "server_info" : self ._server_info .get (server , {}),
793818 }
794819
795- async def finish_request (self ,
796- request : OpenAIRequest ,
797- session : Optional [aiohttp .ClientSession ] = None ):
820+ async def finish_request (self , request : OpenAIRequest ):
798821 async with self ._lock :
799- await self ._unregister_request (request , session = session )
822+ server = self ._req_routing_table .pop (id (request ), None )
823+ if server is not None and server in self ._server_state :
824+ await self ._server_state [server ].decrement_load (request )
825+ if server is not None and server in self ._server_state :
826+ await self ._server_state [server ].poll_and_update ()
800827
801828 def _on_servers_updated (self , old_servers , new_servers ):
802829 new_state = {}
0 commit comments