Skip to content

Commit e078851

Browse files
authored
Merge pull request #10 from SocAIty/dev
Dev
2 parents 842cb01 + df5455c commit e078851

1 file changed

Lines changed: 84 additions & 50 deletions

File tree

apipod/engine/backend/fastapi/router.py

Lines changed: 84 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import threading
44
import logging
5+
from contextlib import asynccontextmanager
56
from typing import Union, Callable, get_type_hints, Generator, AsyncGenerator, Iterator, AsyncIterator
67
from fastapi import APIRouter, FastAPI, Response
78
from fastapi.responses import JSONResponse
@@ -34,6 +35,7 @@ def __init__(
3435
prefix: str = "", # "/api",
3536
max_upload_file_size_mb: float = None,
3637
job_queue=None,
38+
lifespan=None,
3739
*args,
3840
**kwargs):
3941
"""
@@ -46,12 +48,17 @@ def __init__(
4648
prefix: The API route prefix
4749
max_upload_file_size_mb: Maximum file size in MB for uploads
4850
job_queue: Optional custom JobQueue implementation
51+
lifespan: Optional async context manager for custom startup/shutdown logic
4952
args: Additional arguments
5053
kwargs: Additional keyword arguments
5154
"""
55+
# Extract user-provided lifespan (explicit param or kwarg) before parent init
56+
user_lifespan = lifespan or kwargs.pop('lifespan', None)
57+
5258
# Initialize parent classes
5359
api_router_params = inspect.signature(APIRouter.__init__).parameters
5460
api_router_kwargs = {k: kwargs.get(k) for k in api_router_params if k in kwargs}
61+
api_router_kwargs.pop('lifespan', None) # handled via composed lifespan below
5562

