diff --git a/.gitmodules b/.gitmodules index c740a219..cb2b8f6f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "tests/engine_tests/engine-test-data"] path = tests/engine_tests/engine-test-data url = https://github.com/flagsmith/engine-test-data.git + branch = feat/context-values diff --git a/flag_engine/context/__init__.py b/flag_engine/context/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/flag_engine/context/mappers.py b/flag_engine/context/mappers.py new file mode 100644 index 00000000..f5a3525b --- /dev/null +++ b/flag_engine/context/mappers.py @@ -0,0 +1,39 @@ +import typing + +from flag_engine.context.types import EvaluationContext +from flag_engine.environments.models import EnvironmentModel +from flag_engine.identities.models import IdentityModel +from flag_engine.identities.traits.models import TraitModel + + +def map_environment_identity_to_context( + environment: EnvironmentModel, + identity: IdentityModel, + override_traits: typing.Optional[typing.List[TraitModel]], +) -> EvaluationContext: + """ + Maps an EnvironmentModel and IdentityModel to an EvaluationContext. + + :param environment: The environment model object. + :param identity: The identity model object. + :param override_traits: A list of TraitModel objects, to be used in place of `identity.identity_traits` if provided. + :return: An EvaluationContext containing the environment and identity. + """ + return { + "environment": { + "key": environment.api_key, + "name": environment.name or "", + }, + "identity": { + "identifier": identity.identifier, + "key": str(identity.django_id or identity.composite_key), + "traits": { + trait.trait_key: trait.trait_value + for trait in ( + override_traits + if override_traits is not None + else identity.identity_traits + ) + }, + }, + } diff --git a/flag_engine/context/types.py b/flag_engine/context/types.py new file mode 100644 index 00000000..3cc7f106 --- /dev/null +++ b/flag_engine/context/types.py @@ -0,0 +1,28 @@ +# generated by datamodel-codegen: +# filename: https://raw.githubusercontent.com/Flagsmith/flagsmith/chore/update-evaluation-context/sdk/evaluation-context.json # noqa: E501 +# timestamp: 2025-07-16T10:39:10+00:00 + +from __future__ import annotations + +from typing import Dict, Optional, TypedDict + +from typing_extensions import NotRequired + +from flag_engine.identities.traits.types import ContextValue +from flag_engine.utils.types import SupportsStr + + +class EnvironmentContext(TypedDict): + key: str + name: str + + +class IdentityContext(TypedDict): + identifier: str + key: SupportsStr + traits: NotRequired[Dict[str, ContextValue]] + + +class EvaluationContext(TypedDict): + environment: EnvironmentContext + identity: NotRequired[Optional[IdentityContext]] diff --git a/flag_engine/engine.py b/flag_engine/engine.py index 07334fd4..f43dc4c9 100644 --- a/flag_engine/engine.py +++ b/flag_engine/engine.py @@ -1,10 +1,12 @@ import typing +from flag_engine.context.mappers import map_environment_identity_to_context +from flag_engine.context.types import EvaluationContext from flag_engine.environments.models import EnvironmentModel from flag_engine.features.models import FeatureModel, FeatureStateModel from flag_engine.identities.models import IdentityModel from flag_engine.identities.traits.models import TraitModel -from flag_engine.segments.evaluator import get_identity_segments +from flag_engine.segments.evaluator import get_context_segments from flag_engine.utils.exceptions import FeatureStateNotFound @@ -53,9 +55,17 @@ def get_identity_feature_states( :return: list of feature state models based on the environment, any matching segments and any specific identity overrides """ + context = map_environment_identity_to_context( + environment=environment, + identity=identity, + override_traits=override_traits, + ) + feature_states = list( _get_identity_feature_states_dict( - environment, identity, override_traits + environment=environment, + identity=identity, + context=context, ).values() ) if environment.get_hide_disabled_flags(): @@ -79,8 +89,16 @@ def get_identity_feature_state( :return: feature state model based on the environment, any matching segments and any specific identity overrides """ + context = map_environment_identity_to_context( + environment=environment, + identity=identity, + override_traits=override_traits, + ) + feature_states = _get_identity_feature_states_dict( - environment, identity, override_traits + environment=environment, + identity=identity, + context=context, ) matching_feature = next( filter(lambda feature: feature.name == feature_name, feature_states.keys()), @@ -96,29 +114,33 @@ def get_identity_feature_state( def _get_identity_feature_states_dict( environment: EnvironmentModel, identity: IdentityModel, - override_traits: typing.Optional[typing.List[TraitModel]], + context: EvaluationContext, ) -> typing.Dict[FeatureModel, FeatureStateModel]: # Get feature states from the environment - feature_states = {fs.feature: fs for fs in environment.feature_states} + feature_states_by_feature = {fs.feature: fs for fs in environment.feature_states} # Override with any feature states defined by matching segments - identity_segments = get_identity_segments(environment, identity, override_traits) - for matching_segment in identity_segments: - for feature_state in matching_segment.feature_states: - if feature_state.feature in feature_states: - if feature_states[feature_state.feature].is_higher_segment_priority( - feature_state - ): - continue - feature_states[feature_state.feature] = feature_state + for context_segment in get_context_segments( + context=context, + segments=environment.project.segments, + ): + for segment_feature_state in context_segment.feature_states: + if ( + feature_state := feature_states_by_feature.get( + segment_feature := segment_feature_state.feature + ) + ) and feature_state.is_higher_segment_priority(segment_feature_state): + continue + feature_states_by_feature[segment_feature] = segment_feature_state # Override with any feature states defined directly the identity - feature_states.update( + feature_states_by_feature.update( { - fs.feature: fs - for fs in identity.identity_features - if fs.feature in feature_states + identity_feature: identity_feature_state + for identity_feature_state in identity.identity_features + if (identity_feature := identity_feature_state.feature) + in feature_states_by_feature } ) - return feature_states + return feature_states_by_feature diff --git a/flag_engine/identities/traits/models.py b/flag_engine/identities/traits/models.py index 83f3b26e..12ad914c 100644 --- a/flag_engine/identities/traits/models.py +++ b/flag_engine/identities/traits/models.py @@ -1,8 +1,8 @@ from pydantic import BaseModel, Field -from flag_engine.identities.traits.types import TraitValue +from flag_engine.identities.traits.types import ContextValue class TraitModel(BaseModel): trait_key: str - trait_value: TraitValue = Field(...) + trait_value: ContextValue = Field(...) diff --git a/flag_engine/identities/traits/types.py b/flag_engine/identities/traits/types.py index 4ec79b0e..5d213c9c 100644 --- a/flag_engine/identities/traits/types.py +++ b/flag_engine/identities/traits/types.py @@ -8,10 +8,10 @@ from flag_engine.identities.traits.constants import TRAIT_STRING_VALUE_MAX_LENGTH -_UnconstrainedTraitValue = Union[None, int, float, bool, str] +_UnconstrainedContextValue = Union[None, int, float, bool, str] -def map_any_value_to_trait_value(value: Any) -> _UnconstrainedTraitValue: +def map_any_value_to_trait_value(value: Any) -> _UnconstrainedContextValue: """ Try to coerce a value of arbitrary type to a trait value type. Union member-specific constraints, such as max string value length, are ignored here. @@ -36,7 +36,7 @@ def map_any_value_to_trait_value(value: Any) -> _UnconstrainedTraitValue: _float_pattern = re.compile(r"-?[0-9]+\.[0-9]+") -def _map_string_value_to_trait_value(value: str) -> _UnconstrainedTraitValue: +def _map_string_value_to_trait_value(value: str) -> _UnconstrainedContextValue: if _int_pattern.fullmatch(value): return int(value) if _float_pattern.fullmatch(value): @@ -44,11 +44,11 @@ def _map_string_value_to_trait_value(value: str) -> _UnconstrainedTraitValue: return value -def _is_trait_value(value: Any) -> TypeGuard[_UnconstrainedTraitValue]: - return isinstance(value, get_args(_UnconstrainedTraitValue)) +def _is_trait_value(value: Any) -> TypeGuard[_UnconstrainedContextValue]: + return isinstance(value, get_args(_UnconstrainedContextValue)) -TraitValue = Annotated[ +ContextValue = Annotated[ Union[ None, StrictBool, diff --git a/flag_engine/segments/evaluator.py b/flag_engine/segments/evaluator.py index 13f3a18a..49c4e4ef 100644 --- a/flag_engine/segments/evaluator.py +++ b/flag_engine/segments/evaluator.py @@ -2,14 +2,12 @@ import re import typing from contextlib import suppress -from functools import wraps +from functools import partial, wraps import semver -from flag_engine.environments.models import EnvironmentModel -from flag_engine.identities.models import IdentityModel -from flag_engine.identities.traits.models import TraitModel -from flag_engine.identities.traits.types import TraitValue +from flag_engine.context.types import EvaluationContext +from flag_engine.identities.traits.types import ContextValue from flag_engine.segments import constants from flag_engine.segments.models import ( SegmentConditionModel, @@ -19,131 +17,146 @@ from flag_engine.segments.types import ConditionOperator from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids from flag_engine.utils.semver import is_semver -from flag_engine.utils.types import get_casting_function +from flag_engine.utils.types import SupportsStr, get_casting_function -def get_identity_segments( - environment: EnvironmentModel, - identity: IdentityModel, - override_traits: typing.Optional[typing.List[TraitModel]] = None, +def get_context_segments( + context: EvaluationContext, + segments: typing.List[SegmentModel], ) -> typing.List[SegmentModel]: - return list( - filter( - lambda s: evaluate_identity_in_segment(identity, s, override_traits), - environment.project.segments, + return [ + segment + for segment in segments + if is_context_in_segment( + context=context, + segment=segment, ) - ) + ] -def evaluate_identity_in_segment( - identity: IdentityModel, +def is_context_in_segment( + context: EvaluationContext, segment: SegmentModel, - override_traits: typing.Optional[typing.List[TraitModel]] = None, ) -> bool: - """ - Evaluates whether a given identity is in the provided segment. - - :param identity: identity model object to evaluate - :param segment: segment model object to evaluate - :param override_traits: pass in a list of traits to use instead of those on the - identity model itself - :return: True if the identity is in the segment, False otherwise - """ - return len(segment.rules) > 0 and all( - _traits_match_segment_rule( - override_traits or identity.identity_traits, - rule, - segment.id, - identity.django_id or identity.composite_key, - ) - for rule in segment.rules + return bool(rules := segment.rules) and all( + context_matches_rule(context=context, rule=rule, segment_key=segment.id) + for rule in rules ) -def _traits_match_segment_rule( - identity_traits: typing.List[TraitModel], +def context_matches_rule( + context: EvaluationContext, rule: SegmentRuleModel, - segment_id: typing.Union[int, str], - identity_id: typing.Union[int, str], + segment_key: SupportsStr, ) -> bool: matches_conditions = ( rule.matching_function( [ - _traits_match_segment_condition( - identity_traits, condition, segment_id, identity_id + context_matches_condition( + context=context, + condition=condition, + segment_key=segment_key, ) - for condition in rule.conditions + for condition in conditions ] ) - if len(rule.conditions) > 0 + if (conditions := rule.conditions) else True ) return matches_conditions and all( - _traits_match_segment_rule(identity_traits, rule, segment_id, identity_id) + context_matches_rule( + context=context, + rule=rule, + segment_key=segment_key, + ) for rule in rule.rules ) -def _traits_match_segment_condition( - identity_traits: typing.List[TraitModel], +def context_matches_condition( + context: EvaluationContext, condition: SegmentConditionModel, - segment_id: typing.Union[int, str], - identity_id: typing.Union[int, str], + segment_key: SupportsStr, ) -> bool: + context_value = ( + get_context_value(context, condition.property_) if condition.property_ else None + ) + if condition.operator == constants.PERCENTAGE_SPLIT: assert condition.value - float_value = float(condition.value) - return ( - get_hashed_percentage_for_object_ids([segment_id, identity_id]) - <= float_value - ) - trait = next( - filter(lambda t: t.trait_key == condition.property_, identity_traits), None - ) + if context_value is not None: + object_ids = [segment_key, context_value] + else: + object_ids = [segment_key, get_context_value(context, "$.identity.key")] + + float_value = float(condition.value) + return get_hashed_percentage_for_object_ids(object_ids) <= float_value if condition.operator == constants.IS_NOT_SET: - return trait is None + return context_value is None if condition.operator == constants.IS_SET: - return trait is not None + return context_value is not None + + return _matches_context_value(condition, context_value) if context_value else False - return _matches_trait_value(condition, trait.trait_value) if trait else False +def _get_trait(context: EvaluationContext, trait_key: str) -> ContextValue: + return ( + identity_context["traits"][trait_key] + if (identity_context := context["identity"]) + else None + ) + + +def get_context_value( + context: EvaluationContext, + property: str, +) -> ContextValue: + getter = CONTEXT_VALUE_GETTERS_BY_PROPERTY.get(property) or partial( + _get_trait, + trait_key=property, + ) + try: + return getter(context) + except KeyError: + return None -def _matches_trait_value( + +def _matches_context_value( condition: SegmentConditionModel, - trait_value: TraitValue, + context_value: ContextValue, ) -> bool: - if match_func := MATCH_FUNCS_BY_OPERATOR.get(condition.operator): - return match_func(condition.value, trait_value) + if matcher := MATCHERS_BY_OPERATOR.get(condition.operator): + return matcher(condition.value, context_value) return False def _evaluate_not_contains( segment_value: typing.Optional[str], - trait_value: TraitValue, + context_value: ContextValue, ) -> bool: - return isinstance(trait_value, str) and str(segment_value) not in trait_value + return isinstance(context_value, str) and str(segment_value) not in context_value def _evaluate_regex( segment_value: typing.Optional[str], - trait_value: TraitValue, + context_value: ContextValue, ) -> bool: return ( - trait_value is not None - and re.compile(str(segment_value)).match(str(trait_value)) is not None + context_value is not None + and re.compile(str(segment_value)).match(str(context_value)) is not None ) def _evaluate_modulo( segment_value: typing.Optional[str], - trait_value: TraitValue, + context_value: ContextValue, ) -> bool: - if not isinstance(trait_value, (int, float)): + if not isinstance(context_value, (int, float)): return False if segment_value is None: @@ -156,52 +169,61 @@ def _evaluate_modulo( except ValueError: return False - return trait_value % divisor == remainder + return context_value % divisor == remainder -def _evaluate_in(segment_value: typing.Optional[str], trait_value: TraitValue) -> bool: +def _evaluate_in( + segment_value: typing.Optional[str], context_value: ContextValue +) -> bool: if segment_value: - if isinstance(trait_value, str): - return trait_value in segment_value.split(",") - if isinstance(trait_value, int) and not any( - trait_value is x for x in (False, True) + if isinstance(context_value, str): + return context_value in segment_value.split(",") + if isinstance(context_value, int) and not any( + context_value is x for x in (False, True) ): - return str(trait_value) in segment_value.split(",") + return str(context_value) in segment_value.split(",") return False -def _trait_value_typed( +def _context_value_typed( func: typing.Callable[..., bool], -) -> typing.Callable[[typing.Optional[str], TraitValue], bool]: +) -> typing.Callable[[typing.Optional[str], ContextValue], bool]: @wraps(func) def inner( segment_value: typing.Optional[str], - trait_value: typing.Union[TraitValue, semver.Version], + context_value: typing.Union[ContextValue, semver.Version], ) -> bool: with suppress(TypeError, ValueError): - if isinstance(trait_value, str) and is_semver(segment_value): - trait_value = semver.Version.parse( - trait_value, + if isinstance(context_value, str) and is_semver(segment_value): + context_value = semver.Version.parse( + context_value, ) - match_value = get_casting_function(trait_value)(segment_value) - return func(trait_value, match_value) + match_value = get_casting_function(context_value)(segment_value) + return func(context_value, match_value) return False return inner -MATCH_FUNCS_BY_OPERATOR: typing.Dict[ - ConditionOperator, typing.Callable[[typing.Optional[str], TraitValue], bool] +MATCHERS_BY_OPERATOR: typing.Dict[ + ConditionOperator, typing.Callable[[typing.Optional[str], ContextValue], bool] ] = { constants.NOT_CONTAINS: _evaluate_not_contains, constants.REGEX: _evaluate_regex, constants.MODULO: _evaluate_modulo, constants.IN: _evaluate_in, - constants.EQUAL: _trait_value_typed(operator.eq), - constants.GREATER_THAN: _trait_value_typed(operator.gt), - constants.GREATER_THAN_INCLUSIVE: _trait_value_typed(operator.ge), - constants.LESS_THAN: _trait_value_typed(operator.lt), - constants.LESS_THAN_INCLUSIVE: _trait_value_typed(operator.le), - constants.NOT_EQUAL: _trait_value_typed(operator.ne), - constants.CONTAINS: _trait_value_typed(operator.contains), + constants.EQUAL: _context_value_typed(operator.eq), + constants.GREATER_THAN: _context_value_typed(operator.gt), + constants.GREATER_THAN_INCLUSIVE: _context_value_typed(operator.ge), + constants.LESS_THAN: _context_value_typed(operator.lt), + constants.LESS_THAN_INCLUSIVE: _context_value_typed(operator.le), + constants.NOT_EQUAL: _context_value_typed(operator.ne), + constants.CONTAINS: _context_value_typed(operator.contains), +} + + +CONTEXT_VALUE_GETTERS_BY_PROPERTY = { + "$.identity.identifier": lambda context: context["identity"]["identifier"], + "$.identity.key": lambda context: context["identity"]["key"], + "$.environment.name": lambda context: context["environment"]["name"], } diff --git a/flag_engine/types/__init__.py b/flag_engine/types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/flag_engine/utils/types.py b/flag_engine/utils/types.py index f98d6cd9..f9e14b77 100644 --- a/flag_engine/utils/types.py +++ b/flag_engine/utils/types.py @@ -3,7 +3,7 @@ import semver -from flag_engine.identities.traits.types import TraitValue +from flag_engine.identities.traits.types import ContextValue from flag_engine.utils.semver import remove_semver_suffix @@ -15,7 +15,7 @@ def __str__(self) -> str: # pragma: no cover @singledispatch def get_casting_function( input_: object, -) -> typing.Callable[..., TraitValue]: +) -> typing.Callable[..., ContextValue]: """ This function returns a callable to cast a value to the same type as input_ >>> assert get_casting_function("a string") == str diff --git a/requirements-dev.in b/requirements-dev.in index 61f43754..38a47bff 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -9,4 +9,5 @@ pip-tools types-pytest-lazy-fixture types-setuptools mypy -absolufy-imports \ No newline at end of file +absolufy-imports +datamodel-code-generator diff --git a/requirements-dev.txt b/requirements-dev.txt index 7f36826e..ce7d4d53 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,75 +1,168 @@ -# -# This file is autogenerated by pip-compile with Python 3.11 -# by the following command: -# -# pip-compile --output-file=requirements-dev.txt requirements-dev.in -# +# This file was autogenerated by uv via the following command: +# uv pip compile requirements-dev.in --constraints requirements-dev.txt --constraints requirements.txt -o requirements-dev.txt absolufy-imports==0.3.1 - # via -r requirements-dev.in + # via + # -c requirements-dev.txt + # -r requirements-dev.in +annotated-types==0.5.0 + # via + # -c requirements.txt + # pydantic +argcomplete==3.6.2 + # via datamodel-code-generator black==24.3.0 - # via -r requirements-dev.in + # via + # -c requirements-dev.txt + # -r requirements-dev.in + # datamodel-code-generator build==0.10.0 - # via pip-tools + # via + # -c requirements-dev.txt + # pip-tools click==8.1.7 # via + # -c requirements-dev.txt # black # pip-tools -coverage[toml]==7.3.0 - # via pytest-cov -flake8==6.1.0 +coverage==7.3.0 + # via + # -c requirements-dev.txt + # pytest-cov +datamodel-code-generator==0.27.3 # via -r requirements-dev.in -iniconfig==2.0.0 +exceptiongroup==1.3.0 # via pytest +flake8==6.1.0 + # via + # -c requirements-dev.txt + # -r requirements-dev.in +genson==1.3.0 + # via datamodel-code-generator +inflect==5.6.2 + # via datamodel-code-generator +iniconfig==2.0.0 + # via + # -c requirements-dev.txt + # pytest isort==5.12.0 - # via -r requirements-dev.in + # via + # -c requirements-dev.txt + # -r requirements-dev.in + # datamodel-code-generator +jinja2==3.1.6 + # via datamodel-code-generator +markupsafe==2.1.5 + # via jinja2 mccabe==0.7.0 - # via flake8 + # via + # -c requirements-dev.txt + # flake8 mypy==1.5.1 - # via -r requirements-dev.in + # via + # -c requirements-dev.txt + # -r requirements-dev.in mypy-extensions==1.0.0 # via + # -c requirements-dev.txt # black # mypy packaging==23.1 # via + # -c requirements-dev.txt # black # build + # datamodel-code-generator # pytest pathspec==0.11.2 - # via black + # via + # -c requirements-dev.txt + # black +pip==25.0.1 + # via pip-tools pip-tools==7.3.0 - # via -r requirements-dev.in + # via + # -c requirements-dev.txt + # -r requirements-dev.in platformdirs==3.10.0 - # via black + # via + # -c requirements-dev.txt + # black pluggy==1.2.0 - # via pytest + # via + # -c requirements-dev.txt + # pytest pycodestyle==2.11.0 - # via flake8 + # via + # -c requirements-dev.txt + # flake8 +pydantic==2.4.0 + # via + # -c requirements.txt + # datamodel-code-generator +pydantic-core==2.10.0 + # via + # -c requirements.txt + # pydantic pyflakes==3.1.0 - # via flake8 + # via + # -c requirements-dev.txt + # flake8 pyproject-hooks==1.0.0 - # via build + # via + # -c requirements-dev.txt + # build pytest==7.4.0 # via + # -c requirements-dev.txt # -r requirements-dev.in # pytest-cov # pytest-lazy-fixture # pytest-mock pytest-cov==4.1.0 - # via -r requirements-dev.in + # via + # -c requirements-dev.txt + # -r requirements-dev.in pytest-lazy-fixture==0.6.3 - # via -r requirements-dev.in + # via + # -c requirements-dev.txt + # -r requirements-dev.in pytest-mock==3.11.1 - # via -r requirements-dev.in + # via + # -c requirements-dev.txt + # -r requirements-dev.in +pyyaml==6.0.2 + # via datamodel-code-generator +setuptools==75.3.2 + # via pip-tools +tomli==2.2.1 + # via + # black + # build + # coverage + # datamodel-code-generator + # mypy + # pip-tools + # pyproject-hooks + # pytest types-pytest-lazy-fixture==0.6.3.4 - # via -r requirements-dev.in + # via + # -c requirements-dev.txt + # -r requirements-dev.in types-setuptools==68.2.0.0 - # via -r requirements-dev.in + # via + # -c requirements-dev.txt + # -r requirements-dev.in typing-extensions==4.8.0 - # via mypy + # via + # -c requirements-dev.txt + # -c requirements.txt + # annotated-types + # black + # exceptiongroup + # mypy + # pydantic + # pydantic-core wheel==0.41.2 - # via pip-tools - -# The following packages are considered to be unsafe in a requirements file: -# pip -# setuptools + # via + # -c requirements-dev.txt + # pip-tools diff --git a/tests/engine_tests/test_engine.py b/tests/engine_tests/test_engine.py index f6213537..4cc2b1d2 100644 --- a/tests/engine_tests/test_engine.py +++ b/tests/engine_tests/test_engine.py @@ -67,18 +67,15 @@ def test_engine( engine_response = get_identity_feature_states(environment_model, identity_model) # and we sort the feature states so we can iterate over them and compare - sorted_engine_flags = sorted(engine_response, key=lambda fs: fs.feature.name) api_flags = api_response["flags"] # Then # there are an equal number of flags and feature states - assert len(sorted_engine_flags) == len(api_flags) + assert len(engine_response) == len(api_flags) # and the values and enabled status of each of the feature states returned by the # engine are identical to those returned by the Django API (i.e. the test data). - for i, feature_state in enumerate(sorted_engine_flags): - assert ( - feature_state.get_value(identity_model.django_id) - == api_flags[i]["feature_state_value"] - ) - assert feature_state.enabled == api_flags[i]["enabled"] + assert { + fs.feature.name: fs.get_value(identity_model.django_id) + for fs in engine_response + } == {flag["feature"]["name"]: flag["feature_state_value"] for flag in api_flags} diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 5476fe9e..e5f77dd7 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -2,6 +2,8 @@ import pytest +from flag_engine.context.mappers import map_environment_identity_to_context +from flag_engine.context.types import EvaluationContext from flag_engine.environments.models import EnvironmentModel from flag_engine.features.constants import STANDARD from flag_engine.features.models import ( @@ -125,6 +127,18 @@ def identity(environment: EnvironmentModel) -> IdentityModel: ) +@pytest.fixture +def context( + environment: EnvironmentModel, + identity: IdentityModel, +) -> EvaluationContext: + return map_environment_identity_to_context( + environment=environment, + identity=identity, + override_traits=None, + ) + + @pytest.fixture() def trait_matching_segment(segment_condition: SegmentConditionModel) -> TraitModel: return TraitModel( @@ -145,6 +159,18 @@ def identity_in_segment( ) +@pytest.fixture +def context_in_segment( + identity_in_segment: IdentityModel, + environment: EnvironmentModel, +) -> EvaluationContext: + return map_environment_identity_to_context( + environment=environment, + identity=identity_in_segment, + override_traits=None, + ) + + @pytest.fixture() def segment_override_fs( segment: SegmentModel, diff --git a/tests/unit/segments/test_segments_evaluator.py b/tests/unit/segments/test_segments_evaluator.py index 3097510d..6d1fbe2f 100644 --- a/tests/unit/segments/test_segments_evaluator.py +++ b/tests/unit/segments/test_segments_evaluator.py @@ -4,13 +4,9 @@ from pytest_lazyfixture import lazy_fixture from pytest_mock import MockerFixture -from flag_engine.identities.models import IdentityModel -from flag_engine.identities.traits.models import TraitModel +from flag_engine.context.types import EvaluationContext from flag_engine.segments import constants -from flag_engine.segments.evaluator import ( - _matches_trait_value, - evaluate_identity_in_segment, -) +from flag_engine.segments.evaluator import _matches_context_value, is_context_in_segment from flag_engine.segments.models import ( SegmentConditionModel, SegmentModel, @@ -34,101 +30,197 @@ @pytest.mark.parametrize( - "segment, identity_traits, expected_result", + "segment, context, expected_result", ( - (empty_segment, [], False), - (segment_single_condition, [], False), + (empty_segment, {"environment": {"key": "key", "name": "Environment"}}, False), + ( + segment_single_condition, + {"environment": {"key": "key", "name": "Environment"}}, + False, + ), ( segment_single_condition, - [TraitModel(trait_key=trait_key_1, trait_value=trait_value_1)], + { + "environment": {"key": "key", "name": "Environment"}, + "identity": { + "identifier": "foo", + "key": "key_foo", + "traits": { + trait_key_1: trait_value_1, + }, + }, + }, True, ), - (segment_multiple_conditions_all, [], False), ( segment_multiple_conditions_all, - [TraitModel(trait_key=trait_key_1, trait_value=trait_value_1)], + {"environment": {"key": "key", "name": "Environment"}}, False, ), ( segment_multiple_conditions_all, - [ - TraitModel(trait_key=trait_key_1, trait_value=trait_value_1), - TraitModel(trait_key=trait_key_2, trait_value=trait_value_2), - ], + { + "environment": {"key": "key", "name": "Environment"}, + "identity": { + "identifier": "foo", + "key": "key_foo", + "traits": { + trait_key_1: trait_value_1, + }, + }, + }, + False, + ), + ( + segment_multiple_conditions_all, + { + "environment": {"key": "key", "name": "Environment"}, + "identity": { + "identifier": "foo", + "key": "key_foo", + "traits": { + trait_key_1: trait_value_1, + trait_key_2: trait_value_2, + }, + }, + }, True, ), - (segment_multiple_conditions_any, [], False), ( segment_multiple_conditions_any, - [TraitModel(trait_key=trait_key_1, trait_value=trait_value_1)], + {"environment": {"key": "key", "name": "Environment"}}, + False, + ), + ( + segment_multiple_conditions_any, + { + "environment": {"key": "key", "name": "Environment"}, + "identity": { + "identifier": "foo", + "key": "key_foo", + "traits": { + trait_key_1: trait_value_1, + }, + }, + }, True, ), ( segment_multiple_conditions_any, - [TraitModel(trait_key=trait_key_2, trait_value=trait_value_2)], + { + "environment": {"key": "key", "name": "Environment"}, + "identity": { + "identifier": "foo", + "key": "key_foo", + "traits": { + trait_key_2: trait_value_2, + }, + }, + }, True, ), ( segment_multiple_conditions_any, - [ - TraitModel(trait_key=trait_key_1, trait_value=trait_value_1), - TraitModel(trait_key=trait_key_2, trait_value=trait_value_2), - ], + { + "environment": {"key": "key", "name": "Environment"}, + "identity": { + "identifier": "foo", + "key": "key_foo", + "traits": { + trait_key_1: trait_value_1, + trait_key_2: trait_value_2, + }, + }, + }, True, ), - (segment_nested_rules, [], False), ( segment_nested_rules, - [TraitModel(trait_key=trait_key_1, trait_value=trait_value_1)], + {"environment": {"key": "key", "name": "Environment"}}, + False, + ), + ( + segment_nested_rules, + { + "environment": {"key": "key", "name": "Environment"}, + "identity": { + "identifier": "foo", + "key": "key_foo", + "traits": { + trait_key_1: trait_value_1, + }, + }, + }, False, ), ( segment_nested_rules, - [ - TraitModel(trait_key=trait_key_1, trait_value=trait_value_1), - TraitModel(trait_key=trait_key_2, trait_value=trait_value_2), - TraitModel(trait_key=trait_key_3, trait_value=trait_value_3), - ], + { + "environment": {"key": "key", "name": "Environment"}, + "identifier": "foo", + "key": "key_foo", + "identity": { + "traits": { + trait_key_1: trait_value_1, + trait_key_2: trait_value_2, + trait_key_3: trait_value_3, + } + }, + }, True, ), - (segment_conditions_and_nested_rules, [], False), ( segment_conditions_and_nested_rules, - [TraitModel(trait_key=trait_key_1, trait_value=trait_value_1)], + {"environment": {"key": "key", "name": "Environment"}}, False, ), ( segment_conditions_and_nested_rules, - [ - TraitModel(trait_key=trait_key_1, trait_value=trait_value_1), - TraitModel(trait_key=trait_key_2, trait_value=trait_value_2), - TraitModel(trait_key=trait_key_3, trait_value=trait_value_3), - ], + { + "environment": {"key": "key", "name": "Environment"}, + "identifier": "foo", + "key": "key_foo", + "identity": { + "traits": { + trait_key_1: trait_value_1, + } + }, + }, + False, + ), + ( + segment_conditions_and_nested_rules, + { + "environment": {"key": "key", "name": "Environment"}, + "identity": { + "identifier": "foo", + "key": "key_foo", + "traits": { + trait_key_1: trait_value_1, + trait_key_2: trait_value_2, + trait_key_3: trait_value_3, + }, + }, + }, True, ), ), ) -def test_identity_in_segment( +def test_context_in_segment( segment: SegmentModel, - identity_traits: typing.List[TraitModel], + context: EvaluationContext, expected_result: bool, ) -> None: - identity = IdentityModel( - identifier="foo", - identity_traits=identity_traits, - environment_api_key="api-key", - ) - - assert evaluate_identity_in_segment(identity, segment) == expected_result + assert is_context_in_segment(context, segment) == expected_result @pytest.mark.parametrize( "segment_split_value, identity_hashed_percentage, expected_result", ((10, 1, True), (100, 50, True), (0, 1, False), (10, 20, False)), ) -def test_identity_in_segment_percentage_split( +def test_context_in_segment_percentage_split( mocker: MockerFixture, - identity: IdentityModel, + context: EvaluationContext, segment_split_value: int, identity_hashed_percentage: int, expected_result: bool, @@ -140,7 +232,11 @@ def test_identity_in_segment_percentage_split( rule = SegmentRuleModel( type=constants.ALL_RULE, conditions=[percentage_split_condition] ) - segment = SegmentModel(id=1, name="% split", rules=[rule]) + segment = SegmentModel( + id=1, + name="% split", + rules=[SegmentRuleModel(type=constants.ALL_RULE, conditions=[], rules=[rule])], + ) mock_get_hashed_percentage = mocker.patch( "flag_engine.segments.evaluator.get_hashed_percentage_for_object_ids" @@ -148,12 +244,46 @@ def test_identity_in_segment_percentage_split( mock_get_hashed_percentage.return_value = identity_hashed_percentage # When - result = evaluate_identity_in_segment(identity=identity, segment=segment) + result = is_context_in_segment(context=context, segment=segment) # Then assert result == expected_result +def test_context_in_segment_percentage_split__trait_value__calls_expected( + mocker: MockerFixture, + context: EvaluationContext, +) -> None: + # Given + assert context["identity"] is not None + context["identity"]["traits"]["custom_trait"] = "custom_value" + percentage_split_condition = SegmentConditionModel( + operator=constants.PERCENTAGE_SPLIT, + value="10", + property_="custom_trait", + ) + rule = SegmentRuleModel( + type=constants.ALL_RULE, conditions=[percentage_split_condition] + ) + segment = SegmentModel( + id=1, + name="% split", + rules=[SegmentRuleModel(type=constants.ALL_RULE, conditions=[], rules=[rule])], + ) + + mock_get_hashed_percentage = mocker.patch( + "flag_engine.segments.evaluator.get_hashed_percentage_for_object_ids" + ) + mock_get_hashed_percentage.return_value = 1 + + # When + result = is_context_in_segment(context=context, segment=segment) + + # Then + mock_get_hashed_percentage.assert_called_once_with([segment.id, "custom_value"]) + assert result + + @pytest.mark.parametrize( "operator, property_, expected_result", ( @@ -163,8 +293,8 @@ def test_identity_in_segment_percentage_split( (constants.IS_NOT_SET, "random_property", True), ), ) -def test_identity_in_segment_is_set_and_is_not_set( - identity_in_segment: IdentityModel, +def test_context_in_segment_is_set_and_is_not_set( + context_in_segment: EvaluationContext, operator: ConditionOperator, property_: str, expected_result: bool, @@ -181,7 +311,7 @@ def test_identity_in_segment_is_set_and_is_not_set( segment = SegmentModel(id=1, name="segment model", rules=[rule]) # When - result = evaluate_identity_in_segment(identity=identity_in_segment, segment=segment) + result = is_context_in_segment(context=context_in_segment, segment=segment) # Then assert result is expected_result @@ -265,7 +395,7 @@ def test_identity_in_segment_is_set_and_is_not_set( (constants.IN, 1, None, False), ), ) -def test_segment_condition_matches_trait_value( +def test_segment_condition_matches_context_value( operator: ConditionOperator, trait_value: typing.Union[None, int, str, float], condition_value: object, @@ -279,7 +409,7 @@ def test_segment_condition_matches_trait_value( ) # When - result = _matches_trait_value(segment_condition, trait_value) + result = _matches_context_value(segment_condition, trait_value) # Then assert result == expected_result @@ -289,7 +419,7 @@ def test_segment_condition__unsupported_operator__return_false( mocker: MockerFixture, ) -> None: # Given - mocker.patch("flag_engine.segments.evaluator.MATCH_FUNCS_BY_OPERATOR", new={}) + mocker.patch("flag_engine.segments.evaluator.MATCHERS_BY_OPERATOR", new={}) segment_condition = SegmentConditionModel( operator=constants.EQUAL, property_="x", @@ -298,7 +428,7 @@ def test_segment_condition__unsupported_operator__return_false( trait_value = "foo" # When - result = _matches_trait_value(segment_condition, trait_value) + result = _matches_context_value(segment_condition, trait_value) # Then assert result is False @@ -329,7 +459,7 @@ def test_segment_condition__unsupported_operator__return_false( (constants.LESS_THAN_INCLUSIVE, "1.0.1", "1.0.0:semver", False), ], ) -def test_segment_condition_matches_trait_value_for_semver( +def test_segment_condition_matches_context_value_for_semver( operator: ConditionOperator, trait_value: str, condition_value: str, @@ -343,7 +473,7 @@ def test_segment_condition_matches_trait_value_for_semver( ) # When - result = _matches_trait_value(segment_condition, trait_value) + result = _matches_context_value(segment_condition, trait_value) # Then assert result == expected_result @@ -364,7 +494,7 @@ def test_segment_condition_matches_trait_value_for_semver( (1, None, False), ], ) -def test_segment_condition_matches_trait_value_for_modulo( +def test_segment_condition_matches_context_value_for_modulo( trait_value: typing.Union[int, float, str, bool], condition_value: typing.Optional[str], expected_result: bool, @@ -377,7 +507,7 @@ def test_segment_condition_matches_trait_value_for_modulo( ) # When - result = _matches_trait_value(segment_condition, trait_value) + result = _matches_context_value(segment_condition, trait_value) # Then assert result == expected_result