Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ profile.json
profile.html
.vscode/settings.json
*.jsonl
coverage.xml
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ test-verbose: #? run the tests using pytest-xdist with DEBUG logging.
$(activate_venv) && pytest -n auto -v -s --log-cli-level DEBUG

coverage: #? run the tests and generate an html coverage report.
$(activate_venv) && pytest -n auto --cov=aiperf --cov-branch --cov-report=html --cov-report=xml $(args)
$(activate_venv) && pytest -n auto --cov=aiperf --cov-branch --cov-report=html --cov-report=xml --cov-report=term $(args)

install: #? install the project in editable mode.
$(activate_venv) && uv pip install -e ".[dev]" $(args)
Expand Down
1 change: 1 addition & 0 deletions aiperf/common/config/dev_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def print_developer_mode_warning() -> None:
title_align="left",
)
console.print(panel)
console.file.flush()


if AIPERF_DEV_MODE:
Expand Down
41 changes: 37 additions & 4 deletions aiperf/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,32 @@ def raw_str(self) -> str:

def __str__(self) -> str:
"""Return the string representation of the exception with the class name."""
return f"{self.__class__.__name__}: {super().__str__()}"
return super().__str__()


class AIPerfMultiError(AIPerfError):
"""Exception raised when running multiple tasks and one or more fail."""

def __init__(self, message: str, exceptions: list[Exception]) -> None:
def __init__(self, message: str | None, exceptions: list[Exception]) -> None:
self.exceptions = exceptions

err_strings = [
e.raw_str() if isinstance(e, AIPerfError) else str(e) for e in exceptions
]
super().__init__(f"{message}: {','.join(err_strings)}")
self.exceptions = exceptions
if message:
super().__init__(f"{message}: {','.join(err_strings)}")
else:
super().__init__(",".join(err_strings))


class HookError(AIPerfError):
"""Exception raised when a hook encounters an error."""

def __init__(self, hook_class_name: str, hook_func_name: str, e: Exception) -> None:
self.hook_class_name = hook_class_name
self.hook_func_name = hook_func_name
self.exception = e
super().__init__(f"{hook_class_name}.{hook_func_name}: {e}")


class ServiceError(AIPerfError):
Expand All @@ -46,6 +60,25 @@ def __init__(
self.service_id = service_id


class LifecycleOperationError(AIPerfError):
"""Exception raised when a lifecycle operation fails and the lifecycle should stop gracefully."""

def __init__(
self,
operation: str,
original_exception: Exception | None,
lifecycle_id: str,
) -> None:
self.operation = operation
self.original_exception = original_exception
self.lifecycle_id = lifecycle_id
super().__init__(
str(original_exception)
if original_exception
else f"Failed to perform operation '{operation}'"
)


class CommunicationError(AIPerfError):
"""Generic communication error."""

Expand Down
69 changes: 63 additions & 6 deletions aiperf/common/mixins/aiperf_lifecycle_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@

import asyncio
import uuid
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager

from aiperf.common.decorators import implements_protocol
from aiperf.common.enums import LifecycleState
from aiperf.common.exceptions import InvalidStateError
from aiperf.common.exceptions import (
InvalidStateError,
LifecycleOperationError,
)
from aiperf.common.hooks import (
AIPerfHook,
BackgroundTaskParams,
Expand All @@ -17,6 +22,7 @@
)
from aiperf.common.mixins.hooks_mixin import HooksMixin
from aiperf.common.mixins.task_manager_mixin import TaskManagerMixin
from aiperf.common.models import ErrorDetails, ExitErrorInfo
from aiperf.common.protocols import AIPerfLifecycleProtocol


Expand Down Expand Up @@ -49,6 +55,7 @@ def __init__(self, id: str | None = None, **kwargs) -> None:
self._stop_requested_event = asyncio.Event()
self.stopped_event = asyncio.Event() # set on stop or failure
self._children: list[AIPerfLifecycleProtocol] = []
self._exit_errors: list[ExitErrorInfo] = []
if "logger_name" not in kwargs:
kwargs["logger_name"] = self.id
super().__init__(**kwargs)
Expand Down Expand Up @@ -179,10 +186,50 @@ async def start(self) -> None:
self.started_event,
)

