Skip to content

Commit ba8c516

Browse files
committed
Share aiohttp.ClientSessions per worker
Slightly refactor `openAIModelServerClient` to accept a custom `aiohttp.ClientSession` per request, which allows us to use exactly 1 client session per worker. Prior to this commit, a new `aiohttp.ClientSession` is created for each request. Not only is this inefficient and lowers throughput, on certain environments, it also leads to inotify watch issues: aiodns - WARNING - Failed to create DNS resolver channel with automatic monitoring of resolver configuration changes. This usually means the system ran out of inotify watches. Falling back to socket state callback. Consider increasing the system inotify watch limit: Failed to initialize c-ares channel Indeed, because each DNS resolver is created for a new `ClientSession`, creating tons of new `ClientSession`s causes eventual inotify watch exhaustion. Sharing `ClientSession`s solves this issue. Relevant links: - https://docs.aiohttp.org/en/stable/http_request_lifecycle.html - https://stackoverflow.com/questions/62707369/one-aiohttp-clientsession-per-thread - home-assistant/core#144457 (comment) Relevant PR: kubernetes-sigs#247 (doesn't address the issue of worker sharing).
1 parent dab80ce commit ba8c516

File tree

3 files changed

+59
-13
lines changed

3 files changed

+59
-13
lines changed

inference_perf/client/modelserver/base.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
15-
from typing import List, Optional, Tuple
15+
from typing import List, Optional, Tuple, Any
1616
from inference_perf.client.metricsclient.base import MetricsMetadata
1717
from inference_perf.config import APIConfig, APIType
18-
1918
from inference_perf.apis import InferenceAPIData
19+
import aiohttp
2020

2121

2222
class ModelServerPrometheusMetric:
@@ -87,10 +87,35 @@ def get_supported_apis(self) -> List[APIType]:
8787
raise NotImplementedError
8888

8989
@abstractmethod
90-
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
90+
async def process_request(
91+
self, data: InferenceAPIData, stage_id: int, scheduled_time: float, *args: Any, **kwargs: Any
92+
) -> None:
9193
raise NotImplementedError
9294

9395
@abstractmethod
9496
def get_prometheus_metric_metadata(self) -> PrometheusMetricMetadata:
9597
# assumption: all metrics clients have metrics exported in Prometheus format
9698
raise NotImplementedError
99+
100+
101+
class ReusableHTTPClientSession:
102+
"""
103+
A wrapper for aiohttp.ClientSession to allow for reusable sessions.
104+
This is useful for sharing among many HTTP clients.
105+
"""
106+
107+
def __init__(self, session: aiohttp.ClientSession, dont_close: bool = False) -> None:
108+
self.session = session
109+
self.dont_close = dont_close
110+
111+
def dont_close_if(self, dont_close: bool = True) -> "ReusableHTTPClientSession":
112+
return ReusableHTTPClientSession(session=self.session, dont_close=dont_close)
113+
114+
async def __aenter__(self) -> None:
115+
pass
116+
117+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def]
118+
if self.dont_close:
119+
self.dont_close = False
120+
return
121+
await self.session.close()

inference_perf/client/modelserver/openai_client.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from inference_perf.config import APIConfig, APIType, CustomTokenizerConfig
1818
from inference_perf.apis import InferenceAPIData, InferenceInfo, RequestLifecycleMetric, ErrorResponseInfo
1919
from inference_perf.utils import CustomTokenizer
20-
from .base import ModelServerClient, PrometheusMetricMetadata
20+
from .base import ModelServerClient, PrometheusMetricMetadata, ReusableHTTPClientSession
2121
from typing import List, Optional
2222
import aiohttp
2323
import asyncio
@@ -30,6 +30,8 @@
3030

3131

