Skip to content

Commit a5fa521

Browse files
committed
Share aiohttp.ClientSessions per worker
Slightly refactor `openAIModelServerClient` to accept a custom `aiohttp.ClientSession` per request, which allows us to use exactly 1 client session per worker. Prior to this commit, a new `aiohttp.ClientSession` is created for each request. Not only is this inefficient and lowers throughput, on certain environments, it also leads to inotify watch issues: aiodns - WARNING - Failed to create DNS resolver channel with automatic monitoring of resolver configuration changes. This usually means the system ran out of inotify watches. Falling back to socket state callback. Consider increasing the system inotify watch limit: Failed to initialize c-ares channel Indeed, because each DNS resolver is created for a new `ClientSession`, creating tons of new `ClientSession`s causes eventual inotify watch exhaustion. Sharing `ClientSession`s solves this issue. Relevant links: - https://docs.aiohttp.org/en/stable/http_request_lifecycle.html - https://stackoverflow.com/questions/62707369/one-aiohttp-clientsession-per-thread - home-assistant/core#144457 (comment) Relevant PR: kubernetes-sigs#247 (doesn't address the issue of worker sharing).
1 parent 651d176 commit a5fa521

File tree

2 files changed

+67
-43
lines changed

2 files changed

+67
-43
lines changed

inference_perf/client/modelserver/openai_client.py

Lines changed: 64 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import List, Optional
2222
import aiohttp
2323
import asyncio
24+
import copy
2425
import json
2526
import time
2627
import logging
@@ -29,7 +30,17 @@
2930
logger = logging.getLogger(__name__)
3031

3132

