Skip to content

Commit bec8172

Browse files
committed
Share aiohttp.ClientSessions per worker
Slightly refactor `openAIModelServerClient` to add a new method, `process_request_with_session`, that accepts a custom `ReusableHTTPClientSession` per request, which allows the caller to reuse an HTTP client session per worker. The previous method, `process_request`, is made to create a fresh HTTP client session then call `process_request_with_session`, preserving the previous behavior. Prior to this commit, a new `aiohttp.ClientSession` is created for each request. Not only is this inefficient and lowers throughput, on certain environments, it also leads to inotify watch issues: aiodns - WARNING - Failed to create DNS resolver channel with automatic monitoring of resolver configuration changes. This usually means the system ran out of inotify watches. Falling back to socket state callback. Consider increasing the system inotify watch limit: Failed to initialize c-ares channel Indeed, because each DNS resolver is created for a new `ClientSession`, creating tons of new `ClientSession`s causes eventual inotify watch exhaustion. Sharing `ClientSession`s solves this issue. Relevant links: - https://docs.aiohttp.org/en/stable/http_request_lifecycle.html - https://stackoverflow.com/questions/62707369/one-aiohttp-clientsession-per-thread - home-assistant/core#144457 (comment) Relevant PR: kubernetes-sigs#247 (doesn't address the issue of worker sharing).
1 parent c85e5a4 commit bec8172

File tree

3 files changed

+125
-79
lines changed

3 files changed

+125
-79
lines changed

inference_perf/client/modelserver/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from .base import ModelServerClient
14+
from .base import ModelServerClient, ModelServerClientSession
1515
from .mock_client import MockModelServerClient
1616
from .vllm_client import vLLMModelServerClient
1717
from .sglang_client import SGlangModelServerClient
1818

1919

20-
__all__ = ["ModelServerClient", "MockModelServerClient", "vLLMModelServerClient", "SGlangModelServerClient"]
20+
__all__ = [
21+
"ModelServerClient",
22+
"ModelServerClientSession",
23+
"MockModelServerClient",
24+
"vLLMModelServerClient",
25+
"SGlangModelServerClient",
26+
]

inference_perf/client/modelserver/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from typing import List, Optional, Tuple
1616
from inference_perf.client.metricsclient.base import MetricsMetadata
1717
from inference_perf.config import APIConfig, APIType
18-
1918
from inference_perf.apis import InferenceAPIData
2019

2120

