Skip to content

feat!: use firework.bootstrap #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ jobs:
strategy:
matrix:
py_version:
- '3.8'
- '3.9'
- '3.10'
- '3.11'
- '3.12'
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down
4 changes: 4 additions & 0 deletions _bootstrap/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .context import ServiceContext as ServiceContext
from .core import Bootstrap as Bootstrap
from .core import UnhandledExit as UnhandledExit
from .service import Service as Service
132 changes: 132 additions & 0 deletions _bootstrap/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

import asyncio
from contextlib import asynccontextmanager
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, ClassVar

if TYPE_CHECKING:
from .core import Bootstrap


class _State(Enum):
PREPARE_PRE = auto()
PREPARE_POST = auto()
CLEANUP_PRE = auto()
CLEANUP_POST = auto()
READY = auto()


@dataclass
class ServiceContext:
State: ClassVar = _State

bootstrap: Bootstrap

def __post_init__(self):
self._state: _State | None = None
self._ready: bool | None = None
self._notify = asyncio.Event()
self._switch = asyncio.Event()
self._sigexit = asyncio.Event()

def _update(self):
self._notify.set()

def switch(self):
self._switch.set()

def enter(self):
if self._state is not _State.PREPARE_POST:
return

self._ready = True
self._update()

def skip(self):
if self._state is not _State.PREPARE_POST:
return

self._ready = False
self._update()

def exit(self):
"Call by the manager"
self._sigexit.set()

@property
def state(self):
if self._ready:
return self.State.READY
return self._state

async def wait_prepare_pre(self):
await self._switch.wait()
if self._state is not _State.PREPARE_PRE:
raise RuntimeError(f"expected {self.State.PREPARE_PRE}, got {self._state}")

self._switch.clear()
self._update()

async def wait_cleanup_pre(self):
await self._switch.wait()
if self._state is not _State.CLEANUP_PRE:
raise RuntimeError(f"expected {self.State.CLEANUP_PRE}, got {self._state}")

self._switch.clear()
self._update()

async def wait_prepare_post(self):
await self._switch.wait()
if self._state is not _State.PREPARE_POST:
raise RuntimeError(f"expected {self.State.PREPARE_POST}, got {self._state}")

self._switch.clear()

async def wait_cleanup_post(self):
await self._switch.wait()
if self._state is not _State.CLEANUP_POST:
raise RuntimeError(f"expected {self.State.CLEANUP_POST}, got {self._state}")

self._switch.clear()

async def wait_for_sigexit(self):
await self._sigexit.wait()

@property
def ready(self):
if self._ready is None:
raise RuntimeError("ServiceContext.ready is not available outside of prepare context")

return self._ready

@property
def should_exit(self):
return self._sigexit.is_set()

@asynccontextmanager
async def prepare(self):
self._state = _State.PREPARE_PRE
self.switch()
await self._notify.wait()
self._notify.clear()
yield
self._state = _State.PREPARE_POST
self.switch()
await self._notify.wait()
self._notify.clear()
self._state = None

@asynccontextmanager
async def cleanup(self):
self._state = _State.CLEANUP_PRE
self.switch()
await self._notify.wait()
self._notify.clear()
yield
self._state = _State.CLEANUP_POST
self.switch()
await self._notify.wait()
self._notify.clear()
self._state = None
218 changes: 218 additions & 0 deletions _bootstrap/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from __future__ import annotations

import asyncio
import signal
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any, Iterable

from exceptiongroup import ExceptionGroup # noqa: A004
from loguru import logger

from .context import ServiceContext
from .graph import ServiceGraph
from .utiles import TaskGroup, cvar, oneof, cancel_alive_tasks


if TYPE_CHECKING:
from .service import Service


class UnhandledExit(Exception):
pass


BOOTSTRAP_CONTEXT: ContextVar[Bootstrap] = ContextVar("BOOTSTRAP_CONTEXT")


class Bootstrap:
graph: ServiceGraph

def __init__(self):
self.graph = ServiceGraph()

async def spawn(self, *services: Service):
service_bind, previous, nexts = self.graph.subgraph(*services)
tasks: dict[str, asyncio.Task] = self.graph.tasks

prepare_errors: list[Exception] = []
cleanup_errors: list[Exception] = []

done_prepare: dict[str, None] = {}
pending_prepare = TaskGroup()
pending_cleanup = TaskGroup()
queued_prepare = {k: v.copy() for k, v in previous.maps[0].items()}
queued_cleanup = {k: v.copy() for k, v in nexts.maps[0].items()}

spawn_forward_prepare: bool = True

def spawn_prepare(service: Service):
async def prepare_guard():
nonlocal spawn_forward_prepare

