diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index d08af43401..76c9d55120 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -125,6 +125,7 @@ def add_parser_api_server(): 'engine’s tasks once the maximum number of concurrent requests is ' 'reached, regardless of any additional requests sent by clients ' 'concurrently during that time. Default to None.') + parser.add_argument('--metrics', action='store_true', default=False, help='Whether log stats to prometheus') # common args ArgumentHelper.backend(parser) ArgumentHelper.log_level(parser) @@ -337,6 +338,7 @@ def api_server(args): proxy_url=args.proxy_url, max_log_len=args.max_log_len, disable_fastapi_docs=args.disable_fastapi_docs, + metrics=args.metrics, max_concurrent_requests=args.max_concurrent_requests, reasoning_parser=args.reasoning_parser, tool_call_parser=args.tool_call_parser) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index ae8f9eb58f..a860787b8e 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -22,6 +22,7 @@ from lmdeploy.logger import RequestLogger from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, Response, ResponseType, TurbomindEngineConfig from lmdeploy.model import MODELS, BaseChatTemplate, ChatTemplateConfig, best_match_model +from lmdeploy.serve.metrics import IterTimer, Metrics from lmdeploy.serve.utils import LogitsMixin from lmdeploy.tokenizer import DetokenizeState from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_hf_gen_cfg, get_logger @@ -89,6 +90,11 @@ def _append_response(dst: Response, src: Response): return dst +def is_error(status: ResponseType): + """Whether is an error response.""" + return status not in [ResponseType.SUCCESS, ResponseType.FINISH] + + class Session: """Session for AsyncEngine.chat. @@ -246,6 +252,7 @@ class AsyncEngine(LogitsMixin): Default to None. max_log_len (int): Max number of prompt characters or prompt tokens being printed in log. Default: Unlimited + metrics (bool): Whether log the stats to prometheus. """ def __init__(self, @@ -255,6 +262,7 @@ def __init__(self, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, chat_template_config: Optional[ChatTemplateConfig] = None, max_log_len: int = None, + metrics: bool = False, **kwargs) -> None: logger.info(f'input backend={backend}, backend_config={backend_config}') logger.info(f'input chat_template_config={chat_template_config}') @@ -297,6 +305,8 @@ def __init__(self, self.request_logger = RequestLogger(max_log_len) self.internal_thread = _EventLoopThread(daemon=True) self.limiter: asyncio.Semaphore = None + self.metrics = Metrics(metrics) + self.metrics.info(self.backend_config) def close(self): self.internal_thread.close() @@ -585,6 +595,10 @@ async def model_inst(self, session_id: int): inst._active.set() free_insts.put_nowait(inst) + async def handle_exception(self, session_id: int): + self.metrics.failure_frame() + await self.stop_session(session_id) + @asynccontextmanager async def safe_run(self, inst, session_id, **kwargs): generator = inst.async_stream_infer(session_id, **kwargs) @@ -593,7 +607,7 @@ async def safe_run(self, inst, session_id, **kwargs): except (Exception, asyncio.CancelledError, GeneratorExit) as e: # noqa logger.error(f'[safe_run] exception caught: {type(e).__name__} {e}') # TODO: remove session_id from async cancel - await inst.async_cancel(session_id) + await self.handle_exception(session_id) finally: await generator.aclose() @@ -629,61 +643,70 @@ async def generate( """ if (messages is not None) ^ (input_ids is None): raise ValueError('You must specify exactly one of messages or input_ids') - if session_id not in self.id2step: - self.id2step[session_id] = 0 - if step != 0: - self.id2step[session_id] = step - if gen_config is None: - gen_config = GenerationConfig() - else: - gen_config = deepcopy(gen_config) - gen_config.convert_stop_bad_words_to_ids(self.tokenizer) - if gen_config.stop_token_ids is None: - gen_config.stop_token_ids = self.stop_words - gen_config.update_from_hf_gen_cfg(self.hf_gen_cfg, self.tokenizer.eos_token_id) - if not gen_config.do_sample: - logger.warning(f'GenerationConfig: {gen_config}') - logger.warning('Since v0.6.0, lmdeploy add `do_sample` in ' - 'GenerationConfig. It defaults to False, meaning greedy ' - 'decoding. Please set `do_sample=True` if sampling ' - ' decoding is needed') - # greedy decode - gen_config.top_k = 1 - # avoid unnecessary process - gen_config.temperature = 1.0 - gen_config.repetition_penalty = 1.0 - # set random if it is not set and sequence_start is True - elif gen_config.random_seed is None and sequence_start: - gen_config.random_seed = random.getrandbits(64) - if gen_config.n > 1: - logger.ERROR(f"n({gen_config.n}) > 1 hasn't been supported yet. " - f'Fallback to 1') - gen_config.n = 1 - if messages: - prompt = messages - self.request_logger.log_prompt(session_id=session_id, prompt=prompt) - prompt_input = await self._get_prompt_input(prompt, - do_preprocess, - sequence_start, - adapter_name, - tools=tools) - prompt = prompt_input['prompt'] - input_ids = prompt_input['input_ids'] - self.request_logger.log_inputs(session_id=session_id, - prompt=prompt, - prompt_token_ids=input_ids, - gen_config=gen_config, - adapter_name=adapter_name) - logger.info(f'session={session_id}, ' - f'history_tokens={self.id2step[session_id]}, ' - f'input_tokens={len(input_ids)}, ' - f'max_new_tokens={gen_config.max_new_tokens}, ' - f'seq_start={sequence_start}, seq_end={sequence_end}, ' - f'step={step}, prep={do_preprocess}') - else: - # TODO(lvhan) VLM doesn't support input_ids as an argument. - # Figure out a graceful way to handle the invalid input - prompt_input = dict(input_ids=input_ids) + + async def get_inputs_genconfig(gen_config): + if session_id not in self.id2step: + self.id2step[session_id] = 0 + if step != 0: + self.id2step[session_id] = step + if gen_config is None: + gen_config = GenerationConfig() + else: + gen_config = deepcopy(gen_config) + gen_config.convert_stop_bad_words_to_ids(self.tokenizer) + if gen_config.stop_token_ids is None: + gen_config.stop_token_ids = self.stop_words + gen_config.update_from_hf_gen_cfg(self.hf_gen_cfg, self.tokenizer.eos_token_id) + if not gen_config.do_sample: + logger.warning(f'GenerationConfig: {gen_config}') + logger.warning('Since v0.6.0, lmdeploy add `do_sample` in ' + 'GenerationConfig. It defaults to False, meaning greedy ' + 'decoding. Please set `do_sample=True` if sampling ' + ' decoding is needed') + # greedy decode + gen_config.top_k = 1 + # avoid unnecessary process + gen_config.temperature = 1.0 + gen_config.repetition_penalty = 1.0 + # set random if it is not set and sequence_start is True + elif gen_config.random_seed is None and sequence_start: + gen_config.random_seed = random.getrandbits(64) + if gen_config.n > 1: + logger.error(f"n({gen_config.n}) > 1 hasn't been supported yet. " + f'Fallback to 1') + gen_config.n = 1 + if messages: + prompt = messages + self.request_logger.log_prompt(session_id=session_id, prompt=prompt) + prompt_input = await self._get_prompt_input(prompt, + do_preprocess, + sequence_start, + adapter_name, + tools=tools) + prompt = prompt_input['prompt'] + input_ids = prompt_input['input_ids'] + self.request_logger.log_inputs(session_id=session_id, + prompt=prompt, + prompt_token_ids=input_ids, + gen_config=gen_config, + adapter_name=adapter_name) + logger.info(f'session={session_id}, ' + f'history_tokens={self.id2step[session_id]}, ' + f'input_tokens={len(input_ids)}, ' + f'max_new_tokens={gen_config.max_new_tokens}, ' + f'seq_start={sequence_start}, seq_end={sequence_end}, ' + f'step={step}, prep={do_preprocess}') + else: + # TODO(lvhan) VLM doesn't support input_ids as an argument. + # Figure out a graceful way to handle the invalid input + prompt_input = dict(input_ids=input_ids) + return prompt_input, gen_config + + arrival_frame = self.metrics.insert_frame() + prompt_input, gen_config = await get_inputs_genconfig(gen_config) + input_ids = prompt_input['input_ids'] + self.metrics.update_preprocess(arrival_frame) + if gen_config.max_new_tokens is None: # for interactive endpoint, will try maximum possible token num gen_config.max_new_tokens = max(128, self.session_len - self.id2step[session_id] - len(input_ids)) @@ -697,15 +720,14 @@ async def generate( await self.end_session(session_id) return - def is_error(status): - return status not in [ResponseType.SUCCESS, ResponseType.FINISH] - # used to skip / rewind stop words in interactive mode stop_ids = [] if skip_stop_tokens and not gen_config.ignore_eos: stop_ids = gen_config.stop_token_ids or [] + start_frame = self.metrics.insert_frame() async with self.model_inst(session_id) as inst: + self.metrics.update_queue_waiting(start_frame) token_ids = input_ids.copy() history_len = self.id2step[session_id] input_len = len(input_ids) @@ -723,9 +745,12 @@ def is_error(status): sequence_start=sequence_start, sequence_end=sequence_end, step=history_len) as gen: + if self.metrics.applied is True: + gen = IterTimer(gen) prev_len = 0 hit_stop_token = 0 async for outputs in gen: + is_first_token = prev_len == 0 # decode res if is_error(outputs.status): break @@ -771,7 +796,8 @@ def is_error(status): out.logits = outputs.logits if hit_stop_token: out.logits = out.logits[:-hit_stop_token] - + if is_first_token: + self.metrics.update_FTL(arrival_frame) yield out # end of generator loop @@ -807,6 +833,7 @@ def is_error(status): # rewind the step to the token before the stop token output_len = gen_len self.id2step[session_id] += input_len + output_len + self.metrics.last_token_frame(gen) def _run(self, fn=None, coro=None, loop=None): assert (fn or coro) and not (fn and coro) diff --git a/lmdeploy/serve/metrics.py b/lmdeploy/serve/metrics.py new file mode 100644 index 0000000000..2c2e5c6ae9 --- /dev/null +++ b/lmdeploy/serve/metrics.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dataclasses +import threading +import time +from dataclasses import dataclass +from typing import Dict, List, Optional + +import psutil +import pynvml +from prometheus_client import REGISTRY, Counter, Gauge, Histogram, Info, disable_created_metrics + +disable_created_metrics() + + +class IterTimer: + """"The timer to count all the time of iteration.""" + + def __init__(self, iterable): + self._iterable = iterable + self._duration = 0 + + def __iter__(self): + return self + + def __next__(self): + start = time.perf_counter() + item = next(iter(self._iterable)) + self._duration += (time.perf_counter() - start) + return item + + def get_duration(self): + """Get the whole duration of iteration. + + Known as model forwarding latency. + """ + return self._duration + + def __aiter__(self): + return self + + async def __anext__(self): + start = time.perf_counter() + item = await self._iterable.__anext__() + self._duration += (time.perf_counter() - start) + return item + + +@dataclass +class Stats: + """Log system information.""" + # system status + cpu_utilization: Optional[float] = None + cpu_memory_used_bytes: Optional[float] = None + gpu_utilization: Optional[Dict] = None + gpu_memory_used_bytes: Optional[Dict] = None + + def refresh(self): + """Fresh system status.""" + p = psutil.Process() + self.cpu_utilization = psutil.cpu_percent() + self.cpu_memory_used_bytes = p.memory_info().rss + pynvml.nvmlInit() + self.gpu_memory_used_bytes = {} + self.gpu_utilization = {} + for i in range(pynvml.nvmlDeviceGetCount()): + handle = pynvml.nvmlDeviceGetHandleByIndex(int(i)) + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + utilization = pynvml.nvmlDeviceGetUtilizationRates(handle) + self.gpu_memory_used_bytes[str(i)] = str(mem_info.used) + self.gpu_utilization[str(i)] = str(utilization.gpu) + + +def refresh_system(metrics): + """A thread life long function to get hardware information.""" + while True: + time.sleep(1) + # Log to prometheus. + stats = metrics.stats + stats.refresh() + # Info gpu stats + metrics.info_gpu_utilization.info(stats.gpu_utilization) + metrics.info_gpu_memory_used_bytes.info(stats.gpu_memory_used_bytes) + # Set system stat gauges. + metrics.gauge_cpu_utilization.set(stats.cpu_utilization) + metrics.gauge_cpu_memory_used_bytes.set(stats.cpu_memory_used_bytes) + + +class Metrics: + """The metrics for serving.""" + + def __init__(self, applied: bool = False, labelnames: Optional[List[str]] = []): + self.applied = applied + # Unregister any existing lmdeploy collectors + for collector in list(REGISTRY._collector_to_names): + if hasattr(collector, '_name') and 'lmdeploy' in collector._name: + REGISTRY.unregister(collector) + + # Config Information + self.info_backend_config = Info(name='lmdeploy:backend_config', documentation='information of backend_config') + + # System stats + self.info_gpu_utilization = Info(name='lmdeploy:gpu_utilization', + documentation='GPU utilization. 1 means 100 percent usage.') + self.info_gpu_memory_used_bytes = Info(name='lmdeploy:gpu_memory_used_bytes', + documentation='GPU memory used bytes.') + self.gauge_cpu_utilization = Gauge(name='lmdeploy:cpu_utilization', + documentation='CPU utilization. 1 means 100 percent usage.', + labelnames=labelnames) + self.gauge_cpu_memory_used_bytes = Gauge(name='lmdeploy:cpu_memory_used_bytes', + documentation='CPU memory used bytes.', + labelnames=labelnames) + + # requests + self.counter_request_success = Counter(name='lmdeploy:request_success', + documentation='Number of successful requests.', + labelnames=labelnames) + self.counter_request_failure = Counter(name='lmdeploy:request_failure', + documentation='Number of failed requests.', + labelnames=labelnames) + self.counter_request_total = Counter(name='lmdeploy:request_total', + documentation='Number of total requests.', + labelnames=labelnames) + + # latency metrics + self.histogram_duration_queue = Histogram( + name='lmdeploy:duration_queue_seconds', + documentation= # noqa + 'Avarate duration waiting in the queue of requests in s.', + labelnames=labelnames, + ) + self.histogram_duration_infer = Histogram( + name='lmdeploy:duration_infer_seconds', + documentation='Average inference time in s.', + labelnames=labelnames, + ) + self.histogram_duration_preprocess = Histogram( + name='lmdeploy:duration_preprocess_seconds', + documentation='Average duration of processing inputs in s.', + labelnames=labelnames, + ) + self.histogram_first_token_latency = Histogram( + name='lmdeploy:first_token_latency_seconds', + documentation='Average first token latency in s.', + labelnames=labelnames, + ) + self.stats = Stats() + self.refresh_thread = threading.Thread(target=refresh_system, args=(self, ), daemon=True) + self.refresh_thread.start() + + def info(self, backend_config: object) -> None: + if self.applied: + config_dict = {key: str(value) for key, value in dataclasses.asdict(backend_config).items()} + self.info_backend_config.info(config_dict) + + def failure_frame(self): + """log the failaure frame.""" + if self.applied: + self.counter_request_failure.inc() + self.counter_request_total.inc() + + def last_token_frame(self, iterator): + """log the last token frame.""" + if self.applied: + self.histogram_duration_infer.observe(iterator.get_duration()) + self.counter_request_success.inc() + self.counter_request_total.inc() + + def insert_frame(self): + """Insert a frame.""" + if self.applied: + return time.time() + return None + + def update_preprocess(self, start_frame): + """Update preprocess duration.""" + if self.applied: + self.histogram_duration_preprocess.observe(time.time() - start_frame) + + def update_queue_waiting(self, start_frame): + """Update queue waiting time.""" + if self.applied: + self.histogram_duration_queue.observe(time.time() - start_frame) + + def update_FTL(self, start_frame): + """Update first token latency.""" + if self.applied: + self.histogram_first_token_latency.observe(time.time() - start_frame) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 90965fb8b4..09b155028e 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -13,6 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +from prometheus_client import make_asgi_app from starlette.middleware.base import BaseHTTPMiddleware from lmdeploy.archs import get_task @@ -500,7 +501,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await VariableInterface.async_engine.stop_session(request.session_id) + await VariableInterface.async_engine.handle_exception(request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') final_res = res text += res.response @@ -715,7 +716,7 @@ async def _inner_call(i, generator): async for res in generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await VariableInterface.async_engine.stop_session(request.session_id) + await VariableInterface.async_engine.handle_exception(request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') final_res = res text += res.response @@ -842,7 +843,7 @@ async def chat_interactive_v1(request: GenerateRequest, raw_request: Request = N """ if request.cancel: if request.session_id != -1: - await VariableInterface.async_engine.stop_session(request.session_id) + await VariableInterface.async_engine.handle_exception(request.session_id) return {'text': '', 'tokens': 0, 'input_tokens': 0, 'history_tokens': 0, 'finish_reason': 'stop'} else: return create_error_response(HTTPStatus.BAD_REQUEST, 'please set a session_id to cancel a request') @@ -918,7 +919,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: async for out in generation: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await async_engine.stop_session(request.session_id) + await async_engine.handle_exception(request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') text += out.response tokens = out.generate_token_len @@ -1017,6 +1018,7 @@ def serve(model_path: str, proxy_url: Optional[str] = None, max_log_len: int = None, disable_fastapi_docs: bool = False, + metrics: bool = False, max_concurrent_requests: Optional[int] = None, reasoning_parser: Optional[str] = None, tool_call_parser: Optional[str] = None, @@ -1064,6 +1066,7 @@ def serve(model_path: str, proxy_url (str): The proxy url to register the api_server. max_log_len (int): Max number of prompt characters or prompt tokens being printed in log. Default: Unlimited + metrics (bool): Whether log stats to prometheus. max_concurrent_requests: This refers to the number of concurrent requests that the server can handle. The server is designed to process the engine’s tasks once the maximum number of concurrent @@ -1087,6 +1090,11 @@ def serve(model_path: str, app.include_router(router) + if metrics is True: + # Add prometheus asgi middleware to route /metrics requests + metrics_app = make_asgi_app() + app.mount('/metrics', metrics_app) + if allow_origins: app.add_middleware( CORSMiddleware, @@ -1118,6 +1126,7 @@ def serve(model_path: str, backend_config=backend_config, chat_template_config=chat_template_config, max_log_len=max_log_len, + metrics=metrics, **kwargs) # set reasoning parser and tool parser set_parsers(reasoning_parser, tool_call_parser) diff --git a/requirements/runtime_cuda.txt b/requirements/runtime_cuda.txt index 9169c6ff30..d4125dd209 100644 --- a/requirements/runtime_cuda.txt +++ b/requirements/runtime_cuda.txt @@ -8,6 +8,7 @@ openai outlines peft<=0.14.0 pillow +prometheus_client >= 0.18.0 protobuf pydantic>2.0.0 pynvml