Skip to content

Commit 39dbaea

Browse files
committed
Share aiohttp.ClientSessions per worker
Slightly refactor `openAIModelServerClient` to add a new method, `process_request_with_session`, that accepts a custom `ReusableHTTPClientSession` per request, which allows the caller to reuse an HTTP client session per worker. The previous method, `process_request`, is made to create a fresh HTTP client session then call `process_request_with_session`, preserving the previous behavior. Prior to this commit, a new `aiohttp.ClientSession` is created for each request. Not only is this inefficient and lowers throughput, on certain environments, it also leads to inotify watch issues: aiodns - WARNING - Failed to create DNS resolver channel with automatic monitoring of resolver configuration changes. This usually means the system ran out of inotify watches. Falling back to socket state callback. Consider increasing the system inotify watch limit: Failed to initialize c-ares channel Indeed, because each DNS resolver is created for a new `ClientSession`, creating tons of new `ClientSession`s causes eventual inotify watch exhaustion. Sharing `ClientSession`s solves this issue. Relevant links: - https://docs.aiohttp.org/en/stable/http_request_lifecycle.html - https://stackoverflow.com/questions/62707369/one-aiohttp-clientsession-per-thread - home-assistant/core#144457 (comment) Relevant PR: kubernetes-sigs#247 (doesn't address the issue of worker sharing).
1 parent 603fe14 commit 39dbaea

File tree

4 files changed

+112
-71
lines changed

4 files changed

+112
-71
lines changed

inference_perf/client/modelserver/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from .base import ModelServerClient
14+
from .base import ModelServerClient, ModelServerClientSession
1515
from .mock_client import MockModelServerClient
1616
from .vllm_client import vLLMModelServerClient
1717
from .sglang_client import SGlangModelServerClient
1818

1919

20-
__all__ = ["ModelServerClient", "MockModelServerClient", "vLLMModelServerClient", "SGlangModelServerClient"]
20+
__all__ = [
21+
"ModelServerClient",
22+
"ModelServerClientSession",
23+
"MockModelServerClient",
24+
"vLLMModelServerClient",
25+
"SGlangModelServerClient",
26+
]

inference_perf/client/modelserver/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from typing import List, Optional, Tuple
1616
from inference_perf.client.metricsclient.base import MetricsMetadata
1717
from inference_perf.config import APIConfig, APIType
18-
1918
from inference_perf.apis import InferenceAPIData
2019

2120

