diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE.md similarity index 100% rename from .github/PULL_REQUEST_TEMPLATE/pull_request_template.md rename to .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/workflows/build_and_deploy.yml b/.github/workflows/build_and_deploy.yml index 66a2a915a..04df626bc 100644 --- a/.github/workflows/build_and_deploy.yml +++ b/.github/workflows/build_and_deploy.yml @@ -4,6 +4,7 @@ on: push: branches: - main + - 624-add-llm-oriented-metrics-to-grafana-dashboard release: types: - published diff --git a/api/app.py b/api/app.py index 0d2bcb9d9..afd4c310a 100644 --- a/api/app.py +++ b/api/app.py @@ -4,7 +4,6 @@ from fastapi import FastAPI, Request import sentry_sdk from starlette.middleware.sessions import SessionMiddleware -from starlette.responses import JSONResponse from api.endpoints.monitoring import setup_prometheus from api.schemas.core.context import RequestContext @@ -89,7 +88,3 @@ def _setup_monitoring(app: FastAPI, configuration: Configuration) -> None: if configuration.settings.monitoring_prometheus_enabled: setup_prometheus(app, include_in_schema=include_in_schema) - - @app.get(path="/health", tags=[RouterName.MONITORING.title()], include_in_schema=include_in_schema) - def health() -> JSONResponse: - return JSONResponse(content={"status": "ok"}, status_code=200) diff --git a/api/clients/model/_basemodelprovider.py b/api/clients/model/_basemodelprovider.py index 9f43e0a78..e65175151 100644 --- a/api/clients/model/_basemodelprovider.py +++ b/api/clients/model/_basemodelprovider.py @@ -118,7 +118,6 @@ def _get_usage(self, request_content: RequestContent, response_data: dict | list tokenizer = getattr(global_context, "tokenizer", None) if tokenizer and request_content.endpoint in tokenizer.USAGE_ENDPOINTS: try: - completion_tokens = 0 prompt_tokens = tokenizer.get_prompt_tokens(endpoint=request_content.endpoint, body=request_content.json) completion_tokens = tokenizer.get_completion_tokens(endpoint=request_content.endpoint, response_data=response_data) total_tokens = prompt_tokens + completion_tokens @@ -154,7 +153,7 @@ def _format_request(self, request_content: RequestContent) -> RequestContent: Format a request to a provider model. This method can be overridden by a subclass to add additional headers or parameters. This method format the requested endpoint thanks the ENDPOINT_TABLE attribute. Args: - content(RequestContent): The request content to format. + request_content(RequestContent): The request content to format. Returns: content(RequestContent): The formatted request content. @@ -226,7 +225,8 @@ def _format_response(self, request_content: RequestContent, response: httpx.Resp return response - async def _ensure_timeseries_exists(self, redis_client: AsyncRedis, key: str) -> None: + @staticmethod + async def _ensure_timeseries_exists(redis_client: AsyncRedis, key: str) -> None: """ Ensure a time series exists with proper retention configuration. diff --git a/api/endpoints/health.py b/api/endpoints/health.py new file mode 100644 index 000000000..9ea6196d8 --- /dev/null +++ b/api/endpoints/health.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter +from starlette.responses import JSONResponse + +from api.utils.variables import RouterName + +router = APIRouter(tags=[RouterName.HEALTH.title()]) + + +@router.get(path="/health") +def health() -> JSONResponse: + return JSONResponse(content={"status": "ok"}, status_code=200) diff --git a/api/endpoints/monitoring.py b/api/endpoints/monitoring.py index ed80cbc61..c325509ef 100644 --- a/api/endpoints/monitoring.py +++ b/api/endpoints/monitoring.py @@ -3,16 +3,39 @@ from fastapi import Depends, FastAPI import prometheus_client from prometheus_client import CollectorRegistry, multiprocess -from prometheus_fastapi_instrumentator import Instrumentator +from prometheus_fastapi_instrumentator import Instrumentator, metrics from starlette.responses import Response from api.helpers._accesscontroller import AccessController +from api.helpers._metricsmiddleware import ( + inference_output_tokens_per_second, + inference_requests_duration_seconds, + inference_requests_total, + inference_tokens_total, + inference_ttft_milliseconds, +) from api.schemas.admin.roles import PermissionType from api.utils.variables import RouterName -def setup_prometheus(app: FastAPI, include_in_schema: bool = True) -> None: - app.instrumentator = Instrumentator().instrument(app=app) +def setup_prometheus(app: FastAPI, metric_namespace: str = "ogl", include_in_schema: bool = True) -> None: + app.instrumentator = ( + Instrumentator() + .instrument( + app=app, + ) + .add( + metrics.default( + metric_namespace=metric_namespace, + ), + inference_output_tokens_per_second(metric_namespace=metric_namespace), + inference_requests_total(metric_namespace=metric_namespace), + inference_requests_duration_seconds(metric_namespace=metric_namespace), + inference_ttft_milliseconds(metric_namespace=metric_namespace), + inference_tokens_total(metric_namespace=metric_namespace), + ) + .expose(app) + ) @app.get( path="/metrics", @@ -20,7 +43,7 @@ def setup_prometheus(app: FastAPI, include_in_schema: bool = True) -> None: dependencies=[Depends(dependency=AccessController(permissions=[PermissionType.READ_METRIC]))], include_in_schema=include_in_schema, ) - def metrics() -> Response: + def get_metrics() -> Response: if os.environ.get("PROMETHEUS_MULTIPROC_DIR"): registry = CollectorRegistry() multiprocess.MultiProcessCollector(registry) diff --git a/api/helpers/_metricsmiddleware.py b/api/helpers/_metricsmiddleware.py new file mode 100644 index 000000000..0e619fbed --- /dev/null +++ b/api/helpers/_metricsmiddleware.py @@ -0,0 +1,214 @@ +from collections.abc import Callable + +from prometheus_client import Counter, Histogram +from prometheus_fastapi_instrumentator.metrics import Info + +from api.utils.context import request_context + + +def _build_metric_name(namespace: str, name: str) -> str: + return f"{namespace}_{name}" if namespace else name + + +def inference_requests_total(metric_namespace: str = "") -> Callable[[Info], None]: + metric_name = _build_metric_name(metric_namespace, "inference_requests_total") + metric = Counter( + metric_name, + "Total number of LLM requests.", + labelnames=("endpoint", "model", "status_code"), + ) + + def instrumentation(info: Info) -> None: + try: + context = request_context.get() + model = context.router_name + endpoint = context.endpoint + if model and endpoint: + metric.labels( + endpoint=endpoint, + model=model, + status_code=info.modified_status, + ).inc() + except Exception: + pass + + return instrumentation + + +def inference_requests_duration_seconds(metric_namespace: str = "") -> Callable[[Info], None]: + metric_name = _build_metric_name(metric_namespace, "inference_requests_duration_seconds") + metric = Histogram( + metric_name, + "Duration of LLM requests in seconds.", + labelnames=("endpoint", "model", "status_code"), + buckets=( + 0.05, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1, + 1.5, + 2, + 2.5, + 3, + 3.5, + 4, + 4.5, + 5, + 6, + 7, + 8, + 9, + 10, + 15, + 20, + 25, + 30, + 45, + 60, + 75, + 90, + 105, + 120, + 150, + 180, + 210, + 240, + 270, + 300, + ), + ) + + def instrumentation(info: Info) -> None: + try: + context = request_context.get() + model = context.router_name + endpoint = context.endpoint + latency = context.latency + if model and endpoint and latency is not None: + metric.labels( + endpoint=endpoint, + model=model, + status_code=info.modified_status, + ).observe(latency / 1000) + except Exception: + pass + + return instrumentation + + +def inference_ttft_milliseconds(metric_namespace: str = "") -> Callable[[Info], None]: + metric_name = _build_metric_name(metric_namespace, "inference_ttft_milliseconds") + metric = Histogram( + metric_name, + "Time to first token for streaming LLM responses in milliseconds.", + labelnames=("endpoint", "model", "status_code"), + buckets=( + 5, + 10, + 20, + 30, + 50, + 75, + 100, + 150, + 200, + 300, + 500, + 750, + 1000, + 1500, + 2000, + 3000, + 5000, + 7500, + 10000, + 15000, + 20000, + 25000, + 30000, + 45000, + 60000, + 75000, + 90000, + 105000, + 120000, + 135000, + 150000, + 165000, + 180000, + 210000, + 240000, + 270000, + 300000, + ), + ) + + def instrumentation(info: Info) -> None: + try: + context = request_context.get() + model = context.router_name + endpoint = context.endpoint + ttft = context.ttft + if model and endpoint and ttft is not None: + metric.labels( + endpoint=endpoint, + model=model, + status_code=info.modified_status, + ).observe(ttft) + except Exception: + pass + + return instrumentation + + +def inference_output_tokens_per_second(metric_namespace: str = "") -> Callable[[Info], None]: + metric_name = _build_metric_name(metric_namespace, "inference_output_tokens_per_second") + metric = Histogram( + metric_name, + "Output generation speed in tokens per second (completion tokens / request duration, TTFT included).", + labelnames=("endpoint", "model"), + buckets=(5, 10, 20, 30, 50, 75, 85, 90, 95, 100, 105, 110, 115, 125, 150, 175, 200, 250, 300, 400, 500, 750, 1000), + ) + + def instrumentation(info: Info) -> None: + try: + context = request_context.get() + model = context.router_name + endpoint = context.endpoint + usage = context.usage + latency = context.latency + if model and endpoint and usage and latency and usage.completion_tokens: + metric.labels(endpoint=endpoint, model=model).observe(usage.completion_tokens / (latency / 1000)) + except Exception: + pass + + return instrumentation + + +def inference_tokens_total(metric_namespace: str = "") -> Callable[[Info], None]: + metric_name = _build_metric_name(metric_namespace, "inference_tokens_total") + metric = Counter( + metric_name, + "Total number of tokens consumed (prompt and completion).", + labelnames=("endpoint", "model", "type"), + ) + + def instrumentation(info: Info) -> None: + try: + context = request_context.get() + model = context.router_name + endpoint = context.endpoint + usage = context.usage + if model and endpoint and usage is not None: + if usage.prompt_tokens: + metric.labels(endpoint=endpoint, model=model, type="prompt").inc(usage.prompt_tokens) + if usage.completion_tokens: + metric.labels(endpoint=endpoint, model=model, type="completion").inc(usage.completion_tokens) + except Exception: + pass + + return instrumentation diff --git a/api/helpers/load_balancing/_leastbusyloadbalancingstrategy.py b/api/helpers/load_balancing/_leastbusyloadbalancingstrategy.py index 28896fe76..f03241895 100644 --- a/api/helpers/load_balancing/_leastbusyloadbalancingstrategy.py +++ b/api/helpers/load_balancing/_leastbusyloadbalancingstrategy.py @@ -20,7 +20,6 @@ def __init__(self, redis_client: AsyncRedis | Redis, load_balancing_metric: Metr Get a provider to handle the request based on the specified routing strategy. Args: - candidates (list[int]): The list of provider candidates (provider IDs) to choose from redis_client (AsyncRedis): Redis client instance, required for least busy strategy load_balancing_metric (Metric): The type of metric to use for performance evaluation diff --git a/api/utils/hooks_decorator.py b/api/utils/hooks_decorator.py index 32de46636..214f118a1 100644 --- a/api/utils/hooks_decorator.py +++ b/api/utils/hooks_decorator.py @@ -189,7 +189,7 @@ async def update_budget(usage: Usage): # Update the budget update_stmt = update(User).where(User.id == user_id).values(budget=new_budget, updated=func.now()).returning(User.budget) - result = await postgres_session.execute(update_stmt) + await postgres_session.execute(update_stmt) except Exception as e: logger.exception(f"Failed to update budget for user {user_id}: {e}") diff --git a/api/utils/variables.py b/api/utils/variables.py index 81617f416..0457b0516 100644 --- a/api/utils/variables.py +++ b/api/utils/variables.py @@ -19,6 +19,7 @@ class RouterName(StrEnum): COLLECTIONS = ("collections", "api.endpoints.collections") DOCUMENTS = ("documents", "api.endpoints.documents") EMBEDDINGS = ("embeddings", "api.endpoints.embeddings") + HEALTH = ("health", "api.endpoints.health") ME = ("me", "api.endpoints.me") MODELS = ("models", "api.infrastructure.fastapi.endpoints.models") MONITORING = ("monitoring", None) diff --git a/pyproject.toml b/pyproject.toml index c1323ed50..af7e74308 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "opengatellm" -version = "0.4.0post1" +version = "0.4.1" description = "OpenGateLLM project" requires-python = ">=3.12" license = { text = "MIT" }