Skip to content

Commit f3e458e

Browse files
authored
[None][fix] Consolidate aiohttp session management in disagg router (#13408)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
1 parent 2f745de commit f3e458e

11 files changed

Lines changed: 524 additions & 80 deletions

File tree

tensorrt_llm/serve/openai_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -882,9 +882,9 @@ async def get_perf_metrics(self) -> JSONResponse:
882882
async def get_kv_cache_events(self) -> JSONResponse:
883883
events = []
884884
try:
885-
async for event in self.generator.get_kv_cache_events_async(2):
885+
async for event in self.generator.get_kv_cache_events_async(0):
886886
events.append(event)
887-
except IndexError:
887+
except (IndexError, asyncio.QueueEmpty):
888888
# queue is empty, no more events
889889
pass
890890
return JSONResponse(content=events)

tensorrt_llm/serve/router.py

Lines changed: 80 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,25 @@ def get_request_num_tokens(request: OpenAIRequest) -> int:
4848

4949
class ServerState:
5050

51-
def __init__(self, server: str, use_tokens: bool = False):
51+
def __init__(
52+
self,
53+
server: str,
54+
use_tokens: bool = False,
55+
session_provider: Optional[Callable[[],
56+
aiohttp.ClientSession]] = None):
5257
self._server = server
58+
self._base_url = server if server.startswith(
59+
"http") else f"http://{server}"
5360
self._num_active_requests = 0
5461
self._num_active_tokens = 0
5562
self._use_tokens = use_tokens
63+
self._session_provider = session_provider
5664
self._lock = asyncio.Lock()
5765

66+
@property
67+
def _session(self) -> Optional[aiohttp.ClientSession]:
68+
return self._session_provider() if self._session_provider else None
69+
5870
async def increment_load(self, request: OpenAIRequest):
5971
num_tokens = get_request_num_tokens(request) if self._use_tokens else 0
6072
async with self._lock:
@@ -69,19 +81,23 @@ async def decrement_load(self, request: OpenAIRequest):
6981

7082
async def is_healthy(self) -> bool:
7183
try:
72-
async with self._session.get(self._server + "/health") as response:
84+
async with self._session.get(
85+
f"{self._base_url}/health") as response:
7386
return response.status == 200
7487
except Exception:
7588
return False
7689

7790

7891
class KvCacheAwareServerState(ServerState):
7992

80-
def __init__(self,
81-
server: str,
82-
use_tokens: bool = False,
83-
tokens_per_block: int = 32):
84-
super().__init__(server, use_tokens)
93+
def __init__(
94+
self,
95+
server: str,
96+
use_tokens: bool = False,
97+
tokens_per_block: int = 32,
98+
session_provider: Optional[Callable[[],
99+
aiohttp.ClientSession]] = None):
100+
super().__init__(server, use_tokens, session_provider)
85101
self._kv_cache_block_table: set[int] = set()
86102
self._tokens_per_block = tokens_per_block
87103

@@ -108,7 +124,8 @@ def update_with_events(self, events: Iterable[dict]):
108124
self.remove_blocks(event["block_hashes"])
109125

110126
async def poll_events(self, session: aiohttp.ClientSession):
111-
async with session.post(self._server + "/kv_cache_events") as response:
127+
async with session.post(
128+
f"{self._base_url}/kv_cache_events") as response:
112129
events_raw = await response.json()
113130
return events_raw
114131

@@ -124,19 +141,23 @@ async def matched_tokens(self, block_hashes: list[list[int]]) -> int:
124141
break
125142
return match_count
126143

127-
async def decrement_load(self,
128-
request: OpenAIRequest,
129-
session: Optional[aiohttp.ClientSession] = None):
144+
async def decrement_load(self, request: OpenAIRequest):
130145
num_tokens = get_request_num_tokens(request) if self._use_tokens else 0
131-
if session is not None:
132-
events_raw = await self.poll_events(session)
133-
else:
134-
events_raw = None
135146
async with self._lock:
136147
self._num_active_requests -= 1
137148
self._num_active_tokens -= num_tokens
138-
if events_raw is not None:
139-
self.update_with_events(events_raw)
149+
150+
async def poll_and_update(self):
151+
"""Poll KV cache events and update block table. Called outside the critical path."""
152+
try:
153+
assert self._session is not None, "session must be set on KvCacheAwareServerState"
154+
events_raw = await self.poll_events(self._session)
155+
async with self._lock:
156+
if events_raw is not None:
157+
self.update_with_events(events_raw)
158+
except Exception as e:
159+
logger.warning(
160+
f"Failed to poll KV cache events from {self._server}: {e}")
140161

141162
def num_active_tokens(self):
142163
return self._num_active_tokens
@@ -165,7 +186,8 @@ def _init_load_balancing(self,
165186
self._server_state[server] = self._create_server_state(server)
166187

167188
def _create_server_state(self, server: str) -> ServerState:
168-
return self._server_state_class(server, self._use_tokens)
189+
return self._server_state_class(server, self._use_tokens,
190+
lambda: self.session)
169191

170192
def _get_server_load(self, server: str) -> int:
171193
state = self._server_state[server]
@@ -185,11 +207,12 @@ async def _register_request(self, server: str, request: OpenAIRequest):
185207
await self._server_state[server].increment_load(request)
186208
self._req_routing_table[id(request)] = server
187209

188-
async def _unregister_request(self, request: OpenAIRequest,
189-
**kwargs) -> str:
190-
server = self._req_routing_table.pop(id(request))
210+
async def _unregister_request(self, request: OpenAIRequest) -> str:
211+
server = self._req_routing_table.pop(id(request), None)
212+
if server is None:
213+
return ""
191214
if server in self._server_state:
192-
await self._server_state[server].decrement_load(request, **kwargs)
215+
await self._server_state[server].decrement_load(request)
193216
return server
194217

195218
def _select_least_loaded(self,
@@ -231,6 +254,17 @@ def __init__(
231254
self._server_preparation_func = server_preparation_func
232255
self._prepared_ready_servers: set[str] = set()
233256

257+
async def close(self):
258+
"""Close the shared HTTP session."""
259+
if self._session:
260+
try:
261+
await self._session.close()
262+
self._session = None
263+
logger.debug("HTTP session closed")
264+
except Exception as e:
265+
logger.error(f"Error closing session: {e}")
266+
self._session = None
267+
234268
@abstractmethod
235269
def _on_servers_updated(self, old_servers, new_servers):
236270
"""Called when the server list changes.
@@ -247,19 +281,21 @@ def _on_servers_updated(self, old_servers, new_servers):
247281
def servers(self) -> List[str]:
248282
return self._servers
249283

284+
@staticmethod
285+
def _ensure_url(server: str) -> str:
286+
return server if server.startswith("http") else f"http://{server}"
287+
250288
async def _fetch_server_info(self, server: str, timeout: float) -> dict:
251-
session = aiohttp.ClientSession()
252289
try:
253-
async with session.get(f"http://{server}/server_info",
254-
timeout=timeout) as response:
290+
url = self._ensure_url(server)
291+
async with self.session.get(f"{url}/server_info",
292+
timeout=timeout) as response:
255293
return await response.json()
256294
except Exception as e:
257295
logger.warning(
258296
f"Error fetching server info for server {server}: {e}")
259297
raise RuntimeError(
260298
f"Failed to fetch server info for server {server}") from e
261-
finally:
262-
await session.close()
263299

264300
async def _prepare_server(self, server: str):
265301
if server in self._prepared_ready_servers:
@@ -322,15 +358,17 @@ async def get_next_server(
322358
async def finish_request(self, request: OpenAIRequest):
323359
pass
324360

361+
@property
362+
def session(self) -> aiohttp.ClientSession:
363+
if not self._session:
364+
self._session = aiohttp.ClientSession()
365+
return self._session
366+
325367
async def start_server_monitoring(self, poll_interval: float = 10.0):
326368
"""Start monitoring servers update from metadata service"""
327369
if not self._metadata_server:
328370
raise RuntimeError("Metadata server is not initialized")
329371

330-
# Create a session for health checks if it doesn't exist
331-
if not self._session:
332-
self._session = aiohttp.ClientSession()
333-
334372
logger.info(
335373
f"Starting server monitoring for {self._server_role} servers")
336374
self._monitor_task = asyncio.create_task(
@@ -348,18 +386,7 @@ async def stop_server_monitoring(self):
348386
pass
349387
self._monitor_task = None
350388

351-
# Close session when stopping monitoring
352-
await self.close_session()
353-
354-
async def close_session(self):
355-
if self._session:
356-
try:
357-
await self._session.close()
358-
self._session = None
359-
logger.debug("HTTP session closed")
360-
except Exception as e:
361-
logger.error(f"Error closing session: {e}")
362-
self._session = None
389+
await self.close()
363390

364391
async def _monitor_servers(self, poll_interval: float = 10.0):
365392
while True:
@@ -515,12 +542,9 @@ async def check_servers_health(self,
515542

516543
async def _check_server_health(self, server_url) -> bool:
517544
"""Check if a server is healthy by querying its health endpoint"""
518-
if not self._session:
519-
self._session = aiohttp.ClientSession()
520-
521545
assert self._health_check_timeout is not None, "health_check_timeout is not set"
522546
try:
523-
async with self._session.get(
547+
async with self.session.get(
524548
f"{server_url}/health",
525549
timeout=self._health_check_timeout) as response:
526550
if response.status != 200:
@@ -744,9 +768,10 @@ def __init__(self,
744768
# TODO: use max_num_tokens? per server?
745769
self._max_batch_size = max_batch_size
746770

747-
def _create_server_state(self, server):
771+
def _create_server_state(self, server: str) -> KvCacheAwareServerState:
748772
return KvCacheAwareServerState(server, self._use_tokens,
749-
self._tokens_per_block)
773+
self._tokens_per_block,
774+
lambda: self.session)
750775

751776
async def get_next_server(
752777
self,
@@ -792,11 +817,13 @@ async def get_next_server(
792817
"server_info": self._server_info.get(server, {}),
793818
}
794819

795-
async def finish_request(self,
796-
request: OpenAIRequest,
797-
session: Optional[aiohttp.ClientSession] = None):
820+
async def finish_request(self, request: OpenAIRequest):
798821
async with self._lock:
799-
await self._unregister_request(request, session=session)
822+
server = self._req_routing_table.pop(id(request), None)
823+
if server is not None and server in self._server_state:
824+
await self._server_state[server].decrement_load(request)
825+
if server is not None and server in self._server_state:
826+
await self._server_state[server].poll_and_update()
800827

801828
def _on_servers_updated(self, old_servers, new_servers):
802829
new_state = {}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
hostname: localhost
2+
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
3+
backend: pytorch
4+
cuda_graph_config: null
5+
disable_overlap_scheduler: true
6+
enable_autotuner: false
7+
perf_metrics_max_requests: 1000
8+
context_servers:
9+
num_instances: 2
10+
tensor_parallel_size: 1
11+
pipeline_parallel_size: 1
12+
router:
13+
type: conversation
14+
return_perf_metrics: true
15+
perf_metrics_max_requests: 1000
16+
kv_cache_config:
17+
enable_block_reuse: true
18+
event_buffer_max_size: 1024
19+
free_gpu_memory_fraction: 0.1
20+
cache_transceiver_config:
21+
backend: DEFAULT
22+
generation_servers:
23+
num_instances: 1
24+
tensor_parallel_size: 1
25+
pipeline_parallel_size: 1
26+
return_perf_metrics: true
27+
perf_metrics_max_requests: 1000
28+
kv_cache_config:
29+
free_gpu_memory_fraction: 0.1
30+
cache_transceiver_config:
31+
backend: DEFAULT
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
2+
hostname: localhost
3+
backend: pytorch
4+
cuda_graph_config: null
5+
free_gpu_memory_fraction: 0.1
6+
disable_overlap_scheduler: true
7+
enable_autotuner: false
8+
context_servers:
9+
num_instances: 2
10+
router:
11+
type: conversation
12+
max_batch_size: 16
13+
max_num_tokens: 3000
14+
max_seq_len: 4096
15+
tensor_parallel_size: 1
16+
pipeline_parallel_size: 1
17+
kv_cache_config:
18+
enable_block_reuse: true
19+
enable_partial_reuse: false
20+
event_buffer_max_size: 1024
21+
free_gpu_memory_fraction: 0.1
22+
cache_transceiver_config:
23+
backend: DEFAULT
24+
generation_servers:
25+
num_instances: 1
26+
max_batch_size: 256
27+
max_num_tokens: 4096
28+
max_seq_len: 4096
29+
tensor_parallel_size: 1
30+
pipeline_parallel_size: 1
31+
cache_transceiver_config:
32+
backend: DEFAULT
33+
kv_cache_config:
34+
free_gpu_memory_fraction: 0.1
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
hostname: localhost
2+
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
3+
backend: pytorch
4+
cuda_graph_config: null
5+
disable_overlap_scheduler: true
6+
enable_autotuner: false
7+
perf_metrics_max_requests: 1000
8+
context_servers:
9+
num_instances: 2
10+
tensor_parallel_size: 1
11+
pipeline_parallel_size: 1
12+
router:
13+
type: kv_cache_aware
14+
return_perf_metrics: true
15+
perf_metrics_max_requests: 1000
16+
kv_cache_config:
17+
enable_block_reuse: true
18+
event_buffer_max_size: 1024
19+
free_gpu_memory_fraction: 0.1
20+
cache_transceiver_config:
21+
backend: DEFAULT
22+
generation_servers:
23+
num_instances: 1
24+
tensor_parallel_size: 1
25+
pipeline_parallel_size: 1
26+
return_perf_metrics: true
27+
perf_metrics_max_requests: 1000
28+
kv_cache_config:
29+
free_gpu_memory_fraction: 0.1
30+
cache_transceiver_config:
31+
backend: DEFAULT
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
hostname: localhost
2+
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
3+
backend: pytorch
4+
cuda_graph_config: null
5+
disable_overlap_scheduler: true
6+
enable_autotuner: false
7+
perf_metrics_max_requests: 1000
8+
context_servers:
9+
num_instances: 2
10+
tensor_parallel_size: 1
11+
pipeline_parallel_size: 1
12+
router:
13+
type: load_balancing
14+
return_perf_metrics: true
15+
perf_metrics_max_requests: 1000
16+
kv_cache_config:
17+
free_gpu_memory_fraction: 0.1
18+
cache_transceiver_config:
19+
backend: DEFAULT
20+
generation_servers:
21+
num_instances: 1
22+
tensor_parallel_size: 1
23+
pipeline_parallel_size: 1
24+
return_perf_metrics: true
25+
perf_metrics_max_requests: 1000
26+
kv_cache_config:
27+
free_gpu_memory_fraction: 0.1
28+
cache_transceiver_config:
29+
backend: DEFAULT

0 commit comments

Comments
 (0)