Skip to content

Commit 6def4b3

Browse files
author
rtp-llm
committed
update - embedding add exception and metrics
1 parent 1fb7fd0 commit 6def4b3

File tree

4 files changed

+52
-38
lines changed

4 files changed

+52
-38
lines changed

maga_transformer/access_logger/access_logger.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import logging
3+
from pydantic import BaseModel
34
from typing import Any, Union, Dict
45

56
from maga_transformer.access_logger.json_util import dump_json
@@ -57,14 +58,18 @@ def log_query_access(self, request: Union[Dict[str, Any], str], id: int) -> None
5758
access_log = PyAccessLog(request = request_log, response = response_log, id = id)
5859
self.query_logger.info(dump_json(access_log))
5960

60-
def log_success_access(self, request: Union[Dict[str, Any], str], response: Any, id: int) -> None:
61+
def log_success_access(self, request: Union[Dict[str, Any], str, BaseModel], response: Any, id: int) -> None:
62+
if isinstance(request, BaseModel):
63+
request = request.model_dump()
6164
if not self.is_private_request(request):
6265
response_log = ResponseLog()
6366
if LOG_RESPONSE:
6467
response_log.add_response(response)
6568
self.log_access(request, response_log, id)
6669

67-
def log_exception_access(self, request: Union[Dict[str, Any], str], exception: BaseException, id: int) -> None:
70+
def log_exception_access(self, request: Union[Dict[str, Any], str, BaseModel], exception: BaseException, id: int) -> None:
71+
if isinstance(request, BaseModel):
72+
request = request.model_dump()
6873
response_log = ResponseLog()
6974
response_log.add_exception(exception)
7075
if not self.is_private_request(request):

maga_transformer/embedding/api_datatype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
class OpenAIEmbeddingRequest(PyDanticModelBase):
77
input: Union[str, List[str]]
8-
model: str
8+
model: str = ""
99
encoding_format: str = 'float'
1010
user: str = ""
1111
embedding_config: EmbeddingGenerateConfig = EmbeddingGenerateConfig()

