|
4 | 4 | import typing |
5 | 5 | import warnings |
6 | 6 | from contextlib import suppress |
7 | | -from functools import partial, wraps |
| 7 | +from functools import lru_cache, wraps |
8 | 8 |
|
| 9 | +import jsonpath_rfc9535 |
9 | 10 | import semver |
10 | 11 |
|
11 | 12 | from flag_engine.context.mappers import map_environment_identity_to_context |
|
18 | 19 | ) |
19 | 20 | from flag_engine.environments.models import EnvironmentModel |
20 | 21 | from flag_engine.identities.models import IdentityModel |
21 | | -from flag_engine.identities.traits.types import ContextValue |
| 22 | +from flag_engine.identities.traits.types import ContextValue, is_trait_value |
22 | 23 | from flag_engine.result.types import EvaluationResult, FlagResult, SegmentResult |
23 | 24 | from flag_engine.segments import constants |
24 | 25 | from flag_engine.segments.models import SegmentModel |
25 | 26 | from flag_engine.segments.types import ConditionOperator |
26 | | -from flag_engine.segments.utils import get_matching_function |
| 27 | +from flag_engine.segments.utils import escape_double_quotes, get_matching_function |
27 | 28 | from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids |
28 | 29 | from flag_engine.utils.semver import is_semver |
29 | 30 | from flag_engine.utils.types import SupportsStr, get_casting_function |
@@ -256,26 +257,16 @@ def context_matches_condition( |
256 | 257 | ) |
257 | 258 |
|
258 | 259 |
|
259 | | -def _get_trait(context: EvaluationContext, trait_key: str) -> ContextValue: |
260 | | - return ( |
261 | | - identity_context["traits"][trait_key] |
262 | | - if (identity_context := context["identity"]) |
263 | | - else None |
264 | | - ) |
265 | | - |
266 | | - |
267 | 260 | def get_context_value( |
268 | 261 | context: EvaluationContext, |
269 | 262 | property: str, |
270 | 263 | ) -> ContextValue: |
271 | | - getter = CONTEXT_VALUE_GETTERS_BY_PROPERTY.get(property) or partial( |
272 | | - _get_trait, |
273 | | - trait_key=property, |
274 | | - ) |
275 | | - try: |
276 | | - return getter(context) |
277 | | - except KeyError: |
278 | | - return None |
| 264 | + if property.startswith("$."): |
| 265 | + return _get_context_value_getter(property)(context) |
| 266 | + if identity_context := context.get("identity"): |
| 267 | + if traits := identity_context.get("traits"): |
| 268 | + return traits.get(property) |
| 269 | + return None |
279 | 270 |
|
280 | 271 |
|
281 | 272 | def _matches_context_value( |
@@ -385,8 +376,44 @@ def inner( |
385 | 376 | } |
386 | 377 |
|
387 | 378 |
|
388 | | -CONTEXT_VALUE_GETTERS_BY_PROPERTY = { |
389 | | - "$.identity.identifier": lambda context: context["identity"]["identifier"], |
390 | | - "$.identity.key": lambda context: context["identity"]["key"], |
391 | | - "$.environment.name": lambda context: context["environment"]["name"], |
392 | | -} |
| 379 | +@lru_cache |
| 380 | +def _get_context_value_getter( |
| 381 | + property: str, |
| 382 | +) -> typing.Callable[[EvaluationContext], ContextValue]: |
| 383 | + """ |
| 384 | + Get a function to retrieve a context value based on property value, |
| 385 | + assumed to be either a JSONPath string or a trait key. |
| 386 | +
|
| 387 | + :param property: The property to retrieve the value for. |
| 388 | + :return: A function that takes an EvaluationContext and returns the value. |
| 389 | + """ |
| 390 | + try: |
| 391 | + compiled_query = jsonpath_rfc9535.compile(property) |
| 392 | + except jsonpath_rfc9535.JSONPathSyntaxError: |
| 393 | + # This covers a rare case when a trait starting with "$.", |
| 394 | + # but not a valid JSONPath, is used. |
| 395 | + compiled_query = jsonpath_rfc9535.compile( |
| 396 | + f'$.identity.traits["{escape_double_quotes(property)}"]', |
| 397 | + ) |
| 398 | + |
| 399 | + def getter(context: EvaluationContext) -> ContextValue: |
| 400 | + if typing.TYPE_CHECKING: # pragma: no cover |
| 401 | + # Ugly hack to satisfy mypy :( |
| 402 | + data = dict(context) |
| 403 | + else: |
| 404 | + data = context |
| 405 | + try: |
| 406 | + if result := compiled_query.find_one(data): |
| 407 | + if is_trait_value(value := result.value): |
| 408 | + return value |
| 409 | + return None |
| 410 | + except jsonpath_rfc9535.JSONPathError: # pragma: no cover |
| 411 | + # This is supposed to be unreachable, but if it happens, |
| 412 | + # we log a warning and return None. |
| 413 | + warnings.warn( |
| 414 | + f"Failed to evaluate JSONPath query '{property}' in context: {context}", |
| 415 | + RuntimeWarning, |
| 416 | + ) |
| 417 | + return None |
| 418 | + |
| 419 | + return getter |
0 commit comments