Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/6958.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor Agent's kernel registry recovery by separating loader and writer
7 changes: 0 additions & 7 deletions src/ai/backend/agent/kernel_registry/kernel_registry.py

This file was deleted.

Empty file.
23 changes: 5 additions & 18 deletions src/ai/backend/agent/kernel_registry/loader/abc.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import logging
from abc import ABC, abstractmethod

from ai.backend.common.types import KernelId
from ai.backend.logging import BraceStyleAdapter

from ..kernel_registry import KernelRegistry
from .types import KernelRegistrySaveMetadata
from ....agent.kernel import AbstractKernel

log = BraceStyleAdapter(logging.getLogger(__spec__.name))


class AbstractKernelRegistryRecovery(ABC):
class AbstractKernelRegistryLoader(ABC):
"""
Recovery interface for loading and saving KernelRegistry
Loader interface for loading KernelRegistry
"""

@abstractmethod
async def load_kernel_registry(self) -> KernelRegistry:
async def load_kernel_registry(self) -> dict[KernelId, AbstractKernel]:
"""
Load the KernelRegistry from persistent storage.
Raises:
Expand All @@ -24,16 +24,3 @@ async def load_kernel_registry(self) -> KernelRegistry:
Returns: The loaded KernelRegistry.
"""
pass
Comment thread
fregataa marked this conversation as resolved.

@abstractmethod
async def save_kernel_registry(
self, registry: KernelRegistry, metadata: KernelRegistrySaveMetadata
) -> None:
"""
Save the KernelRegistry to persistent storage.
args:
registry: The KernelRegistry to save.
metadata: Additional metadata for saving.
Returns: None
"""
pass
59 changes: 59 additions & 0 deletions src/ai/backend/agent/kernel_registry/loader/pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import logging
import os
import pickle
import shutil
from pathlib import Path
from typing import override

from ai.backend.common.types import KernelId
from ai.backend.logging import BraceStyleAdapter

from ....agent.kernel import AbstractKernel
from ...exception import KernelRegistryLoadError, KernelRegistryNotFound
from .abc import AbstractKernelRegistryLoader

log = BraceStyleAdapter(logging.getLogger(__spec__.name))


class PickleBasedKernelRegistryLoader(AbstractKernelRegistryLoader):
def __init__(
self,
last_registry_file_path: Path,
fallback_registry_file_path: Path,
legacy_registry_file_path: Path,
) -> None:
self._last_registry_file_path = last_registry_file_path
self._fallback_registry_file_path = fallback_registry_file_path
self._legacy_registry_file_path = legacy_registry_file_path

@override
async def load_kernel_registry(self) -> dict[KernelId, AbstractKernel]:
legacy_registry_file = self._legacy_registry_file_path
fallback_registry_file = self._fallback_registry_file_path
final_file_path = self._last_registry_file_path
if not final_file_path.is_file():
log.warning(
"Registry file with name {} not found. "
"Falling back to path with local instance id: {}",
final_file_path,
fallback_registry_file,
)
final_file_path = fallback_registry_file
try:
if os.path.isfile(legacy_registry_file):
shutil.move(legacy_registry_file, final_file_path)
except Exception as e:
log.warning(
"Failed to move legacy kernel registry file {} to {} (err: {})",
str(legacy_registry_file),
str(final_file_path),
str(e),
)
try:
with open(final_file_path, "rb") as f:
return pickle.load(f)

Check notice

Code scanning / devskim

Deserializing attacker-supplied data using `pickle` or `cPickle` can result in code execution. Note

Do not deserialize untrusted data.
except EOFError as e:
log.warning("Failed to load the last kernel registry: {}", str(final_file_path))
raise KernelRegistryLoadError from e
except FileNotFoundError as e:
raise KernelRegistryNotFound from e
89 changes: 0 additions & 89 deletions src/ai/backend/agent/kernel_registry/loader/pickle_based.py

This file was deleted.

Empty file.
75 changes: 75 additions & 0 deletions src/ai/backend/agent/kernel_registry/recovery/recovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging
from collections.abc import Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Self

from ai.backend.common.types import AgentId, KernelId
from ai.backend.logging import BraceStyleAdapter

from ....agent.kernel import AbstractKernel
from ..loader.abc import AbstractKernelRegistryLoader
from ..loader.pickle import PickleBasedKernelRegistryLoader
from ..writer.abc import AbstractKernelRegistryWriter
from ..writer.pickle import PickleBasedKernelRegistryWriter
from ..writer.types import KernelRegistrySaveMetadata

log = BraceStyleAdapter(logging.getLogger(__spec__.name))


@dataclass
class KernelRegistryRecoveryArgs:
ipc_base_path: Path
var_base_path: Path
agent_id: AgentId
local_instance_id: str


class KernelRegistryRecovery:
def __init__(
self,
loaders: list[AbstractKernelRegistryLoader],
writer: AbstractKernelRegistryWriter,
) -> None:
self._loaders = loaders
self._writer = writer