maga_transformer/server/inference_server.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import json
33
import time
4+
import copy
45
import logging
56
import logging.config
67
import traceback
@@ -60,7 +61,6 @@ def start(self):
6061
self._openai_endpoint = None
6162
self._embedding_endpoint = None
6263
if self._inference_worker.model is not None and self._inference_worker.model.model_type == ModelType.EMBEDDING:
63-
assert isinstance(self._inference_worker.model, AsyncModel), "only support embedding model in async mode"
6464
self._embedding_endpoint = EmbeddingEndpoint(self._inference_worker.model)
6565
else:
6666
self._openai_endpoint = OpenaiEndopoint(self._inference_worker.model)
@@ -111,25 +111,25 @@ async def stream_response(
111111
self._access_logger.log_exception_access(request, e, id)
112112
kmonitor.report(AccMetrics.ERROR_QPS_METRIC, 1)
113113
yield response_data_prefix + \
114-
json.dumps(InferenceServer.handler_exceptions(e), ensure_ascii=False) + "\r\n\r\n"
114+
json.dumps(InferenceServer.format_exception(e), ensure_ascii=False) + "\r\n\r\n"
115115

116116
@staticmethod
117-
def format_exception(errcode: int, message: str) -> Dict[str, Any]:
118-
return {'error_code': errcode, "message": message}
117+
def format_exception(e: Exception):
118+
@staticmethod
119+
def _format(errcode: int, message: str) -> Dict[str, Any]:
120+
return {'error_code': errcode, "message": message}
119121

120-
@staticmethod
121-
def handler_exceptions(e: Exception):
122122
if isinstance(e, FtRuntimeException):
123-
return InferenceServer.format_exception(e.expcetion_type, e.message)
123+
return _format(e.expcetion_type, e.message)
124124
elif isinstance(e, ConcurrencyException):
125-
return InferenceServer.format_exception(ExceptionType.CONCURRENCY_LIMIT_ERROR, str(e))
125+
return _format(ExceptionType.CONCURRENCY_LIMIT_ERROR, str(e))
126126
elif isinstance(e, LoraCountException) or isinstance(e, LoraPathException):
127-
return InferenceServer.format_exception(ExceptionType.UPDATE_ERROR, str(e))
127+
return _format(ExceptionType.UPDATE_ERROR, str(e))
128128
elif isinstance(e, Exception):
129129
error_msg = f'ErrorMsg: {str(e)} \n Traceback: {traceback.format_exc()}'
130-
return InferenceServer.format_exception(ExceptionType.UNKNOWN_ERROR, error_msg)
130+
return _format(ExceptionType.UNKNOWN_ERROR, error_msg)
131131
else:
132-
return InferenceServer.format_exception(ExceptionType.UNKNOWN_ERROR, str(e))
132+
return _format(ExceptionType.UNKNOWN_ERROR, str(e))
133133

134134
def update(self, version_info: VersionInfo):
135135
id = self._atomic_count.increment()
@@ -146,7 +146,7 @@ def update(self, version_info: VersionInfo):
146146
self._access_logger.log_exception_access(version_info.__dict__, e, id)
147147
kmonitor.report(AccMetrics.ERROR_UPDATE_QPS_METRIC, 1)
148148
error_code = 500
149-
rep = JSONResponse(self.handler_exceptions(e), status_code=error_code)
149+
rep = JSONResponse(self.format_exception(e), status_code=error_code)
150150
return rep
151151

152152
async def inference(self, req: Union[str,Dict[Any, Any]], raw_request: RawRequest):
@@ -165,17 +165,7 @@ async def _infer_wrap(self, req: Dict[Any, Any], raw_request: RawRequest, genera
165165
try:
166166
rep = await self._infer_impl(req, id, raw_request, generate_call)
167167
except Exception as e:
168-
self._access_logger.log_exception_access(req, e, id)
169-
if isinstance(e, ConcurrencyException):
170-
kmonitor.report(AccMetrics.CONFLICT_QPS_METRIC)
171-
error_code = 409
172-
elif isinstance(e, asyncio.CancelledError):
173-
kmonitor.report(AccMetrics.CANCAL_QPS_METRIC, 1)
174-
error_code = 499
175-
else:
176-
error_code = 500
177-
kmonitor.report(AccMetrics.ERROR_QPS_METRIC, 1)
178-
rep = JSONResponse(self.handler_exceptions(e), status_code=error_code)
168+
rep = self._handle_exception(req, e, id)
179169
return rep
180170

181171
async def chat_completion(self, request: ChatCompletionRequest, raw_request: Request):
@@ -187,26 +177,45 @@ def generate_call():
187177
return await self._infer_wrap(request.model_dump(), raw_request, generate_call)
188178

189179
async def embedding(self, request: Union[Dict[str, Any], str, OpenAIEmbeddingRequest], raw_request: Request):
180+
id = self._atomic_count.increment()
181+
kmonitor.report(AccMetrics.QPS_METRIC, 1)
190182
with self._controller:
191183
try:
192184
assert self._embedding_endpoint is not None, "embedding pipeline should not be None"
193185
result = await self._embedding_endpoint.embedding(request)
186+
log_result = copy.copy(result)
187+
# do not log result since too big
188+
log_result.data = []
189+
self._access_logger.log_success_access(request, log_result, id)
194190
return JSONResponse(result.model_dump(exclude_none=True))
195-
except FtRuntimeException:
196-
raise
197191
except Exception as e:
198-
raise FtRuntimeException(ExceptionType.UNKNOWN_ERROR, str(e))
199-
192+
self._handle_exception(request, e, id)
193+
200194
async def similarity(self, request: Union[Dict[str, Any], str, SimilarityRequest], raw_request: Request):
195+
id = self._atomic_count.increment()
196+
kmonitor.report(AccMetrics.QPS_METRIC, 1)
201197
with self._controller:
202198
try:
203199
assert self._embedding_endpoint is not None, "embedding pipeline should not be None"
204200
result = await self._embedding_endpoint.similarity(request)
201+
self._access_logger.log_success_access(request, result.model_dump(exclude_none=True), id)
205202
return JSONResponse(result.model_dump(exclude_none=True))
206-
except FtRuntimeException:
207-
raise
208203
except Exception as e:
209-
raise FtRuntimeException(ExceptionType.UNKNOWN_ERROR, str(e))
204+
self._handle_exception(request, e, id)
205+
206+
def _handle_exception(self, request: Union[Dict[str, Any], str, BaseModel], e: Exception, id: int):
207+
self._access_logger.log_exception_access(request, e, id)
208+
if isinstance(e, ConcurrencyException):
209+
kmonitor.report(AccMetrics.CONFLICT_QPS_METRIC)
210+
error_code = 409
211+
elif isinstance(e, asyncio.CancelledError):
212+
kmonitor.report(AccMetrics.CANCAL_QPS_METRIC, 1)
213+
error_code = 499
214+
else:
215+
error_code = 500
216+
kmonitor.report(AccMetrics.ERROR_QPS_METRIC, 1)
217+
rep = JSONResponse(self.format_exception(e), status_code=error_code)
218+
return rep
210219

211220
async def _call_generate_with_report(self, generate_call: Callable[[], CompleteResponseAsyncGenerator]):
212221
async def __gen_response_with_report(start_time, response_generator):
@@ -283,4 +292,4 @@ def tokenizer_encode(self, req: Union[str,Dict[Any, Any]]):
283292
response = TokenizerEncodeResponse(token_ids=token_ids, tokens=tokens)
284293
return JSONResponse(content=response.model_dump(exclude_none=True))
285294
except Exception as e:
286-
return JSONResponse(self.handler_exceptions(e), status_code=500)
295+
return JSONResponse(self.format_exception(e), status_code=500)

maga_transformer/server/inference_worker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from maga_transformer.config.exceptions import FtRuntimeException, ExceptionType
1919
from maga_transformer.models.base_model import GenerateResponse, GenerateConfig
2020
from maga_transformer.model_factory import ModelFactory, AsyncModel
21-
from maga_transformer.structure.request_extractor import RequestExtractor
21+
from maga_transformer.structure.request_extractor import RequestExtractor, Request
2222

2323
from pydantic import BaseModel
2424

@@ -52,7 +52,7 @@ def __init__(self) -> None:
5252
if not torch.cuda.is_available():
5353
raise Exception("GPU not found")
5454

55-
self.model = ModelFactory.create_from_env()
55+
self.model: AsyncModel = ModelFactory.create_from_env()
5656
self.pipeline = Pipeline(self.model, self.model.tokenizer)
5757
logging.info("Load model done.")
5858

@@ -78,7 +78,7 @@ def inference(self, **kwargs: Any) -> CompleteResponseAsyncGenerator:
7878
return CompleteResponseAsyncGenerator(response_generator, complete_response_collect_func)
7979

8080

81-
def _inference(self, request, **kwargs):
81+
def _inference(self, request: Request, **kwargs: Any):
8282
if len(request.input_texts) > 1 or request.batch_infer or request.num_return_sequences > 0:
8383
num_return_sequences = request.generate_configs[0].num_return_sequences
8484
generators = [self._yield_generate(text, images, generate_config=generate_config, **kwargs)
@@ -127,7 +127,7 @@ def is_streaming(self, req: Dict[str, Any]):
127127
return RequestExtractor.is_streaming(req) or req.get('stream', False)
128128

129129
def update(self, version_info: VersionInfo):
130-
lora_infos = dict()
130+
lora_infos: Dict[str, Any] = dict()
131131
if version_info.peft_info != None:
132132
lora_infos = version_info.peft_info.get("lora_info", {})
133133
return self.model.update(lora_infos)

0 commit comments

Comments
 (0)