context = ServiceContext(self)
self.graph.contexts[service.id] = context
task = tasks[service.id] = asyncio.create_task(service.launch(context))

await oneof(context.wait_prepare_pre(), task)
await oneof(context.wait_prepare_post(), task)

if task.done():
spawn_forward_prepare = False

prepare_errors.append(task.exception() or UnhandledExit()) # type: ignore
self.graph.drop(service)
return

done_prepare[service.id] = None

if not spawn_forward_prepare:
return

for next_service, barriers in list(queued_prepare.items()):
if service.id in barriers:
barriers.pop(service.id)

if not barriers:
spawn_prepare(service_bind[next_service])
queued_prepare.pop(next_service)

pending_prepare.spawn(prepare_guard())

def spawn_cleanup(service: Service):
async def cleanup_guard():
context = self.graph.contexts[service.id]
task = tasks[service.id]

context.exit()
await oneof(context.wait_cleanup_pre(), task)
await oneof(context.wait_cleanup_post(), task)

self.graph.drop(service)

if task.done():
cleanup_errors.append(task.exception() or UnhandledExit()) # type: ignore
return

for previous_service, barriers in list(queued_cleanup.items()):
if service.id in barriers:
barriers.pop(service.id)

if not barriers:
spawn_cleanup(service_bind[previous_service])
queued_cleanup.pop(previous_service)

pending_cleanup.spawn(cleanup_guard())

def toggle_enter():
for i in done_prepare:
self.graph.contexts[i].enter()

def toggle_skip():
for i in done_prepare:
self.graph.contexts[i].skip()

def rollback():
spawned = False

for i in done_prepare:
if not (nexts[i] & done_prepare.keys()):
spawned = True
spawn_cleanup(service_bind[i])

if not spawned:
raise RuntimeError("Unsatisfied dependencies, rollback failed")

return pending_cleanup.wait()

for i, v in previous.maps[0].items():
if not v:
spawn_prepare(service_bind[i])
queued_prepare.pop(i)

await pending_prepare

if queued_prepare:
toggle_skip()
await rollback()

if cleanup_errors:
raise RuntimeError("Unsatisfied dependencies") from ExceptionGroup("", cleanup_errors)

raise RuntimeError("Unsatisfied dependencies")

if prepare_errors:
toggle_skip()
await rollback()

if cleanup_errors:
raise ExceptionGroup("", cleanup_errors) from ExceptionGroup("", prepare_errors)

raise ExceptionGroup("", prepare_errors)

self.graph.apply(dict(service_bind), previous, nexts)
toggle_enter()

return rollback

async def launch(self, *services: Service):
rollback = await self.spawn(*services)
try:
await asyncio.gather(*[self.graph.contexts[i.id]._switch.wait() for i in services])
except asyncio.CancelledError:
pass
finally:
await rollback()

def launch_blocking(
self,
*services: Service,
loop: asyncio.AbstractEventLoop | None = None,
stop_signal: Iterable[signal.Signals] = (signal.SIGINT,),
):
import contextlib
import threading

loop = asyncio.new_event_loop()

logger.info("Starting launart main task...", style="green bold")

with cvar(BOOTSTRAP_CONTEXT, self):
launch_task = loop.create_task(self.launch(*services), name="amnesia-launch")

handled_signals: dict[signal.Signals, Any] = {}

def signal_handler(*args, **kwargs): # noqa: ARG001
for service in self.graph.services:
self.graph.contexts[service].exit()

if not launch_task.done():
launch_task.cancel()
# wakeup loop if it is blocked by select() with long timeout
launch_task.get_loop().call_soon_threadsafe(lambda: None)
logger.warning("Ctrl-C triggered by user.", style="dark_orange bold")

if threading.current_thread() is threading.main_thread(): # pragma: worst case
try:
for sig in stop_signal:
handled_signals[sig] = signal.getsignal(sig)
signal.signal(sig, signal_handler)
except ValueError: # pragma: no cover
# `signal.signal` may throw if `threading.main_thread` does
# not support signals
handled_signals.clear()

loop.run_until_complete(launch_task)

for sig, handler in handled_signals.items():
if signal.getsignal(sig) is signal_handler:
signal.signal(sig, handler)

try:
cancel_alive_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
with contextlib.suppress(RuntimeError, AttributeError):
# LINK: https://docs.python.org/3.10/library/asyncio-eventloop.html#asyncio.loop.shutdown_default_executor
loop.run_until_complete(loop.shutdown_default_executor()) # type: ignore
finally:
asyncio.set_event_loop(None)
logger.success("asyncio shutdown complete.", style="green bold")
Loading
Loading