@@ -82,6 +81,9 @@ def __init__(self, api_config: APIConfig, timeout: Optional[float] = None, *args
8281
self.api_config = api_config
8382
self.timeout = timeout
8483

84+
def new_session(self) -> "ModelServerClientSession":
85+
return ModelServerClientSession(self)
86+
8587
@abstractmethod
8688
def get_supported_apis(self) -> List[APIType]:
8789
raise NotImplementedError
@@ -94,3 +96,14 @@ async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled
9496
def get_prometheus_metric_metadata(self) -> PrometheusMetricMetadata:
9597
# assumption: all metrics clients have metrics exported in Prometheus format
9698
raise NotImplementedError
99+
100+
101+
class ModelServerClientSession:
102+
def __init__(self, client: ModelServerClient):
103+
self.client = client
104+
105+
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
106+
await self.client.process_request(data, stage_id, scheduled_time)
107+
108+
async def close(self) -> None: # noqa - subclasses optionally override this
109+
pass

inference_perf/client/modelserver/openai_client.py

Lines changed: 87 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from inference_perf.config import APIConfig, APIType, CustomTokenizerConfig
1818
from inference_perf.apis import InferenceAPIData, InferenceInfo, RequestLifecycleMetric, ErrorResponseInfo
1919
from inference_perf.utils import CustomTokenizer
20-
from .base import ModelServerClient, PrometheusMetricMetadata
20+
from .base import ModelServerClient, ModelServerClientSession, PrometheusMetricMetadata
2121
from typing import List, Optional
2222
import aiohttp
2323
import asyncio
@@ -30,6 +30,8 @@
3030

3131

3232
class openAIModelServerClient(ModelServerClient):
33+
_session: aiohttp.ClientSession | None = None
34+
3335
def __init__(
3436
self,
3537
metrics_collector: RequestDataCollector,
@@ -70,73 +72,11 @@ def __init__(
7072
tokenizer_config = CustomTokenizerConfig(pretrained_model_name_or_path=self.model_name)
7173
self.tokenizer = CustomTokenizer(tokenizer_config)
7274

73-
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
74-
payload = data.to_payload(
75-
model_name=self.model_name,
76-
max_tokens=self.max_completion_tokens,
77-
ignore_eos=self.ignore_eos,
78-
streaming=self.api_config.streaming,
79-
)
80-
headers = {"Content-Type": "application/json"}
81-
82-
if self.api_key:
83-
headers["Authorization"] = f"Bearer {self.api_key}"
75+
def new_session(self) -> "ModelServerClientSession":
76+
return openAIModelServerClientSession(self)
8477

85-
if self.api_config.headers:
86-
headers.update(self.api_config.headers)
87-
88-
request_data = json.dumps(payload)
89-
90-
timeout = aiohttp.ClientTimeout(total=self.timeout) if self.timeout else aiohttp.helpers.sentinel
91-
92-
async with aiohttp.ClientSession(
93-
connector=aiohttp.TCPConnector(limit=self.max_tcp_connections), timeout=timeout
94-
) as session:
95-
start = time.perf_counter()
96-
try:
97-
async with session.post(self.uri + data.get_route(), headers=headers, data=request_data) as response:
98-
response_info = await data.process_response(
99-
response=response, config=self.api_config, tokenizer=self.tokenizer
100-
)
101-
response_content = await response.text()
102-
103-
end_time = time.perf_counter()
104-
error = None
105-
if response.status != 200:
106-
error = ErrorResponseInfo(error_msg=response_content, error_type="Error response")
107-
108-
self.metrics_collector.record_metric(
109-
RequestLifecycleMetric(
110-
stage_id=stage_id,
111-
request_data=request_data,
112-
response_data=response_content,
113-
info=response_info,
114-
error=error,
115-
start_time=start,
116-
end_time=end_time,
117-
scheduled_time=scheduled_time,
118-
)
119-
)
120-
except Exception as e:
121-
if isinstance(e, asyncio.exceptions.TimeoutError):
122-
logger.error("request timed out:", exc_info=True)
123-
else:
124-
logger.error("error occured during request processing:", exc_info=True)
125-
self.metrics_collector.record_metric(
126-
RequestLifecycleMetric(
127-
stage_id=stage_id,
128-
request_data=request_data,
129-
response_data=response_content if "response_content" in locals() else "",
130-
info=response_info if "response_info" in locals() else InferenceInfo(),
131-
error=ErrorResponseInfo(
132-
error_msg=str(e),
133-
error_type=type(e).__name__,
134-
),
135-
start_time=start,
136-
end_time=time.perf_counter(),
137-
scheduled_time=scheduled_time,
138-
)
139-
)
78+
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
79+
await self.new_session().process_request(data, stage_id, scheduled_time)
14080

14181
def get_supported_apis(self) -> List[APIType]:
14282
return []
@@ -157,3 +97,83 @@ def get_supported_models(self) -> List[str]:
15797
except Exception as e:
15898
logger.error(f"Got exception retrieving supported models {e}")
15999
return []
100+
101+
102+
class openAIModelServerClientSession(ModelServerClientSession):
103+
def __init__(self, client: openAIModelServerClient):
104+
self.client = client
105+
self.session = aiohttp.ClientSession(
106+
timeout=aiohttp.ClientTimeout(total=client.timeout) if client.timeout else aiohttp.helpers.sentinel,
107+
connector=aiohttp.TCPConnector(limit=client.max_tcp_connections),
108+
)
109+
110+
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
111+
payload = data.to_payload(
112+
model_name=self.client.model_name,
113+
max_tokens=self.client.max_completion_tokens,
114+
ignore_eos=self.client.ignore_eos,
115+
streaming=self.client.api_config.streaming,
116+
)
117+
headers = {"Content-Type": "application/json"}
118+
119+
if self.client.api_key:
120+
headers["Authorization"] = f"Bearer {self.client.api_key}"
121+
122+
if self.client.api_config.headers:
123+
headers.update(self.client.api_config.headers)
124+
125+
request_data = json.dumps(payload)
126+
response_info: InferenceInfo | None = None
127+
response_content: str | None = None
128+
129+
start = time.perf_counter()
130+
try:
131+
async with self.session.post(self.client.uri + data.get_route(), headers=headers, data=request_data) as response:
132+
response_info = await data.process_response(
133+
response=response,
134+
config=self.client.api_config,
135+
tokenizer=self.client.tokenizer,
136+
)
137+
response_content = await response.text()
138+
139+
end_time = time.perf_counter()
140+
error = None
141+
if response.status != 200:
142+
error = ErrorResponseInfo(error_msg=response_content, error_type="Error response")
143+
144+
self.client.metrics_collector.record_metric(
145+
RequestLifecycleMetric(
146+
stage_id=stage_id,
147+
request_data=request_data,
148+
response_data=response_content,
149+
info=response_info,
150+
error=error,
151+
start_time=start,
152+
end_time=end_time,
153+
scheduled_time=scheduled_time,
154+
)
155+
)
156+
except Exception as e:
157+
if isinstance(e, asyncio.exceptions.TimeoutError):
158+
logger.error("request timed out:", exc_info=True)
159+
else:
160+
logger.error("error occured during request processing:", exc_info=True)
161+
162+
self.client.metrics_collector.record_metric(
163+
RequestLifecycleMetric(
164+
stage_id=stage_id,
165+
request_data=request_data,
166+
response_data=response_content or "",
167+
info=response_info or InferenceInfo(),
168+
error=ErrorResponseInfo(
169+
error_msg=str(e),
170+
error_type=type(e).__name__,
171+
),
172+
start_time=start,
173+
end_time=time.perf_counter(),
174+
scheduled_time=scheduled_time,
175+
)
176+
)
177+
178+
async def close(self) -> None:
179+
await self.session.close()

inference_perf/loadgen/load_generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ async def loop(self) -> None:
9494
event_loop = get_event_loop()
9595
item = None
9696
timeout = 0.5
97+
session = self.client.new_session()
9798

9899
while not self.stop_signal.is_set():
99100
# Check if max_concurrency has been updated and recreate semaphore if needed (concurrent load type)
@@ -154,7 +155,7 @@ async def schedule_client(
154155
with self.active_requests_counter.get_lock():
155156
self.active_requests_counter.value += 1
156157
inflight = True
157-
await self.client.process_request(request_data, stage_id, request_time)
158+
await session.process_request(request_data, stage_id, request_time)
158159
except CancelledError:
159160
pass
160161
finally:
@@ -188,6 +189,7 @@ async def schedule_client(
188189
logger.debug(f"[Worker {self.id}] waiting for next phase")
189190
self.request_phase.wait()
190191

192+
await session.close()
191193
logger.debug(f"[Worker {self.id}] stopped")
192194

193195
def run(self) -> None:

0 commit comments

Comments
 (0)