Skip to content

Commit 8ea7f68

Browse files
authored
feat: Generic segment metadata (#265)
1 parent fa183fb commit 8ea7f68

File tree

6 files changed

+128
-26
lines changed

6 files changed

+128
-26
lines changed

flag_engine/context/types.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44

55
from __future__ import annotations
66

7-
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
7+
from typing import Any, Dict, Generic, List, Literal, Optional, Union
88

9-
from typing_extensions import NotRequired
9+
from typing_extensions import NotRequired, TypedDict
1010

11-
from flag_engine.segments.types import ConditionOperator, ContextValue, RuleType
11+
from flag_engine.segments.types import (
12+
ConditionOperator,
13+
ContextValue,
14+
RuleType,
15+
SegmentMetadataT,
16+
)
1217

1318

1419
class EnvironmentContext(TypedDict):
@@ -58,16 +63,16 @@ class FeatureContext(TypedDict):
5863
priority: NotRequired[float]
5964

6065

61-
class SegmentContext(TypedDict):
66+
class SegmentContext(TypedDict, Generic[SegmentMetadataT]):
6267
key: str
6368
name: str
6469
rules: List[SegmentRule]
6570
overrides: NotRequired[List[FeatureContext]]
66-
metadata: NotRequired[Dict[str, Any]]
71+
metadata: NotRequired[SegmentMetadataT]
6772

6873

69-
class EvaluationContext(TypedDict):
74+
class EvaluationContext(TypedDict, Generic[SegmentMetadataT]):
7075
environment: EnvironmentContext
7176
identity: NotRequired[Optional[IdentityContext]]
72-
segments: NotRequired[Dict[str, SegmentContext]]
77+
segments: NotRequired[Dict[str, SegmentContext[SegmentMetadataT]]]
7378
features: NotRequired[Dict[str, FeatureContext]]

flag_engine/result/types.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
from __future__ import annotations
66

7-
from typing import Any, Dict, List, TypedDict
7+
from typing import Any, Dict, Generic, List
88

9-
from typing_extensions import NotRequired
9+
from typing_extensions import NotRequired, TypedDict
10+
11+
from flag_engine.segments.types import SegmentMetadataT
1012

1113

1214
class FlagResult(TypedDict):
@@ -17,12 +19,12 @@ class FlagResult(TypedDict):
1719
reason: str
1820

1921

20-
class SegmentResult(TypedDict):
22+
class SegmentResult(TypedDict, Generic[SegmentMetadataT]):
2123
key: str
2224
name: str
23-
metadata: NotRequired[Dict[str, Any]]
25+
metadata: NotRequired[SegmentMetadataT]
2426

2527

26-
class EvaluationResult(TypedDict):
28+
class EvaluationResult(TypedDict, Generic[SegmentMetadataT]):
2729
flags: Dict[str, FlagResult]
28-
segments: List[SegmentResult]
30+
segments: List[SegmentResult[SegmentMetadataT]]

flag_engine/segments/evaluator.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
)
2121
from flag_engine.result.types import EvaluationResult, FlagResult, SegmentResult
2222
from flag_engine.segments import constants
23-
from flag_engine.segments.types import ConditionOperator, ContextValue, is_context_value
23+
from flag_engine.segments.types import (
24+
ConditionOperator,
25+
ContextValue,
26+
SegmentMetadataT,
27+
is_context_value,
28+
)
2429
from flag_engine.segments.utils import escape_double_quotes, get_matching_function
2530
from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids
2631
from flag_engine.utils.semver import is_semver
@@ -32,14 +37,16 @@ class FeatureContextWithSegmentName(typing.TypedDict):
3237
segment_name: str
3338

3439

35-
def get_evaluation_result(context: EvaluationContext) -> EvaluationResult:
40+
def get_evaluation_result(
41+
context: EvaluationContext[SegmentMetadataT],
42+
) -> EvaluationResult[SegmentMetadataT]:
3643
"""
3744
Get the evaluation result for a given context.
3845
3946
:param context: the evaluation context
4047
:return: EvaluationResult containing the context, flags, and segments
4148
"""
42-
segments: list[SegmentResult] = []
49+
segments: list[SegmentResult[SegmentMetadataT]] = []
4350
flags: dict[str, FlagResult] = {}
4451

4552
segment_feature_contexts: dict[SupportsStr, FeatureContextWithSegmentName] = {}
@@ -48,7 +55,7 @@ def get_evaluation_result(context: EvaluationContext) -> EvaluationResult:
4855
if not is_context_in_segment(context, segment_context):
4956
continue
5057