@classmethod
def create(cls, args: KernelRegistryRecoveryArgs) -> Self:
registry_file_name = f"kernel_registry.{args.agent_id}.dat"
fallback_registry_file_name = f"kernel_registry.{args.local_instance_id}.dat"
legacy_registry_file_path = args.ipc_base_path / registry_file_name
fallback_registry_file_path = args.var_base_path / fallback_registry_file_name
last_registry_file_path = args.var_base_path / registry_file_name

return cls(
loaders=[
PickleBasedKernelRegistryLoader(
last_registry_file_path,
fallback_registry_file_path,
legacy_registry_file_path,
)
],
writer=PickleBasedKernelRegistryWriter(last_registry_file_path),
)

async def save_kernel_registry(
self, registry: Mapping[KernelId, AbstractKernel], metadata: KernelRegistrySaveMetadata
) -> None:
await self._writer.save_kernel_registry(registry, metadata)

async def load_kernel_registry(self) -> dict[KernelId, AbstractKernel]:
result: dict[KernelId, AbstractKernel] = {}
for loader in self._loaders:
try:
loaded = await loader.load_kernel_registry()
for kernel_id, kernel in loaded.items():
result[kernel_id] = kernel
except Exception as e:
log.warning(
"Failed to load kernel registry using loader {}, skip (err: {})",
loader.__class__.__name__,
str(e),
)
continue
return result
Empty file.
30 changes: 30 additions & 0 deletions src/ai/backend/agent/kernel_registry/writer/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Mapping

from ai.backend.common.types import KernelId
from ai.backend.logging import BraceStyleAdapter

from ....agent.kernel import AbstractKernel
from .types import KernelRegistrySaveMetadata

log = BraceStyleAdapter(logging.getLogger(__spec__.name))


class AbstractKernelRegistryWriter(ABC):
"""
Writer interface for saving kernel registry
"""

@abstractmethod
async def save_kernel_registry(
self, registry: Mapping[KernelId, AbstractKernel], metadata: KernelRegistrySaveMetadata
) -> None:
"""
Save the kernel registry to persistent storage.
args:
Comment thread
fregataa marked this conversation as resolved.
registry: The kernel registry to save.
metadata: Additional metadata for saving.
Returns: None
"""
pass
Comment thread
fregataa marked this conversation as resolved.
24 changes: 24 additions & 0 deletions src/ai/backend/agent/kernel_registry/writer/noop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import logging
from collections.abc import Mapping
from typing import override

from ai.backend.common.types import KernelId
from ai.backend.logging import BraceStyleAdapter

from ....agent.kernel import AbstractKernel
from .abc import AbstractKernelRegistryWriter
from .types import KernelRegistrySaveMetadata

log = BraceStyleAdapter(logging.getLogger(__spec__.name))


class NoopKernelRegistryWriter(AbstractKernelRegistryWriter):
def __init__(self) -> None:
pass

@override
async def save_kernel_registry(
self, registry: Mapping[KernelId, AbstractKernel], metadata: KernelRegistrySaveMetadata
) -> None:
log.debug("NoopKernelRegistryWriter: skipping save_kernel_registry")
return
Comment thread
fregataa marked this conversation as resolved.
50 changes: 50 additions & 0 deletions src/ai/backend/agent/kernel_registry/writer/pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import logging
import os
import pickle
import time
from collections.abc import Mapping
from pathlib import Path
from typing import override

from ai.backend.common.types import KernelId
from ai.backend.logging import BraceStyleAdapter

from ....agent.kernel import AbstractKernel
from .abc import AbstractKernelRegistryWriter
from .types import KernelRegistrySaveMetadata

log = BraceStyleAdapter(logging.getLogger(__spec__.name))


SAVE_COOL_DOWN_SECONDS = 60


class PickleBasedKernelRegistryWriter(AbstractKernelRegistryWriter):
def __init__(self, last_registry_file_path: Path) -> None:
self._last_registry_file_path = last_registry_file_path
self._last_saved_time = time.monotonic()

@override
async def save_kernel_registry(
self, registry: Mapping[KernelId, AbstractKernel], metadata: KernelRegistrySaveMetadata
) -> None:
now = time.monotonic()
if (not metadata.force) and (now <= self._last_saved_time + SAVE_COOL_DOWN_SECONDS):
return # don't save too frequently
last_registry_file = self._last_registry_file_path
try:
with open(last_registry_file, "wb") as f:
pickle.dump(dict(registry), f)
self._last_saved_time = now
log.debug("Saved kernel registry to {}", str(last_registry_file))
except Exception as e:
log.exception(
"Failed to save kernel registry to {} (registry: {})",
str(last_registry_file),
str(registry),
exc_info=e,
)
try:
os.remove(last_registry_file)
except FileNotFoundError:
Comment thread
fregataa marked this conversation as resolved.
pass
Loading