Skip to content

Commit 1d44e80

Browse files
amitoanfredette
authored andcommitted
Merge pull request llm-d-incubation#115 from amito/refactor/fastapi-state
refactor: use FastAPI app.state for dependency injection
2 parents 37cfd27 + 35474ce commit 1d44e80

File tree

6 files changed

+179
-145
lines changed

6 files changed

+179
-145
lines changed

src/neuralnav/api/app.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""FastAPI application factory for NeuralNav API."""
22

3+
import asyncio
34
import logging
45
import os
6+
from contextlib import asynccontextmanager
57

68
from fastapi import FastAPI
79
from fastapi.middleware.cors import CORSMiddleware
810

9-
from .routes import (
11+
from neuralnav.api.routes import (
1012
configuration_router,
1113
database_router,
1214
health_router,
@@ -27,12 +29,30 @@
2729
logger = logging.getLogger(__name__)
2830

2931

32+
@asynccontextmanager
33+
async def lifespan(app: FastAPI):
34+
"""Initialize all singletons on app.state during startup."""
35+
from neuralnav.api.dependencies import init_app_state
36+
37+
logger.info("Initializing app state...")
38+
try:
39+
await asyncio.to_thread(init_app_state, app)
40+
except Exception:
41+
logger.exception("App state initialization failed during startup")
42+
raise
43+
# Create asyncio.Lock in the event loop thread (not in the worker thread
44+
# where init_app_state runs) to avoid cross-loop binding issues.
45+
app.state.cluster_manager_lock = asyncio.Lock()
46+
yield
47+
48+
3049
def create_app() -> FastAPI:
3150
"""Create and configure the FastAPI application."""
3251
app = FastAPI(
3352
title="NeuralNav API",
3453
description="API for LLM deployment recommendations",
3554
version="0.1.0",
55+
lifespan=lifespan,
3656
)
3757

3858
# Add CORS middleware

src/neuralnav/api/dependencies.py

Lines changed: 65 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
"""Shared dependencies for API routes.
22
3-
This module provides singleton instances and dependency injection
4-
for the API routes. All shared state is initialized here.
3+
This module provides singleton instances via FastAPI's app.state and
4+
dependency injection via Depends(). All shared state is initialized
5+
during the application lifespan in init_app_state().
56
"""
67

8+
import asyncio
79
import logging
810
import os
11+
from typing import cast
12+
13+
from fastapi import FastAPI, HTTPException, Request, status
14+
from starlette.concurrency import run_in_threadpool
915

1016
from neuralnav.cluster import KubernetesClusterManager, KubernetesDeploymentError
1117
from neuralnav.configuration import DeploymentGenerator, YAMLValidator
1218
from neuralnav.knowledge_base.model_catalog import ModelCatalog
1319
from neuralnav.knowledge_base.slo_templates import SLOTemplateRepository
1420
from neuralnav.orchestration.workflow import RecommendationWorkflow
15-
from neuralnav.shared.schemas import DeploymentMode
1621

1722
# Configure logging
1823
debug_mode = os.getenv("NEURALNAV_DEBUG", "false").lower() == "true"
@@ -24,97 +29,80 @@
2429
)
2530
logger = logging.getLogger(__name__)
2631

27-
# Singleton instances
28-
_workflow: RecommendationWorkflow | None = None
29-
_model_catalog: ModelCatalog | None = None
30-
_slo_repo: SLOTemplateRepository | None = None
31-
_deployment_generator: DeploymentGenerator | None = None
32-
_yaml_validator: YAMLValidator | None = None
33-
_cluster_manager: KubernetesClusterManager | None = None
3432

33+
# ---------------------------------------------------------------------------
34+
# Lifespan: initialize all singletons on app.state
35+
# ---------------------------------------------------------------------------
3536

36-
def get_workflow() -> RecommendationWorkflow:
37-
"""Get the recommendation workflow singleton."""
38-
global _workflow
39-
if _workflow is None:
40-
_workflow = RecommendationWorkflow()
41-
return _workflow
4237

38+
def init_app_state(app: FastAPI) -> None:
39+
"""Initialize all singletons on app.state during lifespan startup."""
40+
app.state.model_catalog = ModelCatalog()
41+
app.state.slo_repo = SLOTemplateRepository()
42+
app.state.deployment_generator = DeploymentGenerator(simulator_mode=False)
43+
app.state.yaml_validator = YAMLValidator()
44+
app.state.cluster_managers = {} # dict[str, KubernetesClusterManager]
45+
app.state.workflow = RecommendationWorkflow()
4346

44-
def get_model_catalog() -> ModelCatalog:
45-
"""Get the model catalog singleton."""
46-
global _model_catalog
47-
if _model_catalog is None:
48-
_model_catalog = ModelCatalog()
49-
return _model_catalog
5047

48+
# ---------------------------------------------------------------------------
49+
# Depends() providers — read from request.app.state
50+
# ---------------------------------------------------------------------------
5151

52-
def get_slo_repo() -> SLOTemplateRepository:
53-
"""Get the SLO template repository singleton."""
54-
global _slo_repo
55-
if _slo_repo is None:
56-
_slo_repo = SLOTemplateRepository()
57-
return _slo_repo
5852

53+
def get_workflow(request: Request) -> RecommendationWorkflow:
54+
"""Get the recommendation workflow singleton."""
55+
return cast(RecommendationWorkflow, request.app.state.workflow)
5956

60-
def get_deployment_generator() -> DeploymentGenerator:
61-
"""Get the deployment generator singleton."""
62-
global _deployment_generator
63-
if _deployment_generator is None:
64-
_deployment_generator = DeploymentGenerator(simulator_mode=False)
65-
logger.info("Deployment generator initialized (simulator_mode=False)")
66-
return _deployment_generator
6757

58+
def get_model_catalog(request: Request) -> ModelCatalog:
59+
"""Get the model catalog singleton."""
60+
return cast(ModelCatalog, request.app.state.model_catalog)
6861

69-
def get_deployment_mode() -> DeploymentMode:
70-
"""Return the current deployment mode."""
71-
gen = get_deployment_generator()
72-
return DeploymentMode.SIMULATOR if gen.simulator_mode else DeploymentMode.PRODUCTION
7362

63+
def get_slo_repo(request: Request) -> SLOTemplateRepository:
64+
"""Get the SLO template repository singleton."""
65+
return cast(SLOTemplateRepository, request.app.state.slo_repo)
7466

75-
def set_deployment_mode(mode: DeploymentMode) -> DeploymentMode:
76-
"""Set the deployment mode and return the new mode."""
77-
gen = get_deployment_generator()
78-
gen.simulator_mode = mode == DeploymentMode.SIMULATOR
79-
logger.info(f"Deployment mode changed to: {mode.value}")
80-
return mode
8167

68+
def get_deployment_generator(request: Request) -> DeploymentGenerator:
69+
"""Get the deployment generator singleton."""
70+
return cast(DeploymentGenerator, request.app.state.deployment_generator)
8271

83-
def get_yaml_validator() -> YAMLValidator:
84-
"""Get the YAML validator singleton."""
85-
global _yaml_validator
86-
if _yaml_validator is None:
87-
_yaml_validator = YAMLValidator()
88-
return _yaml_validator
8972

73+
def get_yaml_validator(request: Request) -> YAMLValidator:
74+
"""Get the YAML validator singleton."""
75+
return cast(YAMLValidator, request.app.state.yaml_validator)
9076

91-
def get_cluster_manager(namespace: str = "default") -> KubernetesClusterManager | None:
92-
"""Get or create a cluster manager.
9377