5663
APIRouter.__init__(self, **api_router_kwargs)
5764
_BaseBackend.__init__(self, title=title, summary=summary, *args, **kwargs)
@@ -61,27 +68,35 @@ def __init__(
6168

6269
self.status = SERVER_HEALTH.INITIALIZING
6370

71+
# Registry for functions that workers can execute. Keys are function names.
72+
self._job_func_registry: dict = {}
73+
# Stop event and thread handle for in-process worker (dev mode)
74+
self._worker_stop_event = threading.Event()
75+
self._worker_thread: threading.Thread | None = None
76+
self._logger = logging.getLogger(__name__)
77+
78+
# Build a composed lifespan that merges internal worker hooks with the user-provided lifespan
79+
combined_lifespan = self._build_lifespan(user_lifespan)
80+
6481
# Create or use provided FastAPI app
6582
if app is None:
6683
app = FastAPI(
6784
title=self.title,
6885
summary=self.summary,
69-
contact={"name": "SocAIty", "url": "https://www.socaity.ai"}
86+
contact={"name": "SocAIty", "url": "https://www.socaity.ai"},
87+
lifespan=combined_lifespan,
7088
)
89+
else:
90+
# Existing app: replace its lifespan with our composed version
91+
app.router.lifespan_context = combined_lifespan
7192

7293
self.app: FastAPI = app
7394
self.prefix = prefix
7495
self.add_standard_routes()
7596

76-
# Registry for functions that workers can execute. Keys are function names.
77-
self._job_func_registry: dict = {}
78-
# Stop event and thread handle for in-process worker (dev mode)
79-
self._worker_stop_event = threading.Event()
80-
self._worker_thread: threading.Thread | None = None
81-
self._logger = logging.getLogger(__name__)
8297
self._endpoint_configurator = FastApiEndpointConfigurator(self)
8398

84-
# excpetion handling
99+
# Exception handling
85100
_FastAPIExceptionHandler.__init__(self)
86101
if not getattr(self.app.state, "_socaity_exception_handler_added", False):
87102
self.app.add_exception_handler(Exception, self.global_exception_handler)
@@ -91,50 +106,69 @@ def __init__(
91106
self._orig_openapi_func = self.app.openapi
92107
self.app.openapi = self.custom_openapi
93108

94-
# Start in-process worker on FastAPI startup (dev convenience).
95-
# Only start if a job_queue with `start_worker` exists.
96-
if not getattr(self.app.state, "_socaity_worker_hooks_added", False):
97-
def _startup():
98-
try:
99-
if self.job_queue and hasattr(self.job_queue, "start_worker"):
100-
# Start worker in a daemon thread so it doesn't block uvicorn
101-
def _run():
102-
try:
103-
self.job_queue.start_worker(
104-
func_registry=self._job_func_registry,
105-
worker_name="api-worker",
106-
stop_event=self._worker_stop_event,
107-
)
108-
except Exception:
109-
self._logger.exception("Worker thread exited with exception")
110-
111-
t = threading.Thread(target=_run, daemon=True)
112-
t.start()
113-
self._worker_thread = t
114-
except Exception:
115-
self._logger.exception("Failed to start in-process worker on startup")
116-
117-
def _shutdown():
118-
try:
119-
# Signal local worker to stop
109+
# ------------------------------------------------------------------
110+
# Lifespan & worker lifecycle
111+
# ------------------------------------------------------------------
112+
113+
def _build_lifespan(self, user_lifespan=None):
114+
"""
115+
Build a composed lifespan context manager that runs:
116+
1. Internal worker startup
117+
2. User-provided lifespan (if any)
118+
3. Internal worker shutdown on exit
119+
"""
120+
router_self = self # capture for closure
121+
122+
@asynccontextmanager
123+
async def _combined_lifespan(app):
124+
router_self._start_background_worker()
125+
try:
126+
if user_lifespan:
127+
async with user_lifespan(app):
128+
yield
129+
else:
130+
yield
131+
finally:
132+
router_self._stop_background_worker()
133+
134+
return _combined_lifespan
135+
136+
def _start_background_worker(self):
137+
"""Start the in-process job queue worker in a daemon thread (dev convenience)."""
138+
try:
139+
if self.job_queue and hasattr(self.job_queue, "start_worker"):
140+
def _run():
120141
try:
121-
self._worker_stop_event.set()
142+
self.job_queue.start_worker(
143+
func_registry=self._job_func_registry,
144+
worker_name="api-worker",
145+
stop_event=self._worker_stop_event,
146+
)
122147
except Exception:
123-
pass
124-
125-
# Call job_queue.shutdown if available
126-
if self.job_queue and hasattr(self.job_queue, "shutdown"):
127-
try:
128-
self.job_queue.shutdown()
129-
except Exception:
130-
self._logger.exception("Error shutting down job queue")
131-
except Exception:
132-
self._logger.exception("Error during worker shutdown handler")
133-
134-
# Register handlers
135-
self.app.add_event_handler("startup", _startup)
136-
self.app.add_event_handler("shutdown", _shutdown)
137-
self.app.state._socaity_worker_hooks_added = True
148+
self._logger.exception("Worker thread exited with exception")
149+
150+
thread = threading.Thread(target=_run, daemon=True)
151+
thread.start()
152+
self._worker_thread = thread
153+
except Exception:
154+
self._logger.exception("Failed to start in-process worker on startup")
155+
156+
def _stop_background_worker(self):
157+
"""Signal the background worker to stop and shut down the job queue."""
158+
try:
159+
self._worker_stop_event.set()
160+
except Exception:
161+
pass
162+
163+
if self.job_queue and hasattr(self.job_queue, "shutdown"):
164+
try:
165+
self.job_queue.shutdown()
166+
except Exception:
167+
self._logger.exception("Error shutting down job queue")
168+
169+
# ------------------------------------------------------------------
170+
# Standard routes
171+
# ------------------------------------------------------------------
138172

139173
def add_standard_routes(self):
140174
"""Add standard API routes for status and health checks."""

0 commit comments

Comments
 (0)