From 1347fa7c78cccc3985b3cf0c6afe2485d1da7fa4 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Fri, 15 Nov 2024 15:05:16 +0800 Subject: [PATCH 1/7] Add metrics monitor --- lmdeploy/cli/serve.py | 5 + lmdeploy/serve/async_engine.py | 139 ++++++++++++++++++---------- lmdeploy/serve/openai/api_server.py | 16 +++- requirements/runtime.txt | 1 + 4 files changed, 109 insertions(+), 52 deletions(-) diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 68f9de8c15..a745ab45bd 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -139,6 +139,10 @@ def add_parser_api_server(): type=str, default=None, help='The proxy url for api server.') + parser.add_argument('--metrics', + action='store_true', + default=False, + help='Whether log stats to prometheus') # common args ArgumentHelper.backend(parser) ArgumentHelper.log_level(parser) @@ -352,6 +356,7 @@ def api_server(args): ssl=args.ssl, proxy_url=args.proxy_url, max_log_len=args.max_log_len, + metrics=args.metrics, disable_fastapi_docs=args.disable_fastapi_docs) @staticmethod diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 3c8f193cd5..64e8ee7a44 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -4,6 +4,7 @@ import json import os import random +import time from contextlib import asynccontextmanager from copy import deepcopy from itertools import count @@ -15,6 +16,7 @@ from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig, Response, ResponseType, TurbomindEngineConfig) from lmdeploy.model import MODELS, ChatTemplateConfig, best_match_model +from lmdeploy.serve.metrics import IterTimer, Metrics, Stats from lmdeploy.serve.utils import LogitsMixin, _get_event_loop from lmdeploy.tokenizer import DetokenizeState from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_logger @@ -128,6 +130,7 @@ class AsyncEngine(LogitsMixin): Default to None. max_log_len (int): Max number of prompt characters or prompt tokens being printed in log. Default: Unlimited + metrics (bool): Whether log the stats to prometheus. """ def __init__(self, @@ -138,6 +141,7 @@ def __init__(self, PytorchEngineConfig]] = None, chat_template_config: Optional[ChatTemplateConfig] = None, max_log_len: int = None, + metrics: bool = False, **kwargs) -> None: logger.info( f'input backend={backend}, backend_config={backend_config}') @@ -186,6 +190,12 @@ def __init__(self, self._session_id = count(0) self.request_logger = RequestLogger(max_log_len) + self.apply_metrics = metrics + if self.apply_metrics: + self.stats = Stats(now=time.time()) + self.metrics = Metrics() + self.metrics.info(self.backend_config) + def _build_turbomind( self, model_path: str, @@ -242,6 +252,12 @@ def __call__(self, use_tqdm=use_tqdm, **kwargs) + async def handle_exception(self, session_id: int): + if self.metrics: + self.stats.request_failure += 1 + self.stats.request_total += 1 + await self.stop_session(session_id) + async def stop_session(self, session_id: int): """Stop a session by a session_id.""" if str(session_id) in self.id2generator: @@ -266,7 +282,7 @@ async def safe_run(self, session_id: Optional[int] = None): yield except (Exception, asyncio.CancelledError, GeneratorExit) as e: # noqa # TODO: find out why await would block the coroutine here - _get_event_loop().create_task(self.stop_session(session_id)) + _get_event_loop().create_task(self.handle_exception(session_id)) raise e if str(session_id) in self.id2generator: self.gens_set.add(self.id2generator[str(session_id)]) @@ -274,6 +290,8 @@ async def safe_run(self, session_id: Optional[int] = None): async def get_generator(self, stop: bool, session_id: int): """Only return the model instance if it is available.""" + if self.apply_metrics: + start = time.time() if stop: return self.engine.create_instance() # waiting no generator is available or the same session_id is running @@ -282,6 +300,8 @@ async def get_generator(self, stop: bool, session_id: int): generator = self.gens_set.pop() self.id2generator[str(session_id)] = generator self.running_session_ids.add(session_id) + if self.apply_metrics: + self.stats.duration_queue += time.time() - start return generator def batch_infer(self, @@ -489,43 +509,54 @@ async def generate( do_preprocess (bool): whether pre-process the messages. Default to True, which means chat_template will be applied. """ - if str(session_id) not in self.id2step: - self.id2step[str(session_id)] = 0 - if step != 0: - self.id2step[str(session_id)] = step - if gen_config is None: - gen_config = GenerationConfig() - else: - gen_config = deepcopy(gen_config) - gen_config.convert_stop_bad_words_to_ids(self.tokenizer) - if gen_config.stop_token_ids is None: - gen_config.stop_token_ids = self.stop_words - if not gen_config.do_sample: - logger.warn(f'GenerationConfig: {gen_config}') - logger.warn( - 'Since v0.6.0, lmdeploy add `do_sample` in ' - 'GenerationConfig. It defaults to False, meaning greedy ' - 'decoding. Please set `do_sample=True` if sampling ' - ' decoding is needed') - # greedy decode - gen_config.top_k = 1 - # avoid unnecessary process - gen_config.temperature = 1.0 - gen_config.repetition_penalty = 1.0 - # set random if it is not set and sequence_start is True - elif gen_config.random_seed is None and sequence_start: - gen_config.random_seed = random.getrandbits(64) - if gen_config.n > 1: - logger.ERROR(f"n({gen_config.n}) > 1 hasn't been supported yet. " - f'Fallback to 1') - gen_config.n = 1 - prompt = messages - self.request_logger.log_prompt(session_id=session_id, prompt=prompt) - prompt_input = await self._get_prompt_input(prompt, - do_preprocess, - sequence_start, - adapter_name, - tools=tools) + + async def get_inputs_genconfig(gen_config): + if self.apply_metrics: + start = time.time() + if str(session_id) not in self.id2step: + self.id2step[str(session_id)] = 0 + if step != 0: + self.id2step[str(session_id)] = step + if gen_config is None: + gen_config = GenerationConfig() + else: + gen_config = deepcopy(gen_config) + gen_config.convert_stop_bad_words_to_ids(self.tokenizer) + if gen_config.stop_token_ids is None: + gen_config.stop_token_ids = self.stop_words + if not gen_config.do_sample: + logger.warn(f'GenerationConfig: {gen_config}') + logger.warn( + 'Since v0.6.0, lmdeploy add `do_sample` in ' + 'GenerationConfig. It defaults to False, meaning greedy ' + 'decoding. Please set `do_sample=True` if sampling ' + ' decoding is needed') + # greedy decode + gen_config.top_k = 1 + # avoid unnecessary process + gen_config.temperature = 1.0 + gen_config.repetition_penalty = 1.0 + # set random if it is not set and sequence_start is True + elif gen_config.random_seed is None and sequence_start: + gen_config.random_seed = random.getrandbits(64) + if gen_config.n > 1: + logger.error( + f"n({gen_config.n}) > 1 hasn't been supported yet. " + f'Fallback to 1') + gen_config.n = 1 + prompt = messages + self.request_logger.log_prompt(session_id=session_id, + prompt=prompt) + prompt_input = await self._get_prompt_input(prompt, + do_preprocess, + sequence_start, + adapter_name, + tools=tools) + if self.apply_metrics: + self.stats.duration_preprocess += time.time() - start + return prompt_input, gen_config + + prompt_input, gen_config = await get_inputs_genconfig(gen_config) prompt = prompt_input['prompt'] input_ids = prompt_input['input_ids'] finish_reason = None @@ -568,19 +599,24 @@ def is_error(status): ] generator = await self.get_generator(False, session_id) + iterator = generator.async_stream_infer( + session_id=session_id, + **prompt_input, + gen_config=gen_config, + adapter_name=adapter_name, + stream_output=stream_response, + sequence_start=sequence_start, + sequence_end=sequence_end, + step=self.id2step[str(session_id)]) + if self.apply_metrics: + iterator = IterTimer(iterator) async with self.safe_run(session_id): state = DetokenizeState(len(input_ids)) start_ids_offset = state.ids_offset response = '' - async for outputs in generator.async_stream_infer( - session_id=session_id, - **prompt_input, - gen_config=gen_config, - adapter_name=adapter_name, - stream_output=stream_response, - sequence_start=sequence_start, - sequence_end=sequence_end, - step=self.id2step[str(session_id)]): + async for outputs in iterator: + if self.apply_metrics: + start = time.perf_counter() # decode res if is_error(outputs.status): tokens = 0 @@ -600,7 +636,9 @@ def is_error(status): if outputs.logprobs: log_offset = ids_offset - start_ids_offset logprobs = outputs.logprobs[log_offset:] - + if self.apply_metrics: + self.stats.duration_postprocess += time.perf_counter( + ) - start # response, history token len, # input token len, gen token len yield GenOut(response, self.id2step[str(session_id)], @@ -632,6 +670,11 @@ def is_error(status): # TODO modify pytorch or turbomind api if self.backend == 'pytorch' and sequence_end: await self.end_session(session_id) + if self.apply_metrics: + self.stats.duration_infer += iterator.get_duration() + self.stats.request_success += 1 + self.stats.request_total += 1 + self.metrics.log(self.stats) def parse_tool_response(self, text, tools, **kwargs): """Parse model response containing tool information. diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index a12cadaa7d..363f18aec0 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +from prometheus_client import make_asgi_app from lmdeploy.archs import get_task from lmdeploy.messages import (GenerationConfig, LogitsProcessor, @@ -483,7 +484,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await VariableInterface.async_engine.stop_session( + await VariableInterface.async_engine.handle_exception( request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') @@ -711,7 +712,7 @@ async def _inner_call(i, generator): async for res in generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await VariableInterface.async_engine.stop_session( + await VariableInterface.async_engine.handle_exception( request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') @@ -843,7 +844,7 @@ async def chat_interactive_v1(request: GenerateRequest, """ if request.cancel: if request.session_id != -1: - await VariableInterface.async_engine.stop_session( + await VariableInterface.async_engine.handle_exception( request.session_id) return { 'text': '', @@ -927,7 +928,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: async for out in generation: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await async_engine.stop_session(request.session_id) + await async_engine.handle_exception(request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') text += out.response @@ -990,6 +991,7 @@ def serve(model_path: str, proxy_url: Optional[str] = None, max_log_len: int = None, disable_fastapi_docs: bool = False, + metrics: bool = False, **kwargs): """An example to perform model inference through the command line interface. @@ -1034,6 +1036,7 @@ def serve(model_path: str, proxy_url (str): The proxy url to register the api_server. max_log_len (int): Max number of prompt characters or prompt tokens being printed in log. Default: Unlimited + metrics (bool): Whether log stats to prometheus. """ if os.getenv('TM_LOG_LEVEL') is None: os.environ['TM_LOG_LEVEL'] = log_level @@ -1050,6 +1053,11 @@ def serve(model_path: str, app.include_router(router) + if metrics is True: + # Add prometheus asgi middleware to route /metrics requests + metrics_app = make_asgi_app() + app.mount('/metrics', metrics_app) + if allow_origins: app.add_middleware( CORSMiddleware, diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 400c492b09..b69f7ce3c3 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -8,6 +8,7 @@ openai outlines<0.1.0 peft<=0.11.1 pillow +prometheus_client >= 0.18.0 protobuf pydantic>2.0.0 pynvml From af0bacd8b26fd285ca8b89f8a98e38ec0f8a1d68 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Fri, 15 Nov 2024 15:05:52 +0800 Subject: [PATCH 2/7] fix --- lmdeploy/serve/metrics.py | 180 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 lmdeploy/serve/metrics.py diff --git a/lmdeploy/serve/metrics.py b/lmdeploy/serve/metrics.py new file mode 100644 index 0000000000..b0ae8528b2 --- /dev/null +++ b/lmdeploy/serve/metrics.py @@ -0,0 +1,180 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dataclasses +import time +from dataclasses import dataclass +from typing import Dict, List, Optional + +import psutil +import pynvml +from prometheus_client import REGISTRY, Gauge, Info, disable_created_metrics + +disable_created_metrics() + + +class IterTimer: + + def __init__(self, iterable): + self._iterable = iterable + self._duration = 0 + + def __iter__(self): + return self + + def __next__(self): + start = time.perf_counter() + item = next(iter(self._iterable)) + self._duration += (time.perf_counter() - start) + return item + + def get_duration(self): + return self._duration + + def __aiter__(self): + return self + + async def __anext__(self): + start = time.perf_counter() + item = await self._iterable.__anext__() + self._duration += (time.perf_counter() - start) + return item + + +@dataclass +class Stats: + """Created by LLMEngine for use by StatLogger.""" + now: float + + # request stats + request_success: int = 0 + request_failure: int = 0 + request_total: int = 0 + request_responding: int = 0 + request_waiting: int = 0 + + # latency stats + duration_queue: float = 0 + duration_infer: float = 0 + duration_preprocess: float = 0 + duration_postprocess: float = 0 + + # system status + cpu_utilization: Optional[float] = None + cpu_memory_used_bytes: Optional[float] = None + gpu_utilization: Optional[Dict] = None + gpu_memory_used_bytes: Optional[Dict] = None + + def refresh(self): + """Fresh system status.""" + p = psutil.Process() + self.cpu_utilization = p.cpu_percent() + self.cpu_memory_used_bytes = p.memory_info().rss + pynvml.nvmlInit() + self.gpu_memory_used_bytes = {} + self.gpu_utilization = {} + for i in range(pynvml.nvmlDeviceGetCount()): + handle = pynvml.nvmlDeviceGetHandleByIndex(int(i)) + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + utilization = pynvml.nvmlDeviceGetUtilizationRates(handle) + self.gpu_memory_used_bytes[str(i)] = str(mem_info.used) + self.gpu_utilization[str(i)] = str(utilization.gpu) + + +class Metrics: + + def __init__(self, labelnames: Optional[List[str]] = []): + # Unregister any existing lmdeploy collectors + for collector in list(REGISTRY._collector_to_names): + if hasattr(collector, '_name') and 'lmdeploy' in collector._name: + REGISTRY.unregister(collector) + + # Config Information + self.info_backend_config = Info( + name='lmdeploy:backend_config', + documentation='information of backend_config') + + # System stats + self.info_gpu_utilization = Info( + name='lmdeploy:gpu_utilization', + documentation='GPU utilization. 1 means 100 percent usage.') + self.info_gpu_memory_used_bytes = Info( + name='lmdeploy:gpu_memory_used_bytes', + documentation='GPU memory used bytes.') + self.gauge_cpu_utilization = Gauge( + name='lmdeploy:cpu_utilization', + documentation='CPU utilization. 1 means 100 percent usage.', + labelnames=labelnames) + self.gauge_cpu_memory_used_bytes = Gauge( + name='lmdeploy:cpu_memory_used_bytes', + documentation='CPU memory used bytes.', + labelnames=labelnames) + + # requests + self.gauge_request_success = Gauge( + name='lmdeploy:request_success', + documentation='Number of successful requests.', + labelnames=labelnames) + self.gauge_request_failure = Gauge( + name='lmdeploy:request_failure', + documentation='Number of failed requests.', + labelnames=labelnames) + self.gauge_request_total = Gauge( + name='lmdeploy:request_total', + documentation='Number of total requests.', + labelnames=labelnames) + + # latency metrics + self.gauge_duration_queue = Gauge( + name='lmdeploy:duration_queue', + documentation= # noqa + 'Avarate duration waiting in the queue of requests in s.', + labelnames=labelnames, + ) + self.gauge_duration_infer = Gauge( + name='lmdeploy:duration_infer', + documentation='Average inference time in s.', + labelnames=labelnames, + ) + self.gauge_duration_preprocess = Gauge( + name='lmdeploy:duration_preprocess', + documentation='Average duration of processing inputs in s.', + labelnames=labelnames, + ) + self.gauge_duration_postprocess = Gauge( + name='lmdeploy:duration_postprocess', + documentation='Average duration of processing outputs in s.', + labelnames=labelnames, + ) + + def info(self, backend_config: object) -> None: + config_dict = { + key: str(value) + for key, value in dataclasses.asdict(backend_config).items() + } + self.info_backend_config.info(config_dict) + + def log(self, stats: Stats) -> None: + """Called by LLMEngine. + + Logs to prometheus and tracked stats every iteration. Logs to Stdout + every self.local_interval seconds. + """ + + # Log to prometheus. + stats.refresh() + # Info gpu stats + self.info_gpu_utilization.info(stats.gpu_utilization) + self.info_gpu_memory_used_bytes.info(stats.gpu_memory_used_bytes) + # Set system stat gauges. + self.gauge_cpu_utilization.set(stats.cpu_utilization) + self.gauge_cpu_memory_used_bytes.set(stats.cpu_memory_used_bytes) + + # Add to request counters. + self.gauge_request_total.set(stats.request_total) + self.gauge_request_success.set(stats.request_success) + self.gauge_request_failure.set(stats.request_failure) + + # duration gauges + self.gauge_duration_infer.set(stats.duration_infer) + self.gauge_duration_queue.set(stats.duration_queue) + self.gauge_duration_preprocess.set(stats.duration_preprocess) + self.gauge_duration_postprocess.set(stats.duration_postprocess) From d4c535d04e847f4aff7260b7556501fe8769c362 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Fri, 15 Nov 2024 15:35:31 +0800 Subject: [PATCH 3/7] fix --- lmdeploy/serve/async_engine.py | 2 +- lmdeploy/serve/metrics.py | 34 +++++++++++++++++++---------- lmdeploy/serve/openai/api_server.py | 1 + 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 64e8ee7a44..baa7672873 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -193,7 +193,7 @@ def __init__(self, self.apply_metrics = metrics if self.apply_metrics: self.stats = Stats(now=time.time()) - self.metrics = Metrics() + self.metrics = Metrics(self.stats) self.metrics.info(self.backend_config) def _build_turbomind( diff --git a/lmdeploy/serve/metrics.py b/lmdeploy/serve/metrics.py index b0ae8528b2..8875884883 100644 --- a/lmdeploy/serve/metrics.py +++ b/lmdeploy/serve/metrics.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import dataclasses +import threading import time from dataclasses import dataclass from typing import Dict, List, Optional @@ -66,7 +67,7 @@ class Stats: def refresh(self): """Fresh system status.""" p = psutil.Process() - self.cpu_utilization = p.cpu_percent() + self.cpu_utilization = psutil.cpu_percent() self.cpu_memory_used_bytes = p.memory_info().rss pynvml.nvmlInit() self.gpu_memory_used_bytes = {} @@ -79,9 +80,24 @@ def refresh(self): self.gpu_utilization[str(i)] = str(utilization.gpu) +def refresh_system(metrics): + + while True: + time.sleep(1) + # Log to prometheus. + stats = metrics.stats + stats.refresh() + # Info gpu stats + metrics.info_gpu_utilization.info(stats.gpu_utilization) + metrics.info_gpu_memory_used_bytes.info(stats.gpu_memory_used_bytes) + # Set system stat gauges. + metrics.gauge_cpu_utilization.set(stats.cpu_utilization) + metrics.gauge_cpu_memory_used_bytes.set(stats.cpu_memory_used_bytes) + + class Metrics: - def __init__(self, labelnames: Optional[List[str]] = []): + def __init__(self, stats: Stats, labelnames: Optional[List[str]] = []): # Unregister any existing lmdeploy collectors for collector in list(REGISTRY._collector_to_names): if hasattr(collector, '_name') and 'lmdeploy' in collector._name: @@ -144,6 +160,11 @@ def __init__(self, labelnames: Optional[List[str]] = []): documentation='Average duration of processing outputs in s.', labelnames=labelnames, ) + self.stats = stats + self.refresh_thread = threading.Thread(target=refresh_system, + args=(self, ), + daemon=True) + self.refresh_thread.start() def info(self, backend_config: object) -> None: config_dict = { @@ -159,15 +180,6 @@ def log(self, stats: Stats) -> None: every self.local_interval seconds. """ - # Log to prometheus. - stats.refresh() - # Info gpu stats - self.info_gpu_utilization.info(stats.gpu_utilization) - self.info_gpu_memory_used_bytes.info(stats.gpu_memory_used_bytes) - # Set system stat gauges. - self.gauge_cpu_utilization.set(stats.cpu_utilization) - self.gauge_cpu_memory_used_bytes.set(stats.cpu_memory_used_bytes) - # Add to request counters. self.gauge_request_total.set(stats.request_total) self.gauge_request_success.set(stats.request_success) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 363f18aec0..75bac72d5d 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1085,6 +1085,7 @@ def serve(model_path: str, backend_config=backend_config, chat_template_config=chat_template_config, max_log_len=max_log_len, + metrics=metrics, **kwargs) if proxy_url is not None: From 129d77f97ef3cd760e8aed4d6fe658b7dd4cfc3b Mon Sep 17 00:00:00 2001 From: AllentDan Date: Tue, 19 Nov 2024 16:59:24 +0800 Subject: [PATCH 4/7] add FTL and refactor metrics --- lmdeploy/serve/async_engine.py | 54 +++++++------------- lmdeploy/serve/metrics.py | 93 ++++++++++++++++++++++++++++------ 2 files changed, 96 insertions(+), 51 deletions(-) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 8470add8d4..af680276a5 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -5,7 +5,6 @@ import os import random import re -import time from contextlib import asynccontextmanager from copy import deepcopy from itertools import count @@ -17,7 +16,7 @@ from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig, Response, ResponseType, TurbomindEngineConfig) from lmdeploy.model import MODELS, ChatTemplateConfig, best_match_model -from lmdeploy.serve.metrics import IterTimer, Metrics, Stats +from lmdeploy.serve.metrics import IterTimer, Metrics from lmdeploy.serve.utils import LogitsMixin, _get_event_loop from lmdeploy.tokenizer import DetokenizeState from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_logger @@ -103,6 +102,11 @@ def __repr__(self) -> str: return res +def is_error(status: ResponseType): + """Whether is an error response.""" + return status not in [ResponseType.SUCCESS, ResponseType.FINISH] + + class AsyncEngine(LogitsMixin): """Async inference engine. Maintaining a bunch of tm_model instances. @@ -191,11 +195,8 @@ def __init__(self, self._session_id = count(0) self.request_logger = RequestLogger(max_log_len) - self.apply_metrics = metrics - if self.apply_metrics: - self.stats = Stats(now=time.time()) - self.metrics = Metrics(self.stats) - self.metrics.info(self.backend_config) + self.metrics = Metrics(metrics) + self.metrics.info(self.backend_config) def _build_turbomind( self, @@ -254,9 +255,7 @@ def __call__(self, **kwargs) async def handle_exception(self, session_id: int): - if self.metrics: - self.stats.request_failure += 1 - self.stats.request_total += 1 + await self.metrics.failure_frame() await self.stop_session(session_id) async def stop_session(self, session_id: int): @@ -291,8 +290,6 @@ async def safe_run(self, session_id: Optional[int] = None): async def get_generator(self, stop: bool, session_id: int): """Only return the model instance if it is available.""" - if self.apply_metrics: - start = time.time() if stop: return self.engine.create_instance() # waiting no generator is available or the same session_id is running @@ -301,8 +298,6 @@ async def get_generator(self, stop: bool, session_id: int): generator = self.gens_set.pop() self.id2generator[str(session_id)] = generator self.running_session_ids.add(session_id) - if self.apply_metrics: - self.stats.duration_queue += time.time() - start return generator def batch_infer(self, @@ -512,8 +507,6 @@ async def generate( """ async def get_inputs_genconfig(gen_config): - if self.apply_metrics: - start = time.time() if str(session_id) not in self.id2step: self.id2step[str(session_id)] = 0 if step != 0: @@ -553,11 +546,11 @@ async def get_inputs_genconfig(gen_config): sequence_start, adapter_name, tools=tools) - if self.apply_metrics: - self.stats.duration_preprocess += time.time() - start return prompt_input, gen_config + arrival_frame = await self.metrics.insert_frame() prompt_input, gen_config = await get_inputs_genconfig(gen_config) + await self.metrics.update_preprocess(arrival_frame) prompt = prompt_input['prompt'] input_ids = prompt_input['input_ids'] finish_reason = None @@ -593,13 +586,9 @@ async def get_inputs_genconfig(gen_config): if sequence_end is True and sequence_start is False: await self.end_session(session_id) else: - - def is_error(status): - return status not in [ - ResponseType.SUCCESS, ResponseType.FINISH - ] - + start_frame = await self.metrics.insert_frame() generator = await self.get_generator(False, session_id) + await self.metrics.update_queue_waiting(start_frame) iterator = generator.async_stream_infer( session_id=session_id, **prompt_input, @@ -609,15 +598,16 @@ def is_error(status): sequence_start=sequence_start, sequence_end=sequence_end, step=self.id2step[str(session_id)]) - if self.apply_metrics: + if self.metrics.applied is True: iterator = IterTimer(iterator) async with self.safe_run(session_id): state = DetokenizeState(len(input_ids)) start_ids_offset = state.ids_offset response = '' async for outputs in iterator: - if self.apply_metrics: - start = time.perf_counter() + start_frame = await self.metrics.insert_frame() + if state.prev_tokens is None: + await self.metrics.update_FTL(arrival_frame) # decode res if is_error(outputs.status): tokens = 0 @@ -637,9 +627,7 @@ def is_error(status): if outputs.logprobs: log_offset = ids_offset - start_ids_offset logprobs = outputs.logprobs[log_offset:] - if self.apply_metrics: - self.stats.duration_postprocess += time.perf_counter( - ) - start + await self.metrics.update_postprocess(start_frame) # response, history token len, # input token len, gen token len yield GenOut(response, self.id2step[str(session_id)], @@ -671,11 +659,7 @@ def is_error(status): # TODO modify pytorch or turbomind api if self.backend == 'pytorch' and sequence_end: await self.end_session(session_id) - if self.apply_metrics: - self.stats.duration_infer += iterator.get_duration() - self.stats.request_success += 1 - self.stats.request_total += 1 - self.metrics.log(self.stats) + await self.metrics.last_token_frame(iterator) def parse_tool_response(self, text, tools, **kwargs): """Parse model response containing tool information. diff --git a/lmdeploy/serve/metrics.py b/lmdeploy/serve/metrics.py index 8875884883..2b84bcb7c5 100644 --- a/lmdeploy/serve/metrics.py +++ b/lmdeploy/serve/metrics.py @@ -13,6 +13,7 @@ class IterTimer: + """"The timer to count all the time of iteration.""" def __init__(self, iterable): self._iterable = iterable @@ -28,6 +29,10 @@ def __next__(self): return item def get_duration(self): + """Get the whole duration of iteration. + + Known as model forwarding latency. + """ return self._duration def __aiter__(self): @@ -43,7 +48,6 @@ async def __anext__(self): @dataclass class Stats: """Created by LLMEngine for use by StatLogger.""" - now: float # request stats request_success: int = 0 @@ -57,6 +61,7 @@ class Stats: duration_infer: float = 0 duration_preprocess: float = 0 duration_postprocess: float = 0 + first_token_latency: float = 0 # system status cpu_utilization: Optional[float] = None @@ -81,7 +86,7 @@ def refresh(self): def refresh_system(metrics): - + """A thread life long function to get hardware information.""" while True: time.sleep(1) # Log to prometheus. @@ -96,8 +101,12 @@ def refresh_system(metrics): class Metrics: + """The metrics for serving.""" - def __init__(self, stats: Stats, labelnames: Optional[List[str]] = []): + def __init__(self, + applied: bool = False, + labelnames: Optional[List[str]] = []): + self.applied = applied # Unregister any existing lmdeploy collectors for collector in list(REGISTRY._collector_to_names): if hasattr(collector, '_name') and 'lmdeploy' in collector._name: @@ -160,33 +169,85 @@ def __init__(self, stats: Stats, labelnames: Optional[List[str]] = []): documentation='Average duration of processing outputs in s.', labelnames=labelnames, ) - self.stats = stats + self.gauge_first_token_latency = Gauge( + name='lmdeploy:first_token_latency', + documentation='Average first token latency in s.', + labelnames=labelnames, + ) + self.stats = Stats() self.refresh_thread = threading.Thread(target=refresh_system, args=(self, ), daemon=True) self.refresh_thread.start() def info(self, backend_config: object) -> None: - config_dict = { - key: str(value) - for key, value in dataclasses.asdict(backend_config).items() - } - self.info_backend_config.info(config_dict) - - def log(self, stats: Stats) -> None: + if self.applied: + config_dict = { + key: str(value) + for key, value in dataclasses.asdict(backend_config).items() + } + self.info_backend_config.info(config_dict) + + async def failure_frame(self): + """log the failaure frame.""" + if self.applied: + self.stats.request_failure += 1 + self.stats.request_total += 1 + + async def last_token_frame(self, iterator): + """log the last token frame.""" + if self.applied: + self.stats.duration_infer += iterator.get_duration() + self.stats.request_success += 1 + self.stats.request_total += 1 + self.log() + + async def insert_frame(self): + """Insert a frame.""" + if self.applied: + return time.time() + return None + + async def update_postprocess(self, start_frame): + """Update postprocess duration.""" + if self.applied: + self.stats.duration_postprocess += time.time() - start_frame + + async def update_preprocess(self, start_frame): + """Update preprocess duration.""" + if self.applied: + self.stats.duration_preprocess += time.time() - start_frame + + async def update_queue_waiting(self, start_frame): + """Update queue waiting time.""" + if self.applied: + self.stats.duration_queue += time.time() - start_frame + + async def update_FTL(self, start_frame): + """Update first token latency.""" + if self.applied: + self.stats.first_token_latency += time.time() - start_frame + + def log(self) -> None: """Called by LLMEngine. Logs to prometheus and tracked stats every iteration. Logs to Stdout every self.local_interval seconds. """ - + stats = self.stats # Add to request counters. self.gauge_request_total.set(stats.request_total) self.gauge_request_success.set(stats.request_success) self.gauge_request_failure.set(stats.request_failure) # duration gauges - self.gauge_duration_infer.set(stats.duration_infer) - self.gauge_duration_queue.set(stats.duration_queue) - self.gauge_duration_preprocess.set(stats.duration_preprocess) - self.gauge_duration_postprocess.set(stats.duration_postprocess) + self.gauge_duration_infer.set(stats.duration_infer / + stats.request_total) + self.gauge_duration_queue.set(stats.duration_queue / + stats.request_total) + self.gauge_duration_preprocess.set(stats.duration_preprocess / + stats.request_total) + self.gauge_duration_postprocess.set(stats.duration_postprocess / + stats.request_total) + self.gauge_first_token_latency.set(stats.first_token_latency / + stats.request_total) From 66cb69bf71d8108d9722746961f4a564e22298e7 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Tue, 19 Nov 2024 17:25:17 +0800 Subject: [PATCH 5/7] refine --- lmdeploy/serve/async_engine.py | 21 +++++++++++---------- lmdeploy/serve/metrics.py | 14 +++++++------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index af680276a5..f6a72954fc 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -255,7 +255,7 @@ def __call__(self, **kwargs) async def handle_exception(self, session_id: int): - await self.metrics.failure_frame() + self.metrics.failure_frame() await self.stop_session(session_id) async def stop_session(self, session_id: int): @@ -548,9 +548,9 @@ async def get_inputs_genconfig(gen_config): tools=tools) return prompt_input, gen_config - arrival_frame = await self.metrics.insert_frame() + arrival_frame = self.metrics.insert_frame() prompt_input, gen_config = await get_inputs_genconfig(gen_config) - await self.metrics.update_preprocess(arrival_frame) + self.metrics.update_preprocess(arrival_frame) prompt = prompt_input['prompt'] input_ids = prompt_input['input_ids'] finish_reason = None @@ -586,9 +586,9 @@ async def get_inputs_genconfig(gen_config): if sequence_end is True and sequence_start is False: await self.end_session(session_id) else: - start_frame = await self.metrics.insert_frame() + start_frame = self.metrics.insert_frame() generator = await self.get_generator(False, session_id) - await self.metrics.update_queue_waiting(start_frame) + self.metrics.update_queue_waiting(start_frame) iterator = generator.async_stream_infer( session_id=session_id, **prompt_input, @@ -605,9 +605,8 @@ async def get_inputs_genconfig(gen_config): start_ids_offset = state.ids_offset response = '' async for outputs in iterator: - start_frame = await self.metrics.insert_frame() - if state.prev_tokens is None: - await self.metrics.update_FTL(arrival_frame) + start_frame = self.metrics.insert_frame() + is_first_token = state.prev_tokens is None # decode res if is_error(outputs.status): tokens = 0 @@ -627,7 +626,9 @@ async def get_inputs_genconfig(gen_config): if outputs.logprobs: log_offset = ids_offset - start_ids_offset logprobs = outputs.logprobs[log_offset:] - await self.metrics.update_postprocess(start_frame) + self.metrics.update_postprocess(start_frame) + if is_first_token: + self.metrics.update_FTL(arrival_frame) # response, history token len, # input token len, gen token len yield GenOut(response, self.id2step[str(session_id)], @@ -659,7 +660,7 @@ async def get_inputs_genconfig(gen_config): # TODO modify pytorch or turbomind api if self.backend == 'pytorch' and sequence_end: await self.end_session(session_id) - await self.metrics.last_token_frame(iterator) + self.metrics.last_token_frame(iterator) def parse_tool_response(self, text, tools, **kwargs): """Parse model response containing tool information. diff --git a/lmdeploy/serve/metrics.py b/lmdeploy/serve/metrics.py index 2b84bcb7c5..a0713174d6 100644 --- a/lmdeploy/serve/metrics.py +++ b/lmdeploy/serve/metrics.py @@ -188,13 +188,13 @@ def info(self, backend_config: object) -> None: } self.info_backend_config.info(config_dict) - async def failure_frame(self): + def failure_frame(self): """log the failaure frame.""" if self.applied: self.stats.request_failure += 1 self.stats.request_total += 1 - async def last_token_frame(self, iterator): + def last_token_frame(self, iterator): """log the last token frame.""" if self.applied: self.stats.duration_infer += iterator.get_duration() @@ -202,28 +202,28 @@ async def last_token_frame(self, iterator): self.stats.request_total += 1 self.log() - async def insert_frame(self): + def insert_frame(self): """Insert a frame.""" if self.applied: return time.time() return None - async def update_postprocess(self, start_frame): + def update_postprocess(self, start_frame): """Update postprocess duration.""" if self.applied: self.stats.duration_postprocess += time.time() - start_frame - async def update_preprocess(self, start_frame): + def update_preprocess(self, start_frame): """Update preprocess duration.""" if self.applied: self.stats.duration_preprocess += time.time() - start_frame - async def update_queue_waiting(self, start_frame): + def update_queue_waiting(self, start_frame): """Update queue waiting time.""" if self.applied: self.stats.duration_queue += time.time() - start_frame - async def update_FTL(self, start_frame): + def update_FTL(self, start_frame): """Update first token latency.""" if self.applied: self.stats.first_token_latency += time.time() - start_frame From e70f4780533f4bab3cd43754f69e668b4698c57a Mon Sep 17 00:00:00 2001 From: AllentDan Date: Thu, 28 Nov 2024 15:37:34 +0800 Subject: [PATCH 6/7] use Counter and Histogram, remove posprocess --- lmdeploy/serve/async_engine.py | 2 - lmdeploy/serve/metrics.py | 90 +++++++++++----------------------- 2 files changed, 29 insertions(+), 63 deletions(-) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index f6a72954fc..b6655c3c66 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -605,7 +605,6 @@ async def get_inputs_genconfig(gen_config): start_ids_offset = state.ids_offset response = '' async for outputs in iterator: - start_frame = self.metrics.insert_frame() is_first_token = state.prev_tokens is None # decode res if is_error(outputs.status): @@ -626,7 +625,6 @@ async def get_inputs_genconfig(gen_config): if outputs.logprobs: log_offset = ids_offset - start_ids_offset logprobs = outputs.logprobs[log_offset:] - self.metrics.update_postprocess(start_frame) if is_first_token: self.metrics.update_FTL(arrival_frame) # response, history token len, diff --git a/lmdeploy/serve/metrics.py b/lmdeploy/serve/metrics.py index a0713174d6..1d4eaeef6b 100644 --- a/lmdeploy/serve/metrics.py +++ b/lmdeploy/serve/metrics.py @@ -2,12 +2,13 @@ import dataclasses import threading import time -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict, List, Optional import psutil import pynvml -from prometheus_client import REGISTRY, Gauge, Info, disable_created_metrics +from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, + disable_created_metrics) disable_created_metrics() @@ -57,11 +58,11 @@ class Stats: request_waiting: int = 0 # latency stats - duration_queue: float = 0 - duration_infer: float = 0 - duration_preprocess: float = 0 - duration_postprocess: float = 0 - first_token_latency: float = 0 + duration_queue: list = field(default_factory=list) + duration_infer: list = field(default_factory=list) + duration_preprocess: list = field(default_factory=list) + duration_postprocess: list = field(default_factory=list) + first_token_latency: list = field(default_factory=list) # system status cpu_utilization: Optional[float] = None @@ -134,43 +135,38 @@ def __init__(self, labelnames=labelnames) # requests - self.gauge_request_success = Gauge( + self.counter_request_success = Counter( name='lmdeploy:request_success', documentation='Number of successful requests.', labelnames=labelnames) - self.gauge_request_failure = Gauge( + self.counter_request_failure = Counter( name='lmdeploy:request_failure', documentation='Number of failed requests.', labelnames=labelnames) - self.gauge_request_total = Gauge( + self.counter_request_total = Counter( name='lmdeploy:request_total', documentation='Number of total requests.', labelnames=labelnames) # latency metrics - self.gauge_duration_queue = Gauge( - name='lmdeploy:duration_queue', + self.histogram_duration_queue = Histogram( + name='lmdeploy:duration_queue_seconds', documentation= # noqa 'Avarate duration waiting in the queue of requests in s.', labelnames=labelnames, ) - self.gauge_duration_infer = Gauge( - name='lmdeploy:duration_infer', + self.histogram_duration_infer = Histogram( + name='lmdeploy:duration_infer_seconds', documentation='Average inference time in s.', labelnames=labelnames, ) - self.gauge_duration_preprocess = Gauge( - name='lmdeploy:duration_preprocess', + self.histogram_duration_preprocess = Histogram( + name='lmdeploy:duration_preprocess_seconds', documentation='Average duration of processing inputs in s.', labelnames=labelnames, ) - self.gauge_duration_postprocess = Gauge( - name='lmdeploy:duration_postprocess', - documentation='Average duration of processing outputs in s.', - labelnames=labelnames, - ) - self.gauge_first_token_latency = Gauge( - name='lmdeploy:first_token_latency', + self.histogram_first_token_latency = Histogram( + name='lmdeploy:first_token_latency_seconds', documentation='Average first token latency in s.', labelnames=labelnames, ) @@ -191,16 +187,15 @@ def info(self, backend_config: object) -> None: def failure_frame(self): """log the failaure frame.""" if self.applied: - self.stats.request_failure += 1 - self.stats.request_total += 1 + self.counter_request_failure.inc() + self.counter_request_total.inc() def last_token_frame(self, iterator): """log the last token frame.""" if self.applied: - self.stats.duration_infer += iterator.get_duration() - self.stats.request_success += 1 - self.stats.request_total += 1 - self.log() + self.histogram_duration_infer.observe(iterator.get_duration()) + self.counter_request_success.inc() + self.counter_request_total.inc() def insert_frame(self): """Insert a frame.""" @@ -208,46 +203,19 @@ def insert_frame(self): return time.time() return None - def update_postprocess(self, start_frame): - """Update postprocess duration.""" - if self.applied: - self.stats.duration_postprocess += time.time() - start_frame - def update_preprocess(self, start_frame): """Update preprocess duration.""" if self.applied: - self.stats.duration_preprocess += time.time() - start_frame + self.histogram_duration_preprocess.observe(time.time() - + start_frame) def update_queue_waiting(self, start_frame): """Update queue waiting time.""" if self.applied: - self.stats.duration_queue += time.time() - start_frame + self.histogram_duration_queue.observe(time.time() - start_frame) def update_FTL(self, start_frame): """Update first token latency.""" if self.applied: - self.stats.first_token_latency += time.time() - start_frame - - def log(self) -> None: - """Called by LLMEngine. - - Logs to prometheus and tracked stats every iteration. Logs to Stdout - every self.local_interval seconds. - """ - stats = self.stats - # Add to request counters. - self.gauge_request_total.set(stats.request_total) - self.gauge_request_success.set(stats.request_success) - self.gauge_request_failure.set(stats.request_failure) - - # duration gauges - self.gauge_duration_infer.set(stats.duration_infer / - stats.request_total) - self.gauge_duration_queue.set(stats.duration_queue / - stats.request_total) - self.gauge_duration_preprocess.set(stats.duration_preprocess / - stats.request_total) - self.gauge_duration_postprocess.set(stats.duration_postprocess / - stats.request_total) - self.gauge_first_token_latency.set(stats.first_token_latency / - stats.request_total) + self.histogram_first_token_latency.observe(time.time() - + start_frame) From abd7705e2fd572723422efb95d985237cea2081b Mon Sep 17 00:00:00 2001 From: AllentDan Date: Thu, 28 Nov 2024 15:53:10 +0800 Subject: [PATCH 7/7] remove useless variables --- lmdeploy/serve/metrics.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/lmdeploy/serve/metrics.py b/lmdeploy/serve/metrics.py index 1d4eaeef6b..22977d47c2 100644 --- a/lmdeploy/serve/metrics.py +++ b/lmdeploy/serve/metrics.py @@ -2,7 +2,7 @@ import dataclasses import threading import time -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Dict, List, Optional import psutil @@ -48,22 +48,7 @@ async def __anext__(self): @dataclass class Stats: - """Created by LLMEngine for use by StatLogger.""" - - # request stats - request_success: int = 0 - request_failure: int = 0 - request_total: int = 0 - request_responding: int = 0 - request_waiting: int = 0 - - # latency stats - duration_queue: list = field(default_factory=list) - duration_infer: list = field(default_factory=list) - duration_preprocess: list = field(default_factory=list) - duration_postprocess: list = field(default_factory=list) - first_token_latency: list = field(default_factory=list) - + """Log system information.""" # system status cpu_utilization: Optional[float] = None cpu_memory_used_bytes: Optional[float] = None