@asynccontextmanager
async def try_operation_or_stop(self, operation: str) -> AsyncIterator[None]:
"""Context manager to try an operation or stop the lifecycle on failure.

This context manager catches any exception, logs it, and raises a LifecycleOperationError
which will be caught by the lifecycle system and trigger a graceful shutdown via _fail().

Args:
operation: Description of the operation being performed

Raises:
LifecycleOperationError: When the operation fails, triggering graceful shutdown
"""
try:
yield
except LifecycleOperationError as e:
# Log error and re-raise without wrapping to avoid duplicate error info
self.error(f"Failed to {operation.lower()}: {e}")
self._exit_errors.append(
ExitErrorInfo(
error_details=ErrorDetails.from_exception(e),
operation=operation,
service_id=self.id,
)
)
raise
except Exception as e:
self.error(f"Failed to {operation.lower()}: {e}")
error = LifecycleOperationError(
operation=operation,
original_exception=e,
lifecycle_id=self.id,
)
self._exit_errors.append(
ExitErrorInfo.from_lifecycle_operation_error(error)
)
raise error from e

async def initialize_and_start(self) -> None:
"""Initialize and start the lifecycle. This is a convenience method that calls `initialize` and `start` in sequence."""
await self.initialize()
await self.start()
async with self.try_operation_or_stop("Initialize"):
await self.initialize()
async with self.try_operation_or_stop("Start"):
await self.start()

