-
Notifications
You must be signed in to change notification settings - Fork 23.2k
Expand file tree
/
Copy pathexecution_context.py
More file actions
263 lines (202 loc) · 7.53 KB
/
Copy pathexecution_context.py
File metadata and controls
263 lines (202 loc) · 7.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
"""
Application-layer execution context adapters.
Concrete context capture lives outside `graphon` so the graph package only
consumes injected context managers when it needs to preserve thread-local state.
"""
import contextvars
import threading
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, final, override, runtime_checkable
from pydantic import BaseModel
class AppContext(Protocol):
"""
Application context interface.
Application adapters can implement this to restore framework-specific state
such as Flask app context around worker execution.
"""
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
...
def get_extension(self, name: str) -> Any:
"""Get application extension by name."""
...
def enter(self) -> AbstractContextManager[None]:
"""Enter the application context."""
...
@runtime_checkable
class IExecutionContext(Protocol):
"""
Protocol for enterable execution context objects.
Concrete implementations may carry extra framework state, but callers only
depend on standard context-manager behavior plus optional user metadata.
"""
def __enter__(self) -> "IExecutionContext":
"""Enter the execution context."""
...
def __exit__(self, *args: Any) -> None:
"""Exit the execution context."""
...
@property
def user(self) -> Any:
"""Get user object."""
...
def refresh_context_vars(self) -> None:
"""Re-capture current context variables for propagation to worker threads."""
...
@final
class ExecutionContext:
"""
Generic execution context used by application-layer adapters.
It restores captured `contextvars` and optionally enters an application
context before the worker executes graph logic.
"""
def __init__(
self,
app_context: AppContext | None = None,
context_vars: contextvars.Context | None = None,
user: Any = None,
) -> None:
self._app_context = app_context
self._context_vars = context_vars
self._user = user
self._local = threading.local()
@property
def app_context(self) -> AppContext | None:
"""Get application context."""
return self._app_context
@property
def context_vars(self) -> contextvars.Context | None:
"""Get captured context variables."""
return self._context_vars
@property
def user(self) -> Any:
"""Get captured user object."""
return self._user
def refresh_context_vars(self) -> None:
"""Re-capture current context variables.
Call this after ContextVars have been updated in the current thread
(e.g. by a GraphEngine layer's ``on_graph_start``) so that worker
threads created afterwards receive the updated values.
"""
self._context_vars = contextvars.copy_context()
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""Enter this execution context."""
if self._context_vars:
for var, val in self._context_vars.items():
var.set(val)
if self._app_context is not None:
with self._app_context.enter():
yield
else:
yield
def __enter__(self) -> "ExecutionContext":
"""Enter the execution context."""
cm = self.enter()
self._local.cm = cm
cm.__enter__()
return self
def __exit__(self, *args: Any) -> None:
"""Exit the execution context."""
cm = getattr(self._local, "cm", None)
if cm is not None:
cm.__exit__(*args)
class NullAppContext(AppContext):
"""
Null application context for non-framework environments.
"""
def __init__(self, config: dict[str, Any] | None = None) -> None:
self._config = config or {}
self._extensions: dict[str, Any] = {}
@override
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
return self._config.get(key, default)
@override
def get_extension(self, name: str) -> Any:
"""Get extension by name."""
return self._extensions.get(name)
def set_extension(self, name: str, extension: Any) -> None:
"""Register an extension for tests or standalone execution."""
self._extensions[name] = extension
@contextmanager
@override
def enter(self) -> Generator[None, None, None]:
"""Enter null context (no-op)."""
yield
class ExecutionContextBuilder:
"""
Builder for creating `ExecutionContext` instances.
"""
def __init__(self) -> None:
self._app_context: AppContext | None = None
self._context_vars: contextvars.Context | None = None
self._user: Any = None
def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder":
"""Set application context."""
self._app_context = app_context
return self
def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder":
"""Set context variables."""
self._context_vars = context_vars
return self
def with_user(self, user: Any) -> "ExecutionContextBuilder":
"""Set user."""
self._user = user
return self
def build(self) -> ExecutionContext:
"""Build the execution context."""
return ExecutionContext(
app_context=self._app_context,
context_vars=self._context_vars,
user=self._user,
)
_capturer: Callable[[], IExecutionContext] | None = None
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
class ContextProviderNotFoundError(KeyError):
"""Raised when a tenant-scoped context provider is missing."""
pass
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
"""Register an enterable execution context capturer."""
global _capturer
_capturer = capturer
def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None:
"""Register a tenant-specific provider for a named context."""
_tenant_context_providers[(name, tenant_id)] = provider
def read_context(name: str, *, tenant_id: str) -> BaseModel:
"""Read a context value for a specific tenant."""
provider = _tenant_context_providers.get((name, tenant_id))
if provider is None:
raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'")
return provider()
def capture_current_context() -> IExecutionContext:
"""
Capture current execution context from the calling environment.
If no framework adapter is registered, return a minimal context that only
restores `contextvars`.
"""
if _capturer is None:
return ExecutionContext(
app_context=NullAppContext(),
context_vars=contextvars.copy_context(),
)
return _capturer()
def reset_context_provider() -> None:
"""Reset the capturer and tenant-scoped providers."""
global _capturer
_capturer = None
_tenant_context_providers.clear()
__all__ = [
"AppContext",
"ContextProviderNotFoundError",
"ExecutionContext",
"ExecutionContextBuilder",
"IExecutionContext",
"NullAppContext",
"capture_current_context",
"read_context",
"register_context",
"register_context_capturer",
"reset_context_provider",
]