|
1 | 1 | """Shared dependencies for API routes. |
2 | 2 |
|
3 | | -This module provides singleton instances and dependency injection |
4 | | -for the API routes. All shared state is initialized here. |
| 3 | +This module provides singleton instances via FastAPI's app.state and |
| 4 | +dependency injection via Depends(). All shared state is initialized |
| 5 | +during the application lifespan in init_app_state(). |
5 | 6 | """ |
6 | 7 |
|
| 8 | +import asyncio |
7 | 9 | import logging |
8 | 10 | import os |
| 11 | +from typing import cast |
| 12 | + |
| 13 | +from fastapi import FastAPI, HTTPException, Request, status |
| 14 | +from starlette.concurrency import run_in_threadpool |
9 | 15 |
|
10 | 16 | from neuralnav.cluster import KubernetesClusterManager, KubernetesDeploymentError |
11 | 17 | from neuralnav.configuration import DeploymentGenerator, YAMLValidator |
12 | 18 | from neuralnav.knowledge_base.model_catalog import ModelCatalog |
13 | 19 | from neuralnav.knowledge_base.slo_templates import SLOTemplateRepository |
14 | 20 | from neuralnav.orchestration.workflow import RecommendationWorkflow |
15 | | -from neuralnav.shared.schemas import DeploymentMode |
16 | 21 |
|
17 | 22 | # Configure logging |
18 | 23 | debug_mode = os.getenv("NEURALNAV_DEBUG", "false").lower() == "true" |
|
24 | 29 | ) |
25 | 30 | logger = logging.getLogger(__name__) |
26 | 31 |
|
27 | | -# Singleton instances |
28 | | -_workflow: RecommendationWorkflow | None = None |
29 | | -_model_catalog: ModelCatalog | None = None |
30 | | -_slo_repo: SLOTemplateRepository | None = None |
31 | | -_deployment_generator: DeploymentGenerator | None = None |
32 | | -_yaml_validator: YAMLValidator | None = None |
33 | | -_cluster_manager: KubernetesClusterManager | None = None |
34 | 32 |
|
| 33 | +# --------------------------------------------------------------------------- |
| 34 | +# Lifespan: initialize all singletons on app.state |
| 35 | +# --------------------------------------------------------------------------- |
35 | 36 |
|
36 | | -def get_workflow() -> RecommendationWorkflow: |
37 | | - """Get the recommendation workflow singleton.""" |
38 | | - global _workflow |
39 | | - if _workflow is None: |
40 | | - _workflow = RecommendationWorkflow() |
41 | | - return _workflow |
42 | 37 |
|
| 38 | +def init_app_state(app: FastAPI) -> None: |
| 39 | + """Initialize all singletons on app.state during lifespan startup.""" |
| 40 | + app.state.model_catalog = ModelCatalog() |
| 41 | + app.state.slo_repo = SLOTemplateRepository() |
| 42 | + app.state.deployment_generator = DeploymentGenerator(simulator_mode=False) |
| 43 | + app.state.yaml_validator = YAMLValidator() |
| 44 | + app.state.cluster_managers = {} # dict[str, KubernetesClusterManager] |
| 45 | + app.state.workflow = RecommendationWorkflow() |
43 | 46 |
|
44 | | -def get_model_catalog() -> ModelCatalog: |
45 | | - """Get the model catalog singleton.""" |
46 | | - global _model_catalog |
47 | | - if _model_catalog is None: |
48 | | - _model_catalog = ModelCatalog() |
49 | | - return _model_catalog |
50 | 47 |
|
| 48 | +# --------------------------------------------------------------------------- |
| 49 | +# Depends() providers — read from request.app.state |
| 50 | +# --------------------------------------------------------------------------- |
51 | 51 |
|
52 | | -def get_slo_repo() -> SLOTemplateRepository: |
53 | | - """Get the SLO template repository singleton.""" |
54 | | - global _slo_repo |
55 | | - if _slo_repo is None: |
56 | | - _slo_repo = SLOTemplateRepository() |
57 | | - return _slo_repo |
58 | 52 |
|
| 53 | +def get_workflow(request: Request) -> RecommendationWorkflow: |
| 54 | + """Get the recommendation workflow singleton.""" |
| 55 | + return cast(RecommendationWorkflow, request.app.state.workflow) |
59 | 56 |
|
60 | | -def get_deployment_generator() -> DeploymentGenerator: |
61 | | - """Get the deployment generator singleton.""" |
62 | | - global _deployment_generator |
63 | | - if _deployment_generator is None: |
64 | | - _deployment_generator = DeploymentGenerator(simulator_mode=False) |
65 | | - logger.info("Deployment generator initialized (simulator_mode=False)") |
66 | | - return _deployment_generator |
67 | 57 |
|
| 58 | +def get_model_catalog(request: Request) -> ModelCatalog: |
| 59 | + """Get the model catalog singleton.""" |
| 60 | + return cast(ModelCatalog, request.app.state.model_catalog) |
68 | 61 |
|
69 | | -def get_deployment_mode() -> DeploymentMode: |
70 | | - """Return the current deployment mode.""" |
71 | | - gen = get_deployment_generator() |
72 | | - return DeploymentMode.SIMULATOR if gen.simulator_mode else DeploymentMode.PRODUCTION |
73 | 62 |
|
| 63 | +def get_slo_repo(request: Request) -> SLOTemplateRepository: |
| 64 | + """Get the SLO template repository singleton.""" |
| 65 | + return cast(SLOTemplateRepository, request.app.state.slo_repo) |
74 | 66 |
|
75 | | -def set_deployment_mode(mode: DeploymentMode) -> DeploymentMode: |
76 | | - """Set the deployment mode and return the new mode.""" |
77 | | - gen = get_deployment_generator() |
78 | | - gen.simulator_mode = mode == DeploymentMode.SIMULATOR |
79 | | - logger.info(f"Deployment mode changed to: {mode.value}") |
80 | | - return mode |
81 | 67 |
|
| 68 | +def get_deployment_generator(request: Request) -> DeploymentGenerator: |
| 69 | + """Get the deployment generator singleton.""" |
| 70 | + return cast(DeploymentGenerator, request.app.state.deployment_generator) |
82 | 71 |
|
83 | | -def get_yaml_validator() -> YAMLValidator: |
84 | | - """Get the YAML validator singleton.""" |
85 | | - global _yaml_validator |
86 | | - if _yaml_validator is None: |
87 | | - _yaml_validator = YAMLValidator() |
88 | | - return _yaml_validator |
89 | 72 |
|
| 73 | +def get_yaml_validator(request: Request) -> YAMLValidator: |
| 74 | + """Get the YAML validator singleton.""" |
| 75 | + return cast(YAMLValidator, request.app.state.yaml_validator) |
90 | 76 |
|
91 | | -def get_cluster_manager(namespace: str = "default") -> KubernetesClusterManager | None: |
92 | | - """Get or create a cluster manager. |
93 | 77 |
|
94 | | - Returns None if cluster is not accessible. |
95 | | - """ |
96 | | - global _cluster_manager |
97 | | - if _cluster_manager is None: |
98 | | - try: |
99 | | - _cluster_manager = KubernetesClusterManager(namespace=namespace) |
100 | | - logger.info("Kubernetes cluster manager initialized successfully") |
101 | | - except KubernetesDeploymentError as e: |
102 | | - logger.info(f"Kubernetes cluster not accessible: {e}") |
103 | | - return None |
104 | | - return _cluster_manager |
| 78 | +_MAX_CACHED_NAMESPACES = 32 |
105 | 79 |
|
106 | 80 |
|
107 | | -def get_cluster_manager_or_raise(namespace: str = "default") -> KubernetesClusterManager: |
| 81 | +async def get_cluster_manager_or_raise( |
| 82 | + request: Request, namespace: str = "default" |
| 83 | +) -> KubernetesClusterManager: |
108 | 84 | """Get or create a cluster manager, raising an exception if not accessible.""" |
109 | | - manager = get_cluster_manager(namespace) |
110 | | - if manager is None: |
111 | | - try: |
112 | | - return KubernetesClusterManager(namespace=namespace) |
113 | | - except KubernetesDeploymentError as e: |
114 | | - from fastapi import HTTPException, status |
115 | | - |
116 | | - raise HTTPException( |
117 | | - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
118 | | - detail=f"Kubernetes cluster not accessible: {str(e)}", |
119 | | - ) from e |
120 | | - return manager |
| 85 | + managers: dict[str, KubernetesClusterManager] = request.app.state.cluster_managers |
| 86 | + if namespace not in managers: |
| 87 | + lock = cast(asyncio.Lock, request.app.state.cluster_manager_lock) |
| 88 | + async with lock: |
| 89 | + if namespace not in managers: |
| 90 | + if len(managers) >= _MAX_CACHED_NAMESPACES: |
| 91 | + raise HTTPException( |
| 92 | + status_code=status.HTTP_400_BAD_REQUEST, |
| 93 | + detail=f"Too many namespaces (limit {_MAX_CACHED_NAMESPACES})", |
| 94 | + ) |
| 95 | + try: |
| 96 | + managers[namespace] = await run_in_threadpool( |
| 97 | + KubernetesClusterManager, namespace=namespace |
| 98 | + ) |
| 99 | + logger.info( |
| 100 | + "Kubernetes cluster manager initialized for namespace=%s", |
| 101 | + namespace, |
| 102 | + ) |
| 103 | + except KubernetesDeploymentError as e: |
| 104 | + raise HTTPException( |
| 105 | + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
| 106 | + detail=f"Kubernetes cluster not accessible: {e}", |
| 107 | + ) from e |
| 108 | + return managers[namespace] |
0 commit comments