51-
segment_result: SegmentResult = {
58+
segment_result: SegmentResult[SegmentMetadataT] = {
5259
"key": segment_context["key"],
5360
"name": segment_context["name"],
5461
}
@@ -152,8 +159,8 @@ def get_flag_result_from_feature_context(
152159

153160

154161
def is_context_in_segment(
155-
context: EvaluationContext,
156-
segment_context: SegmentContext,
162+
context: EvaluationContext[SegmentMetadataT],
163+
segment_context: SegmentContext[SegmentMetadataT],
157164
) -> bool:
158165
return bool(rules := segment_context["rules"]) and all(
159166
context_matches_rule(
@@ -164,7 +171,7 @@ def is_context_in_segment(
164171

165172

166173
def context_matches_rule(
167-
context: EvaluationContext,
174+
context: EvaluationContext[SegmentMetadataT],
168175
rule: SegmentRule,
169176
segment_key: SupportsStr,
170177
) -> bool:
@@ -194,7 +201,7 @@ def context_matches_rule(
194201

195202

196203
def context_matches_condition(
197-
context: EvaluationContext,
204+
context: EvaluationContext[SegmentMetadataT],
198205
condition: SegmentCondition,
199206
segment_key: SupportsStr,
200207
) -> bool:
@@ -255,7 +262,7 @@ def context_matches_condition(
255262

256263

257264
def get_context_value(
258-
context: EvaluationContext,
265+
context: EvaluationContext[SegmentMetadataT],
259266
property: str,
260267
) -> ContextValue:
261268
value = None
@@ -353,7 +360,7 @@ def inner(
353360
@lru_cache
354361
def _get_context_value_getter(
355362
property: str,
356-
) -> typing.Callable[[EvaluationContext], ContextValue]:
363+
) -> typing.Callable[[EvaluationContext[SegmentMetadataT]], ContextValue]:
357364
"""
358365
Get a function to retrieve a context value based on property value,
359366
assumed to be either a JSONPath string or a trait key.
@@ -370,7 +377,7 @@ def _get_context_value_getter(
370377
f'$.identity.traits["{escape_double_quotes(property)}"]',
371378
)
372379

373-
def getter(context: EvaluationContext) -> ContextValue:
380+
def getter(context: EvaluationContext[SegmentMetadataT]) -> ContextValue:
374381
if typing.TYPE_CHECKING: # pragma: no cover
375382
# Ugly hack to satisfy mypy :(
376383
data = dict(context)

flag_engine/segments/types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
from typing import Any, Literal, Union, get_args
1+
from __future__ import annotations
22

3-
from typing_extensions import TypeGuard
3+
from typing import Any, Dict, Literal, Union, get_args
4+
5+
from typing_extensions import TypeGuard, TypeVar
6+
7+
SegmentMetadataT = TypeVar("SegmentMetadataT", default=Dict[str, object])
48

59
ConditionOperator = Literal[
610
"EQUAL",

flag_engine/types/__init__.py

Whitespace-only changes.

tests/unit/test_engine.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
import json
2+
from typing import TYPE_CHECKING, TypedDict
3+
4+
if not TYPE_CHECKING:
5+
# `reveal_type` is a pseudo-builtin only available when type checking.
6+
# Define a no-op version here so that we can call it in the tests.
7+
def reveal_type(x: object) -> None: ...
8+
29

310
from flag_engine.context.types import EvaluationContext, IdentityContext, SegmentContext
411
from flag_engine.engine import get_evaluation_result
@@ -357,3 +364,80 @@ def test_get_evaluation_result__segment_override__no_priority__returns_expected(
357364
{"key": "3", "name": "another_segment"},
358365
],
359366
}
367+
368+
369+
def test_segment_metadata_generic_type__returns_expected() -> None:
370+
# Given
371+
class CustomMetadata(TypedDict):
372+
foo: str
373+
bar: int
374+
375+
segment_metadata = CustomMetadata(foo="hello", bar=123)
376+
377+
evaluation_context: EvaluationContext[CustomMetadata] = {
378+
"environment": {"key": "api-key", "name": ""},
379+
"segments": {
380+
"1": {
381+
"key": "1",
382+
"name": "my_segment",
383+
"rules": [
384+
{
385+
"type": "ALL",
386+
"conditions": [
387+
{
388+
"property": "$.environment.name",
389+
"operator": "EQUAL",
390+
"value": "",
391+
}
392+
],
393+
"rules": [],
394+
}
395+
],
396+
"metadata": segment_metadata,
397+
},
398+
},
399+
}
400+
401+
# When
402+
result = get_evaluation_result(evaluation_context)
403+
404+
# Then
405+
assert result["segments"][0]["metadata"] is segment_metadata
406+
reveal_type(result["segments"][0]["metadata"]) # CustomMetadata
407+
408+
409+
def test_segment_metadata_generic_type__default__returns_expected() -> None:
410+
# Given
411+
segment_metadata = {"hello": object()}
412+
413+
# we don't specify generic type, but mypy is happy with this
414+
evaluation_context: EvaluationContext = {
415+
"environment": {"key": "api-key", "name": ""},
416+
"segments": {
417+
"1": {
418+
"key": "1",
419+
"name": "my_segment",
420+
"rules": [
421+
{
422+
"type": "ALL",
423+
"conditions": [
424+
{
425+
"property": "$.environment.name",
426+
"operator": "EQUAL",
427+
"value": "",
428+
}
429+
],
430+
"rules": [],
431+
}
432+
],
433+
"metadata": segment_metadata,
434+
},
435+
},
436+
}
437+
438+
# When
439+
result = get_evaluation_result(evaluation_context)
440+
441+
# Then
442+
assert result["segments"][0]["metadata"] is segment_metadata
443+
reveal_type(result["segments"][0]["metadata"]) # Dict[str, object]

0 commit comments

Comments
 (0)