Skip to content

Commit bce18cc

Browse files
committed
Share aiohttp.ClientSessions per worker
Slightly refactor `openAIModelServerClient` to accept a custom `aiohttp.ClientSession` per request, which allows us to use exactly 1 client session per worker. Prior to this commit, a new `aiohttp.ClientSession` is created for each request. Not only is this inefficient and lowers throughput, on certain environments, it also leads to inotify watch issues: aiodns - WARNING - Failed to create DNS resolver channel with automatic monitoring of resolver configuration changes. This usually means the system ran out of inotify watches. Falling back to socket state callback. Consider increasing the system inotify watch limit: Failed to initialize c-ares channel Indeed, because each DNS resolver is created for a new `ClientSession`, creating tons of new `ClientSession`s causes eventual inotify watch exhaustion. Sharing `ClientSession`s solves this issue. Relevant links: - https://docs.aiohttp.org/en/stable/http_request_lifecycle.html - https://stackoverflow.com/questions/62707369/one-aiohttp-clientsession-per-thread - home-assistant/core#144457 (comment) Relevant PR: kubernetes-sigs#247 (doesn't address the issue of worker sharing).
1 parent 651d176 commit bce18cc

File tree

3 files changed

+60
-13
lines changed

3 files changed

+60
-13
lines changed

inference_perf/client/modelserver/base.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
15-
from typing import List, Optional, Tuple
15+
from typing import List, Optional, Tuple, Any
1616
from inference_perf.client.metricsclient.base import MetricsMetadata
1717
from inference_perf.config import APIConfig, APIType
18-
1918
from inference_perf.apis import InferenceAPIData
19+
import aiohttp
20+
import copy
2021

2122

2223
class ModelServerPrometheusMetric:
@@ -87,10 +88,35 @@ def get_supported_apis(self) -> List[APIType]:
8788
raise NotImplementedError
8889

8990
@abstractmethod
90-
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
91+
async def process_request(
92+
self, data: InferenceAPIData, stage_id: int, scheduled_time: float, *args: Any, **kwargs: Any
93+
) -> None:
9194
raise NotImplementedError
9295

9396
@abstractmethod
9497
def get_prometheus_metric_metadata(self) -> PrometheusMetricMetadata:
9598
# assumption: all metrics clients have metrics exported in Prometheus format
9699
raise NotImplementedError
100+
101+
102+
class ReusableHTTPClientSession:
103+
"""
104+
A wrapper for aiohttp.ClientSession to allow for reusable sessions.
105+
This is useful for sharing among many HTTP clients.
106+
"""
107+
108+
def __init__(self, session: aiohttp.ClientSession, dont_close: bool = False) -> None:
109+
self.session = session
110+
self.dont_close = dont_close
111+
112+
def dont_close_if(self, dont_close: bool = True) -> "ReusableHTTPClientSession":
113+
return ReusableHTTPClientSession(session=self.session, dont_close=dont_close)
114+
115+
async def __aenter__(self) -> None:
116+
pass
117+
118+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def]
119+
if self.dont_close:
120+
self.dont_close = False
121+
return
122+
await self.session.close()

inference_perf/client/modelserver/openai_client.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from inference_perf.config import APIConfig, APIType, CustomTokenizerConfig
1818
from inference_perf.apis import InferenceAPIData, InferenceInfo, RequestLifecycleMetric, ErrorResponseInfo
1919
from inference_perf.utils import CustomTokenizer
20-
from .base import ModelServerClient, PrometheusMetricMetadata
20+
from .base import ModelServerClient, PrometheusMetricMetadata, ReusableHTTPClientSession
2121
from typing import List, Optional
2222
import aiohttp
2323
import asyncio
@@ -30,6 +30,8 @@
3030

3131