@@ -82,6 +81,9 @@ def __init__(self, api_config: APIConfig, timeout: Optional[float] = None, *args
8281
self.api_config = api_config
8382
self.timeout = timeout
8483

84+
def new_session(self) -> "ModelServerClientSession":
85+
return ModelServerClientSession(self)
86+
8587
@abstractmethod
8688
def get_supported_apis(self) -> List[APIType]:
8789
raise NotImplementedError
@@ -94,3 +96,14 @@ async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled
9496
def get_prometheus_metric_metadata(self) -> PrometheusMetricMetadata:
9597
# assumption: all metrics clients have metrics exported in Prometheus format
9698
raise NotImplementedError
99+
100+
101+
class ModelServerClientSession:
102+
def __init__(self, client: ModelServerClient):
103+
self.client = client
104+
105+
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
106+
await self.client.process_request(data, stage_id, scheduled_time)
107+
108+
async def close(self) -> None: # noqa - subclasses optionally override this
109+
pass

inference_perf/client/modelserver/openai_client.py

Lines changed: 103 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from inference_perf.config import APIConfig, APIType, CustomTokenizerConfig
1818
from inference_perf.apis import InferenceAPIData, InferenceInfo, RequestLifecycleMetric, ErrorResponseInfo
1919
from inference_perf.utils import CustomTokenizer
20-
from .base import ModelServerClient, PrometheusMetricMetadata
20+
from .base import ModelServerClient, ModelServerClientSession, PrometheusMetricMetadata
2121
from typing import List, Optional
2222
import aiohttp
2323
import asyncio
@@ -30,6 +30,8 @@
3030

3131

3232
class openAIModelServerClient(ModelServerClient):
33+
_session: "openAIModelServerClientSession | None" = None
34+
3335
def __init__(
3436
self,
3537
metrics_collector: RequestDataCollector,
@@ -70,82 +72,23 @@ def __init__(
7072
tokenizer_config = CustomTokenizerConfig(pretrained_model_name_or_path=self.model_name)
7173
self.tokenizer = CustomTokenizer(tokenizer_config)
7274

73-
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
74-
payload = await data.to_payload(
75-
model_name=self.model_name,
76-
max_tokens=self.max_completion_tokens,
77-
ignore_eos=self.ignore_eos,
78-
streaming=self.api_config.streaming,
79-
)
80-
headers = {"Content-Type": "application/json"}
81-
82-
if self.api_key:
83-
headers["Authorization"] = f"Bearer {self.api_key}"
84-
85-
if self.api_config.headers:
86-
headers.update(self.api_config.headers)
75+
def new_session(self) -> "ModelServerClientSession":
76+
return openAIModelServerClientSession(self)
8777

88-
request_data = json.dumps(payload)
89-
90-
timeout = aiohttp.ClientTimeout(total=self.timeout) if self.timeout else aiohttp.helpers.sentinel
91-
92-
async with aiohttp.ClientSession(
93-
connector=aiohttp.TCPConnector(limit=self.max_tcp_connections), timeout=timeout
94-
) as session:
95-
start = time.perf_counter()
96-
try:
97-
async with session.post(self.uri + data.get_route(), headers=headers, data=request_data) as response:
98-
response_info = await data.process_response(
99-
response=response, config=self.api_config, tokenizer=self.tokenizer
100-
)
101-
response_content = await response.text()
102-
103-
end_time = time.perf_counter()
104-
error = None
105-
if response.status != 200:
106-
error = ErrorResponseInfo(
107-
error_msg=response_content,
108-
error_type=f"{response.status} {response.reason}",
109-
)
110-
111-
self.metrics_collector.record_metric(
112-
RequestLifecycleMetric(
113-
stage_id=stage_id,
114-
request_data=request_data,
115-
response_data=response_content,
116-
info=response_info,
117-
error=error,
118-
start_time=start,
119-
end_time=end_time,
120-
scheduled_time=scheduled_time,
121-
)
122-
)
123-
except Exception as e:
124-
if isinstance(e, asyncio.exceptions.TimeoutError):
125-
logger.error("request timed out:", exc_info=True)
126-
else:
127-
logger.error("error occured during request processing:", exc_info=True)
128-
failure_info = await data.process_failure(
129-
response=response if "response" in locals() else None,
130-
config=self.api_config,
131-
tokenizer=self.tokenizer,
132-
exception=e,
133-
)
134-
self.metrics_collector.record_metric(
135-
RequestLifecycleMetric(
136-
stage_id=stage_id,
137-
request_data=request_data,
138-
response_data=response_content if "response_content" in locals() else "",
139-
info=failure_info if failure_info else InferenceInfo(),
140-
error=ErrorResponseInfo(
141-
error_msg=str(e),
142-
error_type=type(e).__name__,
143-
),
144-
start_time=start,
145-
end_time=time.perf_counter(),
146-
scheduled_time=scheduled_time,
147-
)
148-
)
78+
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
79+
"""
80+
Create an internal client session if not already, then use that to
81+
process the request.
82+
"""
83+
if self._session is None:
84+
self._session = openAIModelServerClientSession(self)
85+
await self._session.process_request(data, stage_id, scheduled_time)
86+
87+
async def close(self) -> None:
88+
"""Close the internal session created by process_request, if any."""
89+
if self._session is not None:
90+
await self._session.close()
91+
self._session = None
14992

15093
def get_supported_apis(self) -> List[APIType]:
15194
return []
@@ -166,3 +109,87 @@ def get_supported_models(self) -> List[str]:
166109
except Exception as e:
167110
logger.error(f"Got exception retrieving supported models {e}")
168111
return []
112+
113+
114+
class openAIModelServerClientSession(ModelServerClientSession):
115+
def __init__(self, client: openAIModelServerClient):
116+
self.client = client
117+
self.session = aiohttp.ClientSession(
118+
timeout=aiohttp.ClientTimeout(total=client.timeout) if client.timeout else aiohttp.helpers.sentinel,
119+
connector=aiohttp.TCPConnector(limit=client.max_tcp_connections),
120+
)
121+
122+
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
123+
payload = await data.to_payload(
124+
model_name=self.client.model_name,
125+
max_tokens=self.client.max_completion_tokens,
126+
ignore_eos=self.client.ignore_eos,
127+
streaming=self.client.api_config.streaming,
128+
)
129+
headers = {"Content-Type": "application/json"}
130+
131+
if self.client.api_key:
132+
headers["Authorization"] = f"Bearer {self.client.api_key}"
133+
134+
if self.client.api_config.headers:
135+
headers.update(self.client.api_config.headers)
136+
137+
request_data = json.dumps(payload)
138+
139+
start = time.perf_counter()
140+
try:
141+
async with self.session.post(self.client.uri + data.get_route(), headers=headers, data=request_data) as response:
142+
response_info = await data.process_response(
143+
response=response, config=self.client.api_config, tokenizer=self.client.tokenizer
144+
)
145+
response_content = await response.text()
146+
147+
end_time = time.perf_counter()
148+
error = None
149+
if response.status != 200:
150+
error = ErrorResponseInfo(
151+
error_msg=response_content,
152+
error_type=f"{response.status} {response.reason}",
153+
)
154+
155+
self.client.metrics_collector.record_metric(
156+
RequestLifecycleMetric(
157+
stage_id=stage_id,
158+
request_data=request_data,
159+
response_data=response_content,
160+
info=response_info,
161+
error=error,
162+
start_time=start,
163+
end_time=end_time,
164+
scheduled_time=scheduled_time,
165+
)
166+
)
167+
except Exception as e:
168+
if isinstance(e, asyncio.exceptions.TimeoutError):
169+
logger.error("request timed out:", exc_info=True)
170+
else:
171+
logger.error("error occured during request processing:", exc_info=True)
172+
failure_info = await data.process_failure(
173+
response=response if "response" in locals() else None,
174+
config=self.client.api_config,
175+
tokenizer=self.client.tokenizer,
176+
exception=e,
177+
)
178+
self.client.metrics_collector.record_metric(
179+
RequestLifecycleMetric(
180+
stage_id=stage_id,
181+
request_data=request_data,
182+
response_data=response_content if "response_content" in locals() else "",
183+
info=failure_info if failure_info else InferenceInfo(),
184+
error=ErrorResponseInfo(
185+
error_msg=str(e),
186+
error_type=type(e).__name__,
187+
),
188+
start_time=start,
189+
end_time=time.perf_counter(),
190+
scheduled_time=scheduled_time,
191+
)
192+
)
193+
194+
async def close(self) -> None:
195+
await self.session.close()

0 commit comments

Comments
 (0)