33+
class openAIHTTPClientSession(aiohttp.ClientSession):
34+
def __init__(self, timeout: float | None, max_tcp_connections: int) -> None:
35+
super().__init__(
36+
timeout=aiohttp.ClientTimeout(total=timeout) if timeout else aiohttp.helpers.sentinel,
37+
connector=aiohttp.TCPConnector(limit=max_tcp_connections),
38+
)
39+
40+
3241
class openAIModelServerClient(ModelServerClient):
42+
_session: aiohttp.ClientSession | None = None
43+
3344
def __init__(
3445
self,
3546
metrics_collector: RequestDataCollector,
@@ -70,7 +81,19 @@ def __init__(
7081
tokenizer_config = CustomTokenizerConfig(pretrained_model_name_or_path=self.model_name)
7182
self.tokenizer = CustomTokenizer(tokenizer_config)
7283

73-
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
84+
def new_session(self) -> openAIHTTPClientSession:
85+
return openAIHTTPClientSession(timeout=self.timeout, max_tcp_connections=self.max_tcp_connections)
86+
87+
async def process_request(
88+
self,
89+
data: InferenceAPIData,
90+
stage_id: int,
91+
scheduled_time: float,
92+
session: openAIHTTPClientSession | None = None,
93+
) -> None:
94+
custom_session = session is not None
95+
session = session or self.new_session()
96+
7497
payload = data.to_payload(
7598
model_name=self.model_name,
7699
max_tokens=self.max_completion_tokens,
@@ -86,57 +109,56 @@ async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled
86109
headers.update(self.api_config.headers)
87110

88111
request_data = json.dumps(payload)
112+
start = time.perf_counter()
89113

90-
timeout = aiohttp.ClientTimeout(total=self.timeout) if self.timeout else aiohttp.helpers.sentinel
114+
try:
115+
async with session.post(self.uri + data.get_route(), headers=headers, data=request_data) as response:
116+
response_info = await data.process_response(
117+
response=response, config=self.api_config, tokenizer=self.tokenizer
118+
)
119+
response_content = await response.text()
120+
121+
end_time = time.perf_counter()
122+
error = None
123+
if response.status != 200:
124+
error = ErrorResponseInfo(error_msg=response_content, error_type="Error response")
91125

92-
async with aiohttp.ClientSession(
93-
connector=aiohttp.TCPConnector(limit=self.max_tcp_connections), timeout=timeout
94-
) as session:
95-
start = time.perf_counter()
96-
try:
97-
async with session.post(self.uri + data.get_route(), headers=headers, data=request_data) as response:
98-
response_info = await data.process_response(
99-
response=response, config=self.api_config, tokenizer=self.tokenizer
100-
)
101-
response_content = await response.text()
102-
103-
end_time = time.perf_counter()
104-
error = None
105-
if response.status != 200:
106-
error = ErrorResponseInfo(error_msg=response_content, error_type="Error response")
107-
108-
self.metrics_collector.record_metric(
109-
RequestLifecycleMetric(
110-
stage_id=stage_id,
111-
request_data=request_data,
112-
response_data=response_content,
113-
info=response_info,
114-
error=error,
115-
start_time=start,
116-
end_time=end_time,
117-
scheduled_time=scheduled_time,
118-
)
119-
)
120-
except Exception as e:
121-
if isinstance(e, asyncio.exceptions.TimeoutError):
122-
logger.error("request timed out:", exc_info=True)
123-
else:
124-
logger.error("error occured during request processing:", exc_info=True)
125126
self.metrics_collector.record_metric(
126127
RequestLifecycleMetric(
127128
stage_id=stage_id,
128129
request_data=request_data,
129-
response_data=response_content if "response_content" in locals() else "",
130-
info=response_info if "response_info" in locals() else InferenceInfo(),
131-
error=ErrorResponseInfo(
132-
error_msg=str(e),
133-
error_type=type(e).__name__,
134-
),
130+
response_data=response_content,
131+
info=response_info,
132+
error=error,
135133
start_time=start,
136-
end_time=time.perf_counter(),
134+
end_time=end_time,
137135
scheduled_time=scheduled_time,
138136
)
139137
)
138+
except Exception as e:
139+
if isinstance(e, asyncio.exceptions.TimeoutError):
140+
logger.error("request timed out:", exc_info=True)
141+
else:
142+
logger.error("error occured during request processing:", exc_info=True)
143+
self.metrics_collector.record_metric(
144+
RequestLifecycleMetric(
145+
stage_id=stage_id,
146+
request_data=request_data,
147+
response_data=response_content if "response_content" in locals() else "",
148+
info=response_info if "response_info" in locals() else InferenceInfo(),
149+
error=ErrorResponseInfo(
150+
error_msg=str(e),
151+
error_type=type(e).__name__,
152+
),
153+
start_time=start,
154+
end_time=time.perf_counter(),
155+
scheduled_time=scheduled_time,
156+
)
157+
)
158+
finally:
159+
# close our session if it wasn't a shared one.
160+
if not custom_session:
161+
await session.close()
140162

141163
def get_supported_apis(self) -> List[APIType]:
142164
return []

inference_perf/loadgen/load_generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
async def loop(self) -> None:
7878
semaphore = Semaphore(self.max_concurrency)
7979
tasks = []
80+
session = self.client.new_session()
8081
event_loop = get_event_loop()
8182
item = None
8283
timeout = 0.5
@@ -118,7 +119,7 @@ async def schedule_client(
118119
with self.active_requests_counter.get_lock():
119120
self.active_requests_counter.value += 1
120121
inflight = True
121-
await self.client.process_request(request_data, stage_id, request_time)
122+
await self.client.process_request(request_data, stage_id, request_time, session=session)
122123
except CancelledError:
123124
pass
124125
finally:
@@ -149,6 +150,7 @@ async def schedule_client(
149150
logger.debug(f"[Worker {self.id}] waiting for next phase")
150151
self.request_phase.wait()
151152

153+
await session.close()
152154
logger.debug(f"[Worker {self.id}] stopped")
153155

154156
def run(self) -> None:

0 commit comments

Comments
 (0)