1717from inference_perf .config import APIConfig , APIType , CustomTokenizerConfig
1818from inference_perf .apis import InferenceAPIData , InferenceInfo , RequestLifecycleMetric , ErrorResponseInfo
1919from inference_perf .utils import CustomTokenizer
20- from .base import ModelServerClient , PrometheusMetricMetadata
20+ from .base import ModelServerClient , ModelServerClientSession , PrometheusMetricMetadata
2121from typing import List , Optional
2222import aiohttp
2323import asyncio
3030
3131
3232class openAIModelServerClient (ModelServerClient ):
33+ _session : "openAIModelServerClientSession | None" = None
34+
3335 def __init__ (
3436 self ,
3537 metrics_collector : RequestDataCollector ,
@@ -70,82 +72,23 @@ def __init__(
7072 tokenizer_config = CustomTokenizerConfig (pretrained_model_name_or_path = self .model_name )
7173 self .tokenizer = CustomTokenizer (tokenizer_config )
7274
73- async def process_request (self , data : InferenceAPIData , stage_id : int , scheduled_time : float ) -> None :
74- payload = await data .to_payload (
75- model_name = self .model_name ,
76- max_tokens = self .max_completion_tokens ,
77- ignore_eos = self .ignore_eos ,
78- streaming = self .api_config .streaming ,
79- )
80- headers = {"Content-Type" : "application/json" }
81-
82- if self .api_key :
83- headers ["Authorization" ] = f"Bearer { self .api_key } "
84-
85- if self .api_config .headers :
86- headers .update (self .api_config .headers )
75+ def new_session (self ) -> "ModelServerClientSession" :
76+ return openAIModelServerClientSession (self )
8777
88- request_data = json .dumps (payload )
89-
90- timeout = aiohttp .ClientTimeout (total = self .timeout ) if self .timeout else aiohttp .helpers .sentinel
91-
92- async with aiohttp .ClientSession (
93- connector = aiohttp .TCPConnector (limit = self .max_tcp_connections ), timeout = timeout
94- ) as session :
95- start = time .perf_counter ()
96- try :
97- async with session .post (self .uri + data .get_route (), headers = headers , data = request_data ) as response :
98- response_info = await data .process_response (
99- response = response , config = self .api_config , tokenizer = self .tokenizer
100- )
101- response_content = await response .text ()
102-
103- end_time = time .perf_counter ()
104- error = None
105- if response .status != 200 :
106- error = ErrorResponseInfo (
107- error_msg = response_content ,
108- error_type = f"{ response .status } { response .reason } " ,
109- )
110-
111- self .metrics_collector .record_metric (
112- RequestLifecycleMetric (
113- stage_id = stage_id ,
114- request_data = request_data ,
115- response_data = response_content ,
116- info = response_info ,
117- error = error ,
118- start_time = start ,
119- end_time = end_time ,
120- scheduled_time = scheduled_time ,
121- )
122- )
123- except Exception as e :
124- if isinstance (e , asyncio .exceptions .TimeoutError ):
125- logger .error ("request timed out:" , exc_info = True )
126- else :
127- logger .error ("error occured during request processing:" , exc_info = True )
128- failure_info = await data .process_failure (
129- response = response if "response" in locals () else None ,
130- config = self .api_config ,
131- tokenizer = self .tokenizer ,
132- exception = e ,
133- )
134- self .metrics_collector .record_metric (
135- RequestLifecycleMetric (
136- stage_id = stage_id ,
137- request_data = request_data ,
138- response_data = response_content if "response_content" in locals () else "" ,
139- info = failure_info if failure_info else InferenceInfo (),
140- error = ErrorResponseInfo (
141- error_msg = str (e ),
142- error_type = type (e ).__name__ ,
143- ),
144- start_time = start ,
145- end_time = time .perf_counter (),
146- scheduled_time = scheduled_time ,
147- )
148- )
78+ async def process_request (self , data : InferenceAPIData , stage_id : int , scheduled_time : float ) -> None :
79+ """
80+ Create an internal client session if not already, then use that to
81+ process the request.
82+ """
83+ if self ._session is None :
84+ self ._session = openAIModelServerClientSession (self )
85+ await self ._session .process_request (data , stage_id , scheduled_time )
86+
87+ async def close (self ) -> None :
88+ """Close the internal session created by process_request, if any."""
89+ if self ._session is not None :
90+ await self ._session .close ()
91+ self ._session = None
14992
15093 def get_supported_apis (self ) -> List [APIType ]:
15194 return []
@@ -166,3 +109,87 @@ def get_supported_models(self) -> List[str]:
166109 except Exception as e :
167110 logger .error (f"Got exception retrieving supported models { e } " )
168111 return []
112+
113+
114+ class openAIModelServerClientSession (ModelServerClientSession ):
115+ def __init__ (self , client : openAIModelServerClient ):
116+ self .client = client
117+ self .session = aiohttp .ClientSession (
118+ timeout = aiohttp .ClientTimeout (total = client .timeout ) if client .timeout else aiohttp .helpers .sentinel ,
119+ connector = aiohttp .TCPConnector (limit = client .max_tcp_connections ),
120+ )
121+
122+ async def process_request (self , data : InferenceAPIData , stage_id : int , scheduled_time : float ) -> None :
123+ payload = await data .to_payload (
124+ model_name = self .client .model_name ,
125+ max_tokens = self .client .max_completion_tokens ,
126+ ignore_eos = self .client .ignore_eos ,
127+ streaming = self .client .api_config .streaming ,
128+ )
129+ headers = {"Content-Type" : "application/json" }
130+
131+ if self .client .api_key :
132+ headers ["Authorization" ] = f"Bearer { self .client .api_key } "
133+
134+ if self .client .api_config .headers :
135+ headers .update (self .client .api_config .headers )
136+
137+ request_data = json .dumps (payload )
138+
139+ start = time .perf_counter ()
140+ try :
141+ async with self .session .post (self .client .uri + data .get_route (), headers = headers , data = request_data ) as response :
142+ response_info = await data .process_response (
143+ response = response , config = self .client .api_config , tokenizer = self .client .tokenizer
144+ )
145+ response_content = await response .text ()
146+
147+ end_time = time .perf_counter ()
148+ error = None
149+ if response .status != 200 :
150+ error = ErrorResponseInfo (
151+ error_msg = response_content ,
152+ error_type = f"{ response .status } { response .reason } " ,
153+ )
154+
155+ self .client .metrics_collector .record_metric (
156+ RequestLifecycleMetric (
157+ stage_id = stage_id ,
158+ request_data = request_data ,
159+ response_data = response_content ,
160+ info = response_info ,
161+ error = error ,
162+ start_time = start ,
163+ end_time = end_time ,
164+ scheduled_time = scheduled_time ,
165+ )
166+ )
167+ except Exception as e :
168+ if isinstance (e , asyncio .exceptions .TimeoutError ):
169+ logger .error ("request timed out:" , exc_info = True )
170+ else :
171+ logger .error ("error occured during request processing:" , exc_info = True )
172+ failure_info = await data .process_failure (
173+ response = response if "response" in locals () else None ,
174+ config = self .client .api_config ,
175+ tokenizer = self .client .tokenizer ,
176+ exception = e ,
177+ )
178+ self .client .metrics_collector .record_metric (
179+ RequestLifecycleMetric (
180+ stage_id = stage_id ,
181+ request_data = request_data ,
182+ response_data = response_content if "response_content" in locals () else "" ,
183+ info = failure_info if failure_info else InferenceInfo (),
184+ error = ErrorResponseInfo (
185+ error_msg = str (e ),
186+ error_type = type (e ).__name__ ,
187+ ),
188+ start_time = start ,
189+ end_time = time .perf_counter (),
190+ scheduled_time = scheduled_time ,
191+ )
192+ )
193+
194+ async def close (self ) -> None :
195+ await self .session .close ()
0 commit comments