async def stop(self) -> None:
"""Stop the lifecycle and run the @on_stop hooks.
Expand Down Expand Up @@ -232,10 +279,20 @@ async def _fail(self, e: Exception) -> None:
"""Set the state to FAILED and raise an asyncio.CancelledError.
This is used when the transition from one state to another fails.
"""
self.error(f"Failed for {self}: {e}")
if not isinstance(e, LifecycleOperationError):
# Only add to exit errors if the exception has not already been added
self._exit_errors.append(
ExitErrorInfo(
error_details=ErrorDetails.from_exception(e),
operation=self.state.title(),
service_id=self.id,
)
)
if self.state != LifecycleState.STOPPING:
self.debug(f"Stopping {self} due to failure")
await self.stop()
await self._set_state(LifecycleState.FAILED)
self.exception(f"Failed for {self}: {e}")
self.stop_requested = True
self.stopped_event.set()
raise asyncio.CancelledError(f"Failed for {self}: {e}") from e

def attach_child_lifecycle(self, child: AIPerfLifecycleProtocol) -> None:
Expand Down
9 changes: 3 additions & 6 deletions aiperf/common/mixins/hooks_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from aiperf.common import aiperf_logger
from aiperf.common.decorators import implements_protocol
from aiperf.common.exceptions import AIPerfMultiError, UnsupportedHookError
from aiperf.common.exceptions import AIPerfMultiError, HookError, UnsupportedHookError
from aiperf.common.hooks import Hook, HookAttrs, HookType
from aiperf.common.mixins.aiperf_logger_mixin import AIPerfLoggerMixin
from aiperf.common.protocols import HooksProtocol
Expand Down Expand Up @@ -184,15 +184,12 @@ async def run_hooks(
try:
await hook(**kwargs)
except Exception as e:
exceptions.append(e)
exceptions.append(HookError(self.__class__.__name__, hook.func_name, e))
self.exception(
f"Error running {hook!r} hook for {self.__class__.__name__}: {e}"
)
if exceptions:
raise AIPerfMultiError(
f"Errors running {hook_types} hooks for {self.__class__.__name__}",
exceptions,
)
raise AIPerfMultiError(None, exceptions)


# Add this file as one to be ignored when finding the caller of aiperf_logger.
Expand Down
2 changes: 2 additions & 0 deletions aiperf/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from aiperf.common.models.error_models import (
ErrorDetails,
ErrorDetailsCount,
ExitErrorInfo,
)
from aiperf.common.models.health_models import (
CPUTimes,
Expand Down Expand Up @@ -80,6 +81,7 @@
"EmbeddingResponseData",
"ErrorDetails",
"ErrorDetailsCount",
"ExitErrorInfo",
"FullPhaseProgress",
"IOCounters",
"Image",
Expand Down
25 changes: 25 additions & 0 deletions aiperf/common/models/error_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pydantic import Field

from aiperf.common.exceptions import LifecycleOperationError
from aiperf.common.models.base_models import AIPerfBaseModel


Expand Down Expand Up @@ -46,6 +47,30 @@ def from_exception(cls, e: BaseException) -> "ErrorDetails":
)


class ExitErrorInfo(AIPerfBaseModel):
"""Information about an error that should cause the process to exit."""

error_details: ErrorDetails
operation: str = Field(
...,
description="The operation that caused the error.",
)
service_id: str | None = Field(
default=None,
description="The ID of the service that caused the error. If None, the error is not specific to a service.",
)

@classmethod
def from_lifecycle_operation_error(
cls, e: LifecycleOperationError
) -> "ExitErrorInfo":
return cls(
error_details=ErrorDetails.from_exception(e.original_exception or e),
operation=e.operation,
service_id=e.lifecycle_id,
)


class ErrorDetailsCount(AIPerfBaseModel):
"""Count of error details."""

Expand Down
4 changes: 4 additions & 0 deletions aiperf/controller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from aiperf.controller.base_service_manager import (
BaseServiceManager,
)
from aiperf.controller.controller_utils import (
print_exit_errors,
)
from aiperf.controller.kubernetes_service_manager import (
KubernetesServiceManager,
ServiceKubernetesRunInfo,
Expand Down Expand Up @@ -32,4 +35,5 @@
"SignalHandlerMixin",
"SystemController",
"main",
"print_exit_errors",
]
96 changes: 96 additions & 0 deletions aiperf/controller/controller_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import textwrap
from collections import defaultdict

from rich.console import Console
from rich.panel import Panel
from rich.text import Text

from aiperf.common.models import ErrorDetails, ExitErrorInfo


def _group_errors_by_details(
exit_errors: list[ExitErrorInfo],
) -> dict[ErrorDetails, list[ExitErrorInfo]]:
"""Group exit errors by their error details to deduplicate similar errors."""
grouped_errors: dict[ErrorDetails, list[ExitErrorInfo]] = defaultdict(list)
for error in exit_errors:
grouped_errors[error.error_details].append(error)
return dict(grouped_errors)


def print_exit_errors(
exit_errors: list[ExitErrorInfo] | None = None, console: Console | None = None
) -> None:
"""Display command errors to the user with deduplication of similar errors."""
if not exit_errors:
return
console = console or Console()

def _create_field(
label: str, value: str, prefix: str = " ", end: str = "\n"
) -> Text:
"""Helper to create a formatted field for error display."""
return Text.assemble(
Text(f"{prefix}{label}: ", style="bold yellow"),
Text(f"{value}{end}", style="bold"),
)

grouped_errors = _group_errors_by_details(exit_errors)

summary = []
for i, (error_details, error_list) in enumerate(grouped_errors.items()):
operations = {error.operation for error in error_list}
operation_display = (
next(iter(operations)) if len(operations) == 1 else "Multiple Operations"
)

affected_services = sorted({error.service_id or "N/A" for error in error_list})
service_count = len(affected_services)

if service_count == 1:
service_display = affected_services[0]
elif service_count <= 3:
service_display = (
f"{service_count} services: {', '.join(affected_services)}"
)
else:
shown_services = affected_services[:2]
service_display = (
f"{service_count} services: {', '.join(shown_services)}, etc."
)

summary.append(
_create_field(
"Services" if service_count > 1 else "Service",
service_display,
prefix="• ",
)
)
summary.append(_create_field("Operation", operation_display))
summary.append(_create_field("Error", error_details.type or "Unknown"))

# Account for panel borders and indentation, and ensure a minimum width for narrow consoles
wrap_width = max(console.size.width - 15, 20)

wrapped_text = textwrap.fill(
error_details.message,
width=wrap_width,
subsequent_indent=" " * 11, # aligns with " Reason: "
)

end = "\n\n" if i < len(grouped_errors) - 1 else ""
summary.append(_create_field("Reason", wrapped_text, end=end))

console.print()
console.print(
Panel(
Text.assemble(*summary),
border_style="bold red",
title="AIPerf System Exit Errors",
title_align="left",
)
)
console.file.flush()
Loading