diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index b4fde50ebe..4ed085a6bb 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -125,6 +125,11 @@ def add_parser_api_server(): 'engine’s tasks once the maximum number of concurrent requests is ' 'reached, regardless of any additional requests sent by clients ' 'concurrently during that time. Default to None.') + # FIXME: change default value to False + parser.add_argument('--enable-metrics', + action='store_true', + default=True, + help='Whether log stats to cli / prometheus') # common args ArgumentHelper.backend(parser) ArgumentHelper.log_level(parser) @@ -272,7 +277,8 @@ def gradio(args): device_type=args.device, quant_policy=args.quant_policy, eager_mode=args.eager_mode, - max_prefill_token_num=args.max_prefill_token_num) + max_prefill_token_num=args.max_prefill_token_num, + enable_metrics=args.enable_metrics) else: backend_config = TurbomindEngineConfig(dtype=args.dtype, tp=args.tp, @@ -369,6 +375,7 @@ def api_server(args): max_log_len=args.max_log_len, disable_fastapi_docs=args.disable_fastapi_docs, max_concurrent_requests=args.max_concurrent_requests, + enable_metrics=args.enable_metrics, reasoning_parser=args.reasoning_parser, tool_call_parser=args.tool_call_parser) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index b3fed70361..d8bd020014 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import enum +import time from dataclasses import dataclass, field from typing import Callable, Dict, List, Literal, Optional @@ -9,6 +10,7 @@ from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend from lmdeploy.pytorch.disagg.request import MigrationRequest +from .metrics.stats import IterationStats, RequestStateStats, SchedulerStats from .tokenizer import Tokenizer from .utils import get_logger @@ -310,6 +312,7 @@ class PytorchEngineConfig: 'Decode']. Default to `EngineRole.Hybrid`. migration_backend: migration backend. options: ['DLSlime']. Default to `MigrationBackend.DLSlime`. + enable_metrics (bool): Whether log stats to cli / prometheus """ dtype: str = 'auto' tp: int = 1 @@ -338,6 +341,7 @@ class PytorchEngineConfig: role: EngineRole = EngineRole.Hybrid migration_backend: MigrationBackend = MigrationBackend.DLSlime + enable_metrics: bool = False def __post_init__(self): """Check input validation.""" @@ -407,6 +411,34 @@ class Response: last_hidden_state: torch.Tensor = None index: int = 0 + scheduler_stats: SchedulerStats = None + iteration_stats: IterationStats = None + + +# copy from https://github.com/vllm-project/vllm/blob/main/vllm/v1/engine/__init__.py +class EngineCoreEventType(enum.IntEnum): + """The type of engine core request event.""" + QUEUED = 1 + SCHEDULED = 2 + PREEMPTED = 3 # FIXME, currently ignored for simplicity + + +# copy from https://github.com/vllm-project/vllm/blob/main/vllm/v1/engine/__init__.py +@dataclass +class EngineCoreEvent(): + """A timestamped engine core event associated with a request. + + The timestamp is a monotonic timestamps and is used for by the engine frontend to calculate intervals between engine + core events. These timestamps should not be compared with timestamps from other processes. + """ + type: EngineCoreEventType + timestamp: float + + @classmethod + def new_event(cls, event_type: EngineCoreEventType, timestamp: Optional[float] = None) -> 'EngineCoreEvent': + timestamp = time.monotonic() if timestamp is None else timestamp + return cls(event_type, timestamp) + @dataclass class EngineOutput: @@ -431,6 +463,27 @@ class EngineOutput: cache_block_ids: Optional[List[int]] = None + # engine-side time stamp, for logging + timestamp: float = 0.0 + scheduler_stats: SchedulerStats = None + iteration_stats: IterationStats = None + events: List[EngineCoreEvent] = None + + def __post_init__(self): + if self.timestamp == 0.0: + self.timestamp = time.monotonic() + + +@dataclass +class RequestState: + """per request state.""" + + def __init__(self, arrival_time: float, prompt_len: int, is_prefilling: bool, enable_metrics: bool): + + self.prompt_len: int = prompt_len + self.is_prefilling: bool = is_prefilling + self.stats = RequestStateStats(arrival_time=arrival_time) if enable_metrics else None + @dataclass class VisionConfig: diff --git a/lmdeploy/metrics/__init__.py b/lmdeploy/metrics/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/lmdeploy/metrics/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/metrics/loggers.py b/lmdeploy/metrics/loggers.py new file mode 100644 index 0000000000..e986b28df3 --- /dev/null +++ b/lmdeploy/metrics/loggers.py @@ -0,0 +1,341 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/v1/metrics/loggers.py + +import time +from abc import ABC, abstractmethod +from typing import List, Optional + +import numpy as np +import prometheus_client + +from lmdeploy.metrics.stats import IterationStats, SchedulerStats +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + +prometheus_client.disable_created_metrics() + + +class StatLoggerBase(ABC): + + @abstractmethod + def record(self, scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats]): + ... + + def log(self): # noqa + pass + + +class LoggingStatLogger(StatLoggerBase): + + def __init__(self, engine_index: int = 0): + self.engine_index = engine_index + self._reset(time.monotonic()) + self.last_scheduler_stats = SchedulerStats() + + def _reset(self, now): + self.last_log_time = now + + # Tracked stats over current local logging interval. + self.num_prompt_tokens: list[int] = [] + self.num_generation_tokens: list[int] = [] + + def _track_iteration_stats(self, iteration_stats: IterationStats): + # Save tracked stats for token counters. + self.num_prompt_tokens.append(iteration_stats.num_prompt_tokens) + self.num_generation_tokens.append(iteration_stats.num_generation_tokens) + + def _get_throughput(self, tracked_stats: list[int], now: float) -> float: + # Compute summary metrics for tracked stats + return float(np.sum(tracked_stats) / (now - self.last_log_time)) + + def record(self, scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats]): + """Log Stats to standard output.""" + + if iteration_stats: + self._track_iteration_stats(iteration_stats) + + self.last_scheduler_stats = scheduler_stats + + def log(self): + now = time.monotonic() + prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) + generation_throughput = self._get_throughput(self.num_generation_tokens, now) + + self._reset(now) + + scheduler_stats = self.last_scheduler_stats + + # Format and print output. + logger.info( + 'Avg prompt throughput: %.1f tokens/s, ' + 'Avg generation throughput: %.1f tokens/s, ' + 'Running: %d reqs, Waiting: %d reqs, ' + 'GPU KV cache usage: %.1f%%, ', + prompt_throughput, + generation_throughput, + scheduler_stats.num_running_reqs, + scheduler_stats.num_waiting_reqs, + scheduler_stats.gpu_cache_usage * 100, + ) + + +class PrometheusStatLogger(StatLoggerBase): + + def __init__(self, labelnames: Optional[List[str]] = []): + + # unregister any existing lmdeploy collectors + for collector in list(prometheus_client.REGISTRY._collector_to_names): + if hasattr(collector, '_name') and 'lmdeploy' in collector._name: + prometheus_client.REGISTRY.unregister(collector) + + max_model_len = 4096 # FIXME, hard code now, get from model config + + # config Information + self.info_backend_config = prometheus_client.Info(name='lmdeploy:backend_config', + documentation='information of backend_config') + + # + # Scheduler state + # + self.gauge_scheduler_running = prometheus_client.Gauge( + name='lmdeploy:num_requests_running', + documentation='Number of requests in model execution batches.', + labelnames=labelnames) + + self.gauge_scheduler_waiting = prometheus_client.Gauge( + name='lmdeploy:num_requests_waiting', + documentation='Number of requests waiting to be processed.', + labelnames=labelnames) + + # + # GPU cache + # + self.gauge_gpu_cache_usage = prometheus_client.Gauge( + name='lmdeploy:gpu_cache_usage_perc', + documentation='GPU KV-cache usage. 1 means 100 percent usage.', + labelnames=labelnames) + + # + # Counters + # + self.counter_prompt_tokens = prometheus_client.Counter(name='lmdeploy:prompt_tokens_total', + documentation='Number of prefill tokens processed.', + labelnames=labelnames) + + self.counter_generation_tokens = prometheus_client.Counter( + name='lmdeploy:generation_tokens_total', + documentation='Number of generation tokens processed.', + labelnames=labelnames) + + # from lmdeploy.messages import ResponseType + # self.counter_request_success: dict[ResponseType, + # prometheus_client.Counter] = {} + # counter_request_success_base = prometheus_client.Counter( + # name="lmdeploy:request_success_total", + # documentation="Count of successfully processed requests.", + # labelnames=labelnames + ["finished_reason"]) + # for reason in FinishReason: + # self.counter_request_success[ + # reason] = counter_request_success_base.labels(*(labelvalues + + # [str(reason)])) + + # + # Histograms of counts + # + self.histogram_num_prompt_tokens_request = \ + prometheus_client.Histogram( + name='lmdeploy:request_prompt_tokens', + documentation='Number of prefill tokens processed.', + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames) + + self.histogram_num_generation_tokens_request = \ + prometheus_client.Histogram( + name='lmdeploy:request_generation_tokens', + documentation='Number of generation tokens processed.', + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames) + + # FIXME, build_cudagraph_buckets + # self.histogram_iteration_tokens = \ + # prometheus_client.Histogram( + # name="lmdeploy:iteration_tokens_total", + # documentation="Histogram of number of tokens per engine_step.", + # buckets=build_cudagraph_buckets(vllm_config), + # labelnames=labelnames) + + self.histogram_max_num_generation_tokens_request = \ + prometheus_client.Histogram( + name='lmdeploy:request_max_num_generation_tokens', + documentation='Histogram of maximum number of requested generation tokens.', + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames) + + self.histogram_n_request = \ + prometheus_client.Histogram( + name='lmdeploy:request_params_n', + documentation='Histogram of the n request parameter.', + buckets=[1, 2, 5, 10, 20], + labelnames=labelnames) + + self.histogram_max_tokens_request = \ + prometheus_client.Histogram( + name='lmdeploy:request_params_max_tokens', + documentation='Histogram of the max_tokens request parameter.', + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames) + + # + # Histogram of timing intervals + # + self.histogram_time_to_first_token = \ + prometheus_client.Histogram( + name='lmdeploy:time_to_first_token_seconds', + documentation='Histogram of time to first token in seconds.', + buckets=[ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, + 640.0, 2560.0 + ], + labelnames=labelnames) + + self.histogram_time_per_output_token = \ + prometheus_client.Histogram( + name='lmdeploy:time_per_output_token_seconds', + documentation='Histogram of time per output token in seconds.', + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + ], + labelnames=labelnames) + + request_latency_buckets = [ + 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, + 960.0, 1920.0, 7680.0 + ] + self.histogram_e2e_time_request = \ + prometheus_client.Histogram( + name='lmdeploy:e2e_request_latency_seconds', + documentation='Histogram of e2e request latency in seconds.', + buckets=request_latency_buckets, + labelnames=labelnames) + self.histogram_queue_time_request = \ + prometheus_client.Histogram( + name='lmdeploy:request_queue_time_seconds', + documentation='Histogram of time spent in WAITING phase for request.', + buckets=request_latency_buckets, + labelnames=labelnames) + self.histogram_inference_time_request = \ + prometheus_client.Histogram( + name='lmdeploy:request_inference_time_seconds', + documentation='Histogram of time spent in RUNNING phase for request.', + buckets=request_latency_buckets, + labelnames=labelnames) + self.histogram_prefill_time_request = \ + prometheus_client.Histogram( + name='lmdeploy:request_prefill_time_seconds', + documentation='Histogram of time spent in PREFILL phase for request.', + buckets=request_latency_buckets, + labelnames=labelnames) + self.histogram_decode_time_request = \ + prometheus_client.Histogram( + name='lmdeploy:request_decode_time_seconds', + documentation='Histogram of time spent in DECODE phase for request.', + buckets=request_latency_buckets, + labelnames=labelnames) + + # def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): + # metrics_info = config_obj.metrics_info() + + # name, documentation = None, None + # if type == "cache_config": + # name = "lmdeploy:cache_config_info" + # documentation = "Information of the LLMEngine CacheConfig" + # assert name is not None, f"Unknown metrics info type {type}" + + # # Info type metrics are syntactic sugar for a gauge permanently set to 1 + # # Since prometheus multiprocessing mode does not support Info, emulate + # # info here with a gauge. + # info_gauge = prometheus_client.Gauge( + # name=name, + # documentation=documentation, + # labelnames=metrics_info.keys()).labels(**metrics_info) + # info_gauge.set(1) + + def record(self, scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats]): + """Log to prometheus.""" + + self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) + self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) + + self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) + + if iteration_stats is None: + return + + self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) + self.counter_generation_tokens.inc(iteration_stats.num_generation_tokens) + self.counter_generation_tokens.inc(iteration_stats.num_generation_tokens) + # self.histogram_iteration_tokens.observe( + # iteration_stats.num_prompt_tokens + \ + # iteration_stats.num_generation_tokens) + + # import pdb; pdb.set_trace() + for ttft in iteration_stats.time_to_first_tokens_iter: + print(f'ttft: {ttft}') + self.histogram_time_to_first_token.observe(ttft) + + # import pdb; pdb.set_trace() + for tpot in iteration_stats.time_per_output_tokens_iter: + print(f'tpot: {tpot}') + self.histogram_time_per_output_token.observe(tpot) + + # import pdb; pdb.set_trace() + for finished_request in iteration_stats.finished_requests: + # self.counter_request_success[finished_request.finish_reason].inc() + self.histogram_e2e_time_request.observe(finished_request.e2e_latency) + self.histogram_queue_time_request.observe(finished_request.queued_time) + self.histogram_prefill_time_request.observe(finished_request.prefill_time) + self.histogram_inference_time_request.observe(finished_request.inference_time) + self.histogram_decode_time_request.observe(finished_request.decode_time) + self.histogram_num_prompt_tokens_request.observe(finished_request.num_prompt_tokens) + self.histogram_num_generation_tokens_request.observe(finished_request.num_generation_tokens) + # self.histogram_max_tokens_request.observe( + # finished_request.max_tokens_param) + + +def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]: + """Builds a list of buckets with increasing powers of 10 multiplied by + mantissa values until the value exceeds the specified maximum.""" + exponent = 0 + buckets: list[int] = [] + while True: + for m in mantissa_lst: + value = m * 10**exponent + if value <= max_value: + buckets.append(value) + else: + return buckets + exponent += 1 + + +def build_1_2_5_buckets(max_value: int) -> list[int]: + """ + Example: + >>> build_1_2_5_buckets(100) + [1, 2, 5, 10, 20, 50, 100] + """ + return build_buckets([1, 2, 5], max_value) + + +def setup_loggers(enable_metrics: bool, engine_num: int): + if not enable_metrics: + return [] + + stat_loggers: list[list[StatLoggerBase]] = [] + # independent set for each DP rank + for i in range(engine_num): + stat_loggers.append([LoggingStatLogger(), PrometheusStatLogger()]) + + return stat_loggers diff --git a/lmdeploy/metrics/stats.py b/lmdeploy/metrics/stats.py new file mode 100644 index 0000000000..2f8761207f --- /dev/null +++ b/lmdeploy/metrics/stats.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/metrics/stats.py + +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from lmdeploy.messages import EngineCoreEvent, EngineOutput, ResponseType + + +@dataclass +class SchedulerStats: + """Stats associated with the scheduler.""" + + num_running_reqs: int = 0 + num_waiting_reqs: int = 0 + + gpu_cache_usage: float = 0.0 + + +@dataclass +class RequestStateStats: + """Stats that need to be tracked across delta updates.""" + + num_generation_tokens: int = 0 + + # This is a engine frontend timestamp (wall-clock) + arrival_time: float = 0.0 + + # These are engine core timestamps (monotonic) + queued_ts: float = 0.0 + scheduled_ts: float = 0.0 + first_token_ts: float = 0.0 + last_token_ts: float = 0.0 + + +@dataclass +class FinishedRequestStats: + """Stats associated with a finished request.""" + + finish_reason: 'ResponseType' + e2e_latency: float = 0.0 + num_prompt_tokens: int = 0 + num_generation_tokens: int = 0 + # max_tokens_param: Optional[int] = None + queued_time: float = 0.0 + prefill_time: float = 0.0 + inference_time: float = 0.0 + decode_time: float = 0.0 + + +class IterationStats: + """Stats associated with a single set of EngineCoreOutputs.""" + + def __init__(self): + self.iteration_timestamp = time.time() + self.num_generation_tokens = 0 + self.num_prompt_tokens = 0 + self.num_preempted_reqs = 0 + self.finished_requests: list[FinishedRequestStats] = [] + # self.max_num_generation_tokens_iter: list[int] = [] + # self.n_params_iter: list[int] = [] + self.time_to_first_tokens_iter: list[float] = [] + self.time_per_output_tokens_iter: list[float] = [] + + def _time_since(self, start: float) -> float: + """Calculate an interval relative to this iteration's timestamp.""" + return self.iteration_timestamp - start + + def update_from_output(self, output: 'EngineOutput', engine_core_timestamp: float, is_prefilling: bool, + prompt_len: int, req_stats: RequestStateStats): + num_new_generation_tokens = len(output.token_ids) + + self.num_generation_tokens += num_new_generation_tokens + if is_prefilling: + assert num_new_generation_tokens > 0 + self.num_prompt_tokens += prompt_len + + first_token_latency = self._time_since(req_stats.arrival_time) + self.time_to_first_tokens_iter.append(first_token_latency) + + req_stats.num_generation_tokens += num_new_generation_tokens + + # Process request-level engine core events + if output.events is not None: + self.update_from_events(output.events, req_stats) + + # Process the batch-level "new tokens" engine core event + if is_prefilling: + req_stats.first_token_ts = engine_core_timestamp + else: + tpot = engine_core_timestamp - req_stats.last_token_ts + self.time_per_output_tokens_iter.append(tpot) + + req_stats.last_token_ts = engine_core_timestamp + + def update_from_events(self, events: list['EngineCoreEvent'], req_stats: RequestStateStats): + # Avoid circular dependency + from lmdeploy.messages import EngineCoreEventType + + for event in events: + if event.type == EngineCoreEventType.QUEUED: + req_stats.queued_ts = event.timestamp + elif event.type == EngineCoreEventType.SCHEDULED: + if req_stats.scheduled_ts == 0.0: # ignore preemptions + req_stats.scheduled_ts = event.timestamp + # FIXME: deal with preempted case + # elif event.type == EngineCoreEventType.PREEMPTED: + # self.num_preempted_reqs += 1 + + def update_from_finished_request( + self, + finish_reason: 'ResponseType', + num_prompt_tokens: int, + # max_tokens_param: Optional[int], + req_stats: RequestStateStats): + + e2e_latency = self._time_since(req_stats.arrival_time) + + # Queued interval is from first QUEUED event to first SCHEDULED + queued_time = req_stats.scheduled_ts - req_stats.queued_ts + + # Prefill interval is from first SCHEDULED to first NEW_TOKEN + # Any preemptions during prefill is included in the interval + prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts + + # Decode interval is from first NEW_TOKEN to last NEW_TOKEN + # Any preemptions during decode are included + decode_time = req_stats.last_token_ts - req_stats.first_token_ts + + # Inference interval is from first SCHEDULED to last NEW_TOKEN + # Any preemptions during prefill or decode are included + inference_time = req_stats.last_token_ts - req_stats.scheduled_ts + + finished_req = \ + FinishedRequestStats(finish_reason=finish_reason, + e2e_latency=e2e_latency, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=req_stats.num_generation_tokens, + # max_tokens_param=max_tokens_param, + queued_time=queued_time, + prefill_time=prefill_time, + inference_time=inference_time, + decode_time=decode_time) + self.finished_requests.append(finished_req) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index dea2a18770..d5724bb692 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -66,6 +66,8 @@ class SchedulerConfig: prefill_interval: int = 16 max_active_adapters: int = 64 + enable_metrics: bool = False + @dataclass class CacheConfig: diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index caa314ec60..28502b8740 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -3,13 +3,15 @@ import copy import logging import os -from dataclasses import dataclass +import time +from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple import numpy as np import torch -from lmdeploy.messages import PytorchEngineConfig, ResponseType +from lmdeploy.messages import EngineCoreEvent, PytorchEngineConfig, ResponseType +from lmdeploy.metrics.stats import IterationStats, SchedulerStats from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch from lmdeploy.utils import get_logger, get_max_batch_size, get_model, logging_timer @@ -46,6 +48,9 @@ class InferOutput: # when Prefill Engine is Done. cache_block_ids: List[int] = None + # events for logging + events: List[EngineCoreEvent] = field(default_factory=list) + def _tensorlize_block_offsets(block_offsets, dtype=torch.int32): """tensorlize block_offsets.""" @@ -59,7 +64,8 @@ def _build_scheduler_config(engine_config: PytorchEngineConfig): """build scheduler config.""" scheduler_config = SchedulerConfig(max_batches=engine_config.max_batch_size, max_session_len=engine_config.session_len, - prefill_interval=engine_config.prefill_interval) + prefill_interval=engine_config.prefill_interval, + enable_metrics=engine_config.enable_metrics) return scheduler_config @@ -427,13 +433,24 @@ def _start_loop(self): """start loop.""" return self.req_manager.start_loop(self.async_loop) - def _response(self, resp: Response, resp_type: ResponseType, data: Any = None, err_msg: str = ''): + def _response(self, + resp: Response, + resp_type: ResponseType, + scheduler_stats: SchedulerStats = None, + iteration_stats: IterationStats = None, + events: List[EngineCoreEvent] = None, + data: Any = None, + err_msg: str = ''): """response.""" if resp.type == ResponseType.FINISH: return resp.type = resp_type resp.data = data resp.err_msg = err_msg + + resp.scheduler_stats = scheduler_stats + resp.iteration_stats = iteration_stats + resp.events = events self.req_manager.response(resp) def _get_max_session_len(self): @@ -498,6 +515,10 @@ def _on_end_session(self, reqs: List[Request], **kwargs): def _on_add_message(self, reqs: List[Request], **kwargs): """on add message callback.""" for req in reqs: + # record arrival time, when requests arrive engine side + if req.arrival_time is None: + req.arrival_time = time.time() + req_data = req.data if req_data.get('input_multimodals', None) is None: continue @@ -769,7 +790,8 @@ def _make_infer_outputs(self, next_token_ids: torch.LongTensor, running: SeqList resp=msg.resp, finish=finish, token_ids=token_ids, - cache_block_ids=cache_block_ids) + cache_block_ids=cache_block_ids, + events=msg.events) outputs[session_id] = out if msg.return_logits: @@ -926,8 +948,16 @@ def __log_resps(outputs: List[InferOutput]): def __send_resp(out: InferOutput): """send response.""" resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS) + + scheduler_stats = SchedulerStats(num_running_reqs=self.scheduler.num_running(), + num_waiting_reqs=self.scheduler.num_waiting(), + gpu_cache_usage=self.scheduler.usage) + self._response(out.resp, resp_type, + scheduler_stats=scheduler_stats, + iteration_stats=None, + events=out.events, data=dict(token_ids=out.token_ids, logits=out.logits, cache_block_ids=out.cache_block_ids)) def __send_resps(step_outputs: List[InferOutput]): diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 985d648310..ae8c7127cc 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -125,7 +125,11 @@ async def async_stream_infer(self, int: The number of the output tokens. """ if len(input_ids) > self.max_input_len: - yield EngineOutput(ResponseType.INPUT_LENGTH_ERROR, [], 0) + yield EngineOutput(status=ResponseType.INPUT_LENGTH_ERROR, + token_ids=[], + num_token=0, + scheduler_stats=None, + iteration_stats=None) return gen_config = gen_config or GenerationConfig() sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) @@ -146,24 +150,44 @@ async def async_stream_infer(self, while True: resp = await self.req_sender.async_recv(resp) + print(f'=> finally get resp {resp}') cache_block_ids = resp.data.get('cache_block_ids', None) if resp.type == ResponseType.SUCCESS: token_ids = resp.data['token_ids'].tolist() num_ids = len(token_ids) logger.debug(f'session[{session_id}] success: num_out_ids={num_ids}.') - yield EngineOutput(resp.type, token_ids, num_ids, cache_block_ids=cache_block_ids) + yield EngineOutput(status=resp.type, + token_ids=token_ids, + num_token=num_ids, + scheduler_stats=resp.scheduler_stats, + iteration_stats=resp.iteration_stats, + events=resp.events, + cache_block_ids=cache_block_ids) elif resp.type == ResponseType.FINISH: resp_data = resp.data token_ids = resp_data['token_ids'].tolist() logits = resp_data['logits'] num_ids = len(token_ids) logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.') - yield EngineOutput(resp.type, token_ids, num_ids, logits=logits, cache_block_ids=cache_block_ids) + yield EngineOutput(status=resp.type, + token_ids=token_ids, + num_token=num_ids, + logits=logits, + scheduler_stats=resp.scheduler_stats, + iteration_stats=resp.iteration_stats, + events=resp.events, + cache_block_ids=cache_block_ids) break else: logger.debug(f'session[{session_id}] failed.') - yield EngineOutput(resp.type, [], 0) + yield EngineOutput(status=resp.type, + token_ids=[], + num_token=0, + scheduler_stats=resp.scheduler_stats, + iteration_stats=resp.iteration_stats, + events=resp.events) + # FIXME, should be None for scheduler_stats etc ? break async def async_infer(self, diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index c1c81912cc..0280442dbb 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -5,7 +5,8 @@ from dataclasses import dataclass, field from typing import Any, Awaitable, Callable, Dict, List -from lmdeploy.messages import ResponseType +from lmdeploy.messages import EngineCoreEvent, ResponseType +from lmdeploy.metrics.stats import IterationStats, SchedulerStats from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') @@ -32,6 +33,11 @@ class Response: data: Any = None err_msg: str = '' + # for logging + scheduler_stats = SchedulerStats + iteration_stats = IterationStats + events: List[EngineCoreEvent] = None + @dataclass class Request: @@ -42,6 +48,9 @@ class Request: data: Any = None resp: Response = None + # engine-side request arrival time + arrival_time: float = None + ReqList = List[Request] diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 831f7b3139..04a0db5a39 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -7,7 +7,7 @@ import numpy as np from torch import Tensor -from lmdeploy.messages import GenerationConfig, LogitsProcessor +from lmdeploy.messages import EngineCoreEvent, EngineCoreEventType, GenerationConfig, LogitsProcessor from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch from lmdeploy.pytorch.disagg.request import MigrationRequest from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs @@ -468,6 +468,9 @@ class SchedulerSequence: preserve_cache: bool = False migration_inputs: Optional[MigrationExecutionBatch] = None + # events for logging + events: List[EngineCoreEvent] = field(default_factory=list) + def __post_init__(self): """post init.""" self._num_history_ids: int = 0 @@ -657,3 +660,11 @@ def set_step(self, step: int): if self.history_multimodals is not None: self._num_history_cross = self.history_multimodals.get_encoder_len(0, self.num_history_ids) self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, num_all_ids) + + def record_event( + self, + event_type: EngineCoreEventType, + timestamp: Optional[float] = None, + ) -> None: + print(f'=> record event {event_type}, {timestamp}') + self.events.append(EngineCoreEvent.new_event(event_type, timestamp)) diff --git a/lmdeploy/pytorch/paging/block_manager/base_block_manager.py b/lmdeploy/pytorch/paging/block_manager/base_block_manager.py index d8116e08f3..7f6b9ee693 100644 --- a/lmdeploy/pytorch/paging/block_manager/base_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/base_block_manager.py @@ -283,6 +283,13 @@ def get_num_free_cpu_blocks(self) -> int: """Get number of free cpu blocks.""" return self.allocator.get_phy_allocator('cpu').get_num_free_blocks() + def get_usage(self) -> float: + """Get the KV cache usage. + + (between 0.0 and 1.0). + """ + return 1.0 - (self.get_num_free_gpu_blocks() / self.num_gpu_blocks) + def on_device(self, msg: SchedulerSequence, device: str): allocator = self.allocator logical_blocks = msg.logical_blocks diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index bc62c14918..d168cf057e 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. # modify from: https://github.com/vllm-project/vllm +import time from collections import OrderedDict from dataclasses import dataclass from typing import Dict, List, Tuple +from lmdeploy.messages import EngineCoreEventType from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch from lmdeploy.utils import get_logger, logging_timer @@ -52,6 +54,14 @@ def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) self.seq_manager = SequenceManager() + @property + def usage(self) -> float: + """Get the KV cache usage. + + The KV cache usage (between 0.0 and 1.0). + """ + return self.block_manager.get_usage() + @property def waiting(self): """get waiting sequence.""" @@ -136,6 +146,9 @@ def add_sequence(self, seq: SchedulerSequence): # push message to waiting queue self._set_message_status(seq, MessageStatus.WAITING) + if self.scheduler_config.enable_metrics: + seq.record_event(EngineCoreEventType.QUEUED) + @logging_timer('ScheduleMigration', logger) def _schedule_migration(self): @@ -227,6 +240,9 @@ def _reorder_waiting(): waiting = _reorder_waiting() while len(waiting) > 0 and len(running) < max_batches: + # for logging + scheduled_timestamp = time.monotonic() + seq = waiting.pop(0) if (len(running) > 0 and token_count + seq.num_token_ids > self.cache_config.max_prefill_token_num): @@ -241,12 +257,19 @@ def _reorder_waiting(): self.block_manager.allocate(seq) _to_running(seq) + if self.scheduler_config.enable_metrics: + seq.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) + return running, swap_in_map, swap_out_map, copy_map @logging_timer('ScheduleDecoding', logger) def _schedule_decoding(self, prealloc_size: int = 0): """schedule decoding.""" + # for logging + # FIXME, record request scheduled event + # scheduled_timestamp = time.monotonic() + running = self.running assert len(running) != 0 diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 19dfc0617f..028f29e57a 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -20,7 +20,10 @@ from lmdeploy import Tokenizer from lmdeploy.archs import get_model_arch from lmdeploy.logger import RequestLogger -from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, Response, ResponseType, TurbomindEngineConfig +from lmdeploy.messages import (EngineOutput, GenerationConfig, PytorchEngineConfig, RequestState, Response, + ResponseType, TurbomindEngineConfig) +from lmdeploy.metrics.loggers import StatLoggerBase, setup_loggers +from lmdeploy.metrics.stats import IterationStats, SchedulerStats from lmdeploy.model import MODELS, BaseChatTemplate, ChatTemplateConfig, best_match_model from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.serve.utils import LogitsMixin @@ -250,6 +253,7 @@ class AsyncEngine(LogitsMixin): Default to None. max_log_len (int): Max number of prompt characters or prompt tokens being printed in log. Default: Unlimited + enable_metrics (bool): Whether log stats to cli / prometheus """ def __init__(self, @@ -259,7 +263,16 @@ def __init__(self, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, chat_template_config: Optional[ChatTemplateConfig] = None, max_log_len: int = None, + enable_metrics: Optional[bool] = True, **kwargs) -> None: + + # setup stat loggers + backend_config.enable_metrics = enable_metrics + self.enable_metrics = enable_metrics + self.stat_loggers: List[List[StatLoggerBase]] = setup_loggers(enable_metrics=enable_metrics, + engine_num=backend_config.dp) + + # setup chat template logger.info(f'input backend={backend}, backend_config={backend_config}') logger.info(f'input chat_template_config={chat_template_config}') @@ -377,6 +390,11 @@ def __call__(self, use_tqdm=use_tqdm, **kwargs) + async def do_log_stats(self, ) -> None: + for each_engine_loggers in self.stat_loggers: + for stat_logger in each_engine_loggers: + stat_logger.log() + async def stop_session(self, session_id: int): """Stop a session by a session_id.""" logger.info(f'stop session {session_id}') @@ -601,6 +619,40 @@ async def safe_run(self, inst, session_id, **kwargs): finally: await generator.aclose() + def _update_stats_from_output(self, req_state: RequestState, engine_core_output: EngineOutput, + iteration_stats: Optional[IterationStats]): + print('update from output') + if iteration_stats is None: + return + + assert req_state.stats is not None + iteration_stats.update_from_output(engine_core_output, engine_core_output.timestamp, req_state.is_prefilling, + req_state.prompt_len, req_state.stats) + + def _update_stats_from_finished(self, req_state: RequestState, finish_reason: Optional[ResponseType], + iteration_stats: Optional[IterationStats]): + + if iteration_stats is None: + return + + assert finish_reason is not None + assert req_state.stats is not None + iteration_stats.update_from_finished_request( + finish_reason=finish_reason, + num_prompt_tokens=req_state.prompt_len, + # max_tokens_param=req_state.max_tokens_param, + req_stats=req_state.stats) + + @staticmethod + def _record_stats( + stat_loggers: StatLoggerBase, + scheduler_stats: SchedulerStats, + iteration_stats: Optional[IterationStats], + ): + for each_engine_loggers in stat_loggers: + for stat_logger in each_engine_loggers: + stat_logger.record(scheduler_stats=scheduler_stats, iteration_stats=iteration_stats) + async def generate( self, messages, @@ -729,7 +781,10 @@ def is_error(status): step=history_len) as gen: prev_len = 0 hit_stop_token = 0 + iteration_stats = IterationStats() if self.enable_metrics else None async for outputs in gen: + print(f'=> async engine step outputs {type(outputs)}, {outputs}') + # decode res if is_error(outputs.status): break @@ -782,9 +837,35 @@ def is_error(status): if hit_stop_token: out.logits = out.logits[:-hit_stop_token] + # update stats from per iteration engine outputs (i.e. step output) + req_state = RequestState( + arrival_time=outputs.timestamp, + prompt_len=input_len, + is_prefilling=(output_len == 0), # FIXME: is this logic correct ? + enable_metrics=self.enable_metrics) + self._update_stats_from_output(req_state=req_state, + engine_core_output=outputs, + iteration_stats=iteration_stats) yield out # end of generator loop + # update stats from per finished requests engine outputs + self._update_stats_from_finished( + req_state=req_state, + finish_reason=outputs.status, # ResponseType + iteration_stats=iteration_stats) + print(f'=> check async engine output {type(outputs)} {outputs}') + + # perform logging + if self.stat_loggers: + assert outputs.scheduler_stats is not None + + AsyncEngine._record_stats( + self.stat_loggers, + scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats, + ) + if not is_error(outputs.status): finish_reason = 'length' \ if gen_len >= gen_config.max_new_tokens else 'stop' diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 78c607850a..c340d9efa6 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -9,6 +9,7 @@ from http import HTTPStatus from typing import AsyncGenerator, Dict, List, Literal, Optional, Union +import prometheus_client import uvicorn from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware @@ -1137,6 +1138,7 @@ def serve(model_path: str, max_log_len: int = None, disable_fastapi_docs: bool = False, max_concurrent_requests: Optional[int] = None, + enable_metrics: Optional[bool] = True, reasoning_parser: Optional[str] = None, tool_call_parser: Optional[str] = None, allow_terminate_by_client: bool = False, @@ -1189,6 +1191,7 @@ def serve(model_path: str, process the engine’s tasks once the maximum number of concurrent requests is reached, regardless of any additional requests sent by clients concurrently during that time. Default to None. + enable_metrics: Whether log stats to cli / prometheus reasoning_parser (str): The reasoning parser name. tool_call_parser (str): The tool call parser name. allow_terminate_by_client (bool): Allow request from client to terminate server. @@ -1197,30 +1200,6 @@ def serve(model_path: str, os.environ['TM_LOG_LEVEL'] = log_level logger.setLevel(log_level) - if disable_fastapi_docs: - app = FastAPI( - docs_url=None, - redoc_url=None, - openapi_url=None, - ) - else: - app = FastAPI(docs_url='/') - - app.include_router(router) - - if allow_origins: - app.add_middleware( - CORSMiddleware, - allow_origins=allow_origins, - allow_credentials=allow_credentials, - allow_methods=allow_methods, - allow_headers=allow_headers, - ) - - # Set the maximum number of concurrent requests - if max_concurrent_requests is not None: - app.add_middleware(ConcurrencyLimitMiddleware, max_concurrent_requests=max_concurrent_requests) - VariableInterface.allow_terminate_by_client = allow_terminate_by_client if api_keys is not None: if isinstance(api_keys, str): @@ -1240,10 +1219,60 @@ def serve(model_path: str, backend_config=backend_config, chat_template_config=chat_template_config, max_log_len=max_log_len, + enable_metrics=enable_metrics, **kwargs) # set reasoning parser and tool parser set_parsers(reasoning_parser, tool_call_parser) + _running_tasks: set[asyncio.Task] = set() + + async def lifespan(app: FastAPI): + async_engine = VariableInterface.async_engine + task = None + try: + if enable_metrics: + log_interval = 1. # FIXME: change this + + async def _force_log(): + while True: + await asyncio.sleep(log_interval) + + await async_engine.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + + yield + finally: + if task: + task.cancel() + + if disable_fastapi_docs: + app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None, lifespan=lifespan) + else: + app = FastAPI(docs_url='/', lifespan=lifespan) + + app.include_router(router) + + if enable_metrics: + # add prometheus asgi middleware to route '/metrics' requests + metrics_app = prometheus_client.make_asgi_app() + app.mount('/metrics', metrics_app) + + if allow_origins: + app.add_middleware( + CORSMiddleware, + allow_origins=allow_origins, + allow_credentials=allow_credentials, + allow_methods=allow_methods, + allow_headers=allow_headers, + ) + + # Set the maximum number of concurrent requests + if max_concurrent_requests is not None: + app.add_middleware(ConcurrencyLimitMiddleware, max_concurrent_requests=max_concurrent_requests) + if proxy_url is not None: VariableInterface.proxy_url = proxy_url VariableInterface.api_server_url = f'{http_or_https}://{server_name}:{server_port}' # noqa diff --git a/requirements/runtime_cuda.txt b/requirements/runtime_cuda.txt index 557a8ef2aa..d158919072 100644 --- a/requirements/runtime_cuda.txt +++ b/requirements/runtime_cuda.txt @@ -9,6 +9,7 @@ outlines partial_json_parser peft<=0.14.0 pillow +prometheus_client >= 0.18.0 protobuf pydantic>2.0.0 pynvml