3232
class openAIModelServerClient(ModelServerClient):
33+
_session: aiohttp.ClientSession | None = None
34+
3335
def __init__(
3436
self,
3537
metrics_collector: RequestDataCollector,
@@ -70,7 +72,24 @@ def __init__(
7072
tokenizer_config = CustomTokenizerConfig(pretrained_model_name_or_path=self.model_name)
7173
self.tokenizer = CustomTokenizer(tokenizer_config)
7274

73-
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
75+
def new_reusable_session(self) -> ReusableHTTPClientSession:
76+
return ReusableHTTPClientSession(
77+
aiohttp.ClientSession(
78+
timeout=aiohttp.ClientTimeout(total=self.timeout) if self.timeout else aiohttp.helpers.sentinel,
79+
connector=aiohttp.TCPConnector(limit=self.max_tcp_connections),
80+
)
81+
)
82+
83+
async def process_request(
84+
self,
85+
data: InferenceAPIData,
86+
stage_id: int,
87+
scheduled_time: float,
88+
session: Optional[ReusableHTTPClientSession] = None,
89+
) -> None:
90+
reusing_session = session is not None
91+
session = session or self.new_reusable_session()
92+
7493
payload = data.to_payload(
7594
model_name=self.model_name,
7695
max_tokens=self.max_completion_tokens,
@@ -87,14 +106,10 @@ async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled
87106

88107
request_data = json.dumps(payload)
89108

90-
timeout = aiohttp.ClientTimeout(total=self.timeout) if self.timeout else aiohttp.helpers.sentinel
91-
92-
async with aiohttp.ClientSession(
93-
connector=aiohttp.TCPConnector(limit=self.max_tcp_connections), timeout=timeout
94-
) as session:
109+
async with session.dont_close_if(reusing_session):
95110
start = time.perf_counter()
96111
try:
97-
async with session.post(self.uri + data.get_route(), headers=headers, data=request_data) as response:
112+
async with session.session.post(self.uri + data.get_route(), headers=headers, data=request_data) as response:
98113
response_info = await data.process_response(
99114
response=response, config=self.api_config, tokenizer=self.tokenizer
100115
)

inference_perf/loadgen/load_generator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from inference_perf.datagen import DataGenerator
1717
from inference_perf.apis import InferenceAPIData
1818
from inference_perf.client.modelserver import ModelServerClient
19+
from inference_perf.client.modelserver.openai_client import openAIModelServerClient
1920
from inference_perf.circuit_breaker import get_circuit_breaker
2021
from inference_perf.config import LoadConfig, LoadStage, LoadType, StageGenType
2122
from asyncio import (
@@ -29,7 +30,7 @@
2930
set_event_loop_policy,
3031
get_event_loop,
3132
)
32-
from typing import List, Tuple, TypeAlias, Optional
33+
from typing import List, Tuple, TypeAlias, Optional, Any
3334
from types import FrameType
3435
import time
3536
import multiprocessing as mp
@@ -81,6 +82,10 @@ async def loop(self) -> None:
8182
item = None
8283
timeout = 0.5
8384

85+
session: Any = None
86+
if issubclass(type(self.client), openAIModelServerClient):
87+
session = self.client.new_reusable_session()
88+
8489
while not self.stop_signal.is_set():
8590
while self.request_phase.is_set() and not self.cancel_signal.is_set():
8691
await semaphore.acquire()
@@ -118,7 +123,7 @@ async def schedule_client(
118123
with self.active_requests_counter.get_lock():
119124
self.active_requests_counter.value += 1
120125
inflight = True
121-
await self.client.process_request(request_data, stage_id, request_time)
126+
await self.client.process_request(request_data, stage_id, request_time, session=session)
122127
except CancelledError:
123128
pass
124129
finally:
@@ -149,6 +154,7 @@ async def schedule_client(
149154
logger.debug(f"[Worker {self.id}] waiting for next phase")
150155
self.request_phase.wait()
151156

157+
await session.close()
152158
logger.debug(f"[Worker {self.id}] stopped")
153159

154160
def run(self) -> None:

0 commit comments

Comments
 (0)