11import os
22import json
33import time
4+ import copy
45import logging
56import logging .config
67import traceback
@@ -60,7 +61,6 @@ def start(self):
6061 self ._openai_endpoint = None
6162 self ._embedding_endpoint = None
6263 if self ._inference_worker .model is not None and self ._inference_worker .model .model_type == ModelType .EMBEDDING :
63- assert isinstance (self ._inference_worker .model , AsyncModel ), "only support embedding model in async mode"
6464 self ._embedding_endpoint = EmbeddingEndpoint (self ._inference_worker .model )
6565 else :
6666 self ._openai_endpoint = OpenaiEndopoint (self ._inference_worker .model )
@@ -111,25 +111,25 @@ async def stream_response(
111111 self ._access_logger .log_exception_access (request , e , id )
112112 kmonitor .report (AccMetrics .ERROR_QPS_METRIC , 1 )
113113 yield response_data_prefix + \
114- json .dumps (InferenceServer .handler_exceptions (e ), ensure_ascii = False ) + "\r \n \r \n "
114+ json .dumps (InferenceServer .format_exception (e ), ensure_ascii = False ) + "\r \n \r \n "
115115
116116 @staticmethod
117- def format_exception (errcode : int , message : str ) -> Dict [str , Any ]:
118- return {'error_code' : errcode , "message" : message }
117+ def format_exception (e : Exception ):
118+ @staticmethod
119+ def _format (errcode : int , message : str ) -> Dict [str , Any ]:
120+ return {'error_code' : errcode , "message" : message }
119121
120- @staticmethod
121- def handler_exceptions (e : Exception ):
122122 if isinstance (e , FtRuntimeException ):
123- return InferenceServer . format_exception (e .expcetion_type , e .message )
123+ return _format (e .expcetion_type , e .message )
124124 elif isinstance (e , ConcurrencyException ):
125- return InferenceServer . format_exception (ExceptionType .CONCURRENCY_LIMIT_ERROR , str (e ))
125+ return _format (ExceptionType .CONCURRENCY_LIMIT_ERROR , str (e ))
126126 elif isinstance (e , LoraCountException ) or isinstance (e , LoraPathException ):
127- return InferenceServer . format_exception (ExceptionType .UPDATE_ERROR , str (e ))
127+ return _format (ExceptionType .UPDATE_ERROR , str (e ))
128128 elif isinstance (e , Exception ):
129129 error_msg = f'ErrorMsg: { str (e )} \n Traceback: { traceback .format_exc ()} '
130- return InferenceServer . format_exception (ExceptionType .UNKNOWN_ERROR , error_msg )
130+ return _format (ExceptionType .UNKNOWN_ERROR , error_msg )
131131 else :
132- return InferenceServer . format_exception (ExceptionType .UNKNOWN_ERROR , str (e ))
132+ return _format (ExceptionType .UNKNOWN_ERROR , str (e ))
133133
134134 def update (self , version_info : VersionInfo ):
135135 id = self ._atomic_count .increment ()
@@ -146,7 +146,7 @@ def update(self, version_info: VersionInfo):
146146 self ._access_logger .log_exception_access (version_info .__dict__ , e , id )
147147 kmonitor .report (AccMetrics .ERROR_UPDATE_QPS_METRIC , 1 )
148148 error_code = 500
149- rep = JSONResponse (self .handler_exceptions (e ), status_code = error_code )
149+ rep = JSONResponse (self .format_exception (e ), status_code = error_code )
150150 return rep
151151
152152 async def inference (self , req : Union [str ,Dict [Any , Any ]], raw_request : RawRequest ):
@@ -165,17 +165,7 @@ async def _infer_wrap(self, req: Dict[Any, Any], raw_request: RawRequest, genera
165165 try :
166166 rep = await self ._infer_impl (req , id , raw_request , generate_call )
167167 except Exception as e :
168- self ._access_logger .log_exception_access (req , e , id )
169- if isinstance (e , ConcurrencyException ):
170- kmonitor .report (AccMetrics .CONFLICT_QPS_METRIC )
171- error_code = 409
172- elif isinstance (e , asyncio .CancelledError ):
173- kmonitor .report (AccMetrics .CANCAL_QPS_METRIC , 1 )
174- error_code = 499
175- else :
176- error_code = 500
177- kmonitor .report (AccMetrics .ERROR_QPS_METRIC , 1 )
178- rep = JSONResponse (self .handler_exceptions (e ), status_code = error_code )
168+ rep = self ._handle_exception (req , e , id )
179169 return rep
180170
181171 async def chat_completion (self , request : ChatCompletionRequest , raw_request : Request ):
@@ -187,26 +177,45 @@ def generate_call():
187177 return await self ._infer_wrap (request .model_dump (), raw_request , generate_call )
188178
189179 async def embedding (self , request : Union [Dict [str , Any ], str , OpenAIEmbeddingRequest ], raw_request : Request ):
180+ id = self ._atomic_count .increment ()
181+ kmonitor .report (AccMetrics .QPS_METRIC , 1 )
190182 with self ._controller :
191183 try :
192184 assert self ._embedding_endpoint is not None , "embedding pipeline should not be None"
193185 result = await self ._embedding_endpoint .embedding (request )
186+ log_result = copy .copy (result )
187+ # do not log result since too big
188+ log_result .data = []
189+ self ._access_logger .log_success_access (request , log_result , id )
194190 return JSONResponse (result .model_dump (exclude_none = True ))
195- except FtRuntimeException :
196- raise
197191 except Exception as e :
198- raise FtRuntimeException ( ExceptionType . UNKNOWN_ERROR , str ( e ) )
199-
192+ self . _handle_exception ( request , e , id )
193+
200194 async def similarity (self , request : Union [Dict [str , Any ], str , SimilarityRequest ], raw_request : Request ):
195+ id = self ._atomic_count .increment ()
196+ kmonitor .report (AccMetrics .QPS_METRIC , 1 )
201197 with self ._controller :
202198 try :
203199 assert self ._embedding_endpoint is not None , "embedding pipeline should not be None"
204200 result = await self ._embedding_endpoint .similarity (request )
201+ self ._access_logger .log_success_access (request , result .model_dump (exclude_none = True ), id )
205202 return JSONResponse (result .model_dump (exclude_none = True ))
206- except FtRuntimeException :
207- raise
208203 except Exception as e :
209- raise FtRuntimeException (ExceptionType .UNKNOWN_ERROR , str (e ))
204+ self ._handle_exception (request , e , id )
205+
206+ def _handle_exception (self , request : Union [Dict [str , Any ], str , BaseModel ], e : Exception , id : int ):
207+ self ._access_logger .log_exception_access (request , e , id )
208+ if isinstance (e , ConcurrencyException ):
209+ kmonitor .report (AccMetrics .CONFLICT_QPS_METRIC )
210+ error_code = 409
211+ elif isinstance (e , asyncio .CancelledError ):
212+ kmonitor .report (AccMetrics .CANCAL_QPS_METRIC , 1 )
213+ error_code = 499
214+ else :
215+ error_code = 500
216+ kmonitor .report (AccMetrics .ERROR_QPS_METRIC , 1 )
217+ rep = JSONResponse (self .format_exception (e ), status_code = error_code )
218+ return rep
210219
211220 async def _call_generate_with_report (self , generate_call : Callable [[], CompleteResponseAsyncGenerator ]):
212221 async def __gen_response_with_report (start_time , response_generator ):
@@ -283,4 +292,4 @@ def tokenizer_encode(self, req: Union[str,Dict[Any, Any]]):
283292 response = TokenizerEncodeResponse (token_ids = token_ids , tokens = tokens )
284293 return JSONResponse (content = response .model_dump (exclude_none = True ))
285294 except Exception as e :
286- return JSONResponse (self .handler_exceptions (e ), status_code = 500 )
295+ return JSONResponse (self .format_exception (e ), status_code = 500 )
0 commit comments