3232
class openAIModelServerClient(ModelServerClient):
33+
_session: aiohttp.ClientSession | None = None
34+
3335
def __init__(
3436
self,
3537
metrics_collector: RequestDataCollector,
@@ -70,7 +72,16 @@ def __init__(
7072
tokenizer_config = CustomTokenizerConfig(pretrained_model_name_or_path=self.model_name)
7173
self.tokenizer = CustomTokenizer(tokenizer_config)
7274

73-
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
75+
async def process_request(
76+
self,
77+
data: InferenceAPIData,
78+
stage_id: int,
79+
scheduled_time: float,
80+
session: Optional[ReusableHTTPClientSession] = None,
81+
) -> None:
82+
reusing_session = session is not None
83+
session = session or self.new_reusable_session()
84+
7485
payload = data.to_payload(
7586
model_name=self.model_name,
7687
max_tokens=self.max_completion_tokens,
@@ -87,14 +98,10 @@ async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled
8798

8899
request_data = json.dumps(payload)
89100

90-
timeout = aiohttp.ClientTimeout(total=self.timeout) if self.timeout else aiohttp.helpers.sentinel
91-
92-
async with aiohttp.ClientSession(
93-
connector=aiohttp.TCPConnector(limit=self.max_tcp_connections), timeout=timeout
94-
) as session:
101+
async with session.dont_close_if(reusing_session):
95102
start = time.perf_counter()
96103
try:
97-
async with session.post(self.uri + data.get_route(), headers=headers, data=request_data) as response:
104+
async with session.session.post(self.uri + data.get_route(), headers=headers, data=request_data) as response:
98105
response_info = await data.process_response(
99106
response=response, config=self.api_config, tokenizer=self.tokenizer
100107
)
@@ -138,6 +145,14 @@ async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled
138145
)
139146
)
140147

148+
def new_reusable_session(self) -> ReusableHTTPClientSession:
149+
return ReusableHTTPClientSession(
150+
aiohttp.ClientSession(
151+
timeout=aiohttp.ClientTimeout(total=self.timeout) if self.timeout else aiohttp.helpers.sentinel,
152+
connector=aiohttp.TCPConnector(limit=self.max_tcp_connections),
153+
)
154+
)
155+
141156
def get_supported_apis(self) -> List[APIType]:
142157
return []
143158

inference_perf/loadgen/load_generator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from inference_perf.datagen import DataGenerator
1919
from inference_perf.apis import InferenceAPIData
2020
from inference_perf.client.modelserver import ModelServerClient
21+
from inference_perf.client.modelserver.openai_client import openAIModelServerClient
2122
from inference_perf.circuit_breaker import get_circuit_breaker
2223
from inference_perf.config import LoadConfig, LoadStage, LoadType, StageGenType, TraceFormat
2324
from asyncio import (
@@ -31,7 +32,7 @@
3132
set_event_loop_policy,
3233
get_event_loop,
3334
)
34-
from typing import List, Tuple, TypeAlias, Optional
35+
from typing import List, Tuple, TypeAlias, Optional, Any
3536
from types import FrameType
3637
import time
3738
import multiprocessing as mp
@@ -83,6 +84,10 @@ async def loop(self) -> None:
8384
item = None
8485
timeout = 0.5
8586

87+
session: Any = None
88+
if issubclass(type(self.client), openAIModelServerClient):
89+
session = self.client.new_reusable_session()
90+
8691
while not self.stop_signal.is_set():
8792
while self.request_phase.is_set() and not self.cancel_signal.is_set():
8893
await semaphore.acquire()
@@ -120,7 +125,7 @@ async def schedule_client(
120125
with self.active_requests_counter.get_lock():
121126
self.active_requests_counter.value += 1
122127
inflight = True
123-
await self.client.process_request(request_data, stage_id, request_time)
128+
await self.client.process_request(request_data, stage_id, request_time, session=session)
124129
except CancelledError:
125130
pass
126131
finally:
@@ -151,6 +156,7 @@ async def schedule_client(
151156
logger.debug(f"[Worker {self.id}] waiting for next phase")
152157
self.request_phase.wait()
153158

159+
await session.close()
154160
logger.debug(f"[Worker {self.id}] stopped")
155161

156162
def run(self) -> None:

0 commit comments

Comments
 (0)