94-
Returns None if cluster is not accessible.
95-
"""
96-
global _cluster_manager
97-
if _cluster_manager is None:
98-
try:
99-
_cluster_manager = KubernetesClusterManager(namespace=namespace)
100-
logger.info("Kubernetes cluster manager initialized successfully")
101-
except KubernetesDeploymentError as e:
102-
logger.info(f"Kubernetes cluster not accessible: {e}")
103-
return None
104-
return _cluster_manager
78+
_MAX_CACHED_NAMESPACES = 32
10579

10680

107-
def get_cluster_manager_or_raise(namespace: str = "default") -> KubernetesClusterManager:
81+
async def get_cluster_manager_or_raise(
82+
request: Request, namespace: str = "default"
83+
) -> KubernetesClusterManager:
10884
"""Get or create a cluster manager, raising an exception if not accessible."""
109-
manager = get_cluster_manager(namespace)
110-
if manager is None:
111-
try:
112-
return KubernetesClusterManager(namespace=namespace)
113-
except KubernetesDeploymentError as e:
114-
from fastapi import HTTPException, status
115-
116-
raise HTTPException(
117-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
118-
detail=f"Kubernetes cluster not accessible: {str(e)}",
119-
) from e
120-
return manager
85+
managers: dict[str, KubernetesClusterManager] = request.app.state.cluster_managers
86+
if namespace not in managers:
87+
lock = cast(asyncio.Lock, request.app.state.cluster_manager_lock)
88+
async with lock:
89+
if namespace not in managers:
90+
if len(managers) >= _MAX_CACHED_NAMESPACES:
91+
raise HTTPException(
92+
status_code=status.HTTP_400_BAD_REQUEST,
93+
detail=f"Too many namespaces (limit {_MAX_CACHED_NAMESPACES})",
94+
)
95+
try:
96+
managers[namespace] = await run_in_threadpool(
97+
KubernetesClusterManager, namespace=namespace
98+
)
99+
logger.info(
100+
"Kubernetes cluster manager initialized for namespace=%s",
101+
namespace,
102+
)
103+
except KubernetesDeploymentError as e:
104+
raise HTTPException(
105+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
106+
detail=f"Kubernetes cluster not accessible: {e}",
107+
) from e
108+
return managers[namespace]

0 commit comments

Comments
 (0)