Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions src/edge_proxy/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

logger = structlog.get_logger(__name__)

SERVER_API_KEY_PREFIX = "ser."


class EnvironmentService:
def __init__(
Expand Down Expand Up @@ -77,11 +79,12 @@ def get_flags_response_data(
) -> dict[str, typing.Any]:
environment_document = self.get_environment(environment_key)
environment = EnvironmentModel.model_validate(environment_document)
is_server_key = environment_key.startswith(SERVER_API_KEY_PREFIX)

if feature:
feature_state = get_environment_feature_state(environment, feature)

if not filter_out_server_key_only_feature_states(
if not is_server_key and not filter_out_server_key_only_feature_states(
feature_states=[feature_state],
environment=environment,
):
Expand All @@ -90,10 +93,12 @@ def get_flags_response_data(
data = map_feature_state_to_response_data(feature_state)

else:
feature_states = filter_out_server_key_only_feature_states(
feature_states=get_environment_feature_states(environment),
environment=environment,
)
feature_states = get_environment_feature_states(environment)
if not is_server_key:
feature_states = filter_out_server_key_only_feature_states(
feature_states=feature_states,
environment=environment,
)
data = map_feature_states_to_response_data(feature_states)

return data
Expand All @@ -103,21 +108,26 @@ def get_identity_response_data(
) -> dict[str, typing.Any]:
environment_document = self.get_environment(environment_key)
environment = EnvironmentModel.model_validate(environment_document)
is_server_key = environment_key.startswith(SERVER_API_KEY_PREFIX)

identity = IdentityModel.model_validate(
self.cache.get_identity(
environment_api_key=environment_key,
identifier=input_data.identifier,
)
)
trait_models = input_data.traits
flags = filter_out_server_key_only_feature_states(
feature_states=get_identity_feature_states(
environment,
identity,
override_traits=trait_models,
),
environment=environment,
flags = get_identity_feature_states(
environment,
identity,
override_traits=trait_models,
)

if not is_server_key:
flags = filter_out_server_key_only_feature_states(
feature_states=flags,
environment=environment,
)
data = {
"traits": map_traits_to_response_data(trait_models),
"flags": map_feature_states_to_response_data(
Expand Down
103 changes: 102 additions & 1 deletion tests/test_environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from pytest_mock import MockerFixture

from edge_proxy.environments import EnvironmentService
from edge_proxy.exceptions import FlagsmithUnknownKeyError
from edge_proxy.exceptions import (
FeatureNotFoundError,
FlagsmithUnknownKeyError,
)
from edge_proxy.models import IdentityWithTraits
from edge_proxy.settings import (
EndpointCacheSettings,
Expand Down Expand Up @@ -230,3 +233,101 @@ async def test_get_identity_flags_response_skips_cache_for_different_identity(
assert environment_service.get_identity_response_data.cache_info().currsize == 2
assert environment_service.get_identity_response_data.cache_info().misses == 2
assert environment_service.get_identity_response_data.cache_info().hits == 0


@pytest.mark.asyncio
async def test_get_flags_response_data_skips_filter_for_server_key(
mocker: MockerFixture,
) -> None:
# Given
# We create a new settings object that contains a server key as a client_side_key
api_key = "ser." + environment_1_api_key
_settings = AppSettings(
environment_key_pairs=[
{"client_side_key": api_key, "server_side_key": "ser.key"}
]
)

mocked_client = mocker.AsyncMock()
mocked_client.get.return_value = mocker.MagicMock(
text=orjson.dumps(environment_1), raise_for_status=lambda: None
)

environment_service = EnvironmentService(settings=_settings, client=mocked_client)
await environment_service.refresh_environment_caches()

# When
# We retrieve the flag response data
flags = environment_service.get_flags_response_data(api_key)
specific_flag = environment_service.get_flags_response_data(api_key, "feature_3")

# Then
# we get the server-side only flag
assert len(flags) == 3
assert flags[2].get("feature").get("name") == "feature_3"
assert specific_flag.get("feature").get("name") == "feature_3"


@pytest.mark.asyncio
async def test_get_flags_response_data_filters_server_side_features_for_client_key(
mocker: MockerFixture,
) -> None:
# Given
# We create a new settings object that contains a client side key
_settings = AppSettings(
environment_key_pairs=[
{"client_side_key": environment_1_api_key, "server_side_key": "ser.key"}
]
)

mocked_client = mocker.AsyncMock()
mocked_client.get.return_value = mocker.MagicMock(
text=orjson.dumps(environment_1), raise_for_status=lambda: None
)

environment_service = EnvironmentService(settings=_settings, client=mocked_client)
await environment_service.refresh_environment_caches()

# When
# We retrieve the flag response data
flags = environment_service.get_flags_response_data(environment_1_api_key)
with pytest.raises(FeatureNotFoundError):
environment_service.get_flags_response_data(environment_1_api_key, "feature_3")

# Then
# we only get the two client side flags
assert len(flags) == 2


@pytest.mark.asyncio
async def test_get_identity_flags_response_skips_filter_for_server_key(
mocker: MockerFixture,
) -> None:
# Given
# We create a new settings object that contains a server key as a client_side_key
api_key = "ser." + environment_1_api_key
_settings = AppSettings(
environment_key_pairs=[
{"client_side_key": api_key, "server_side_key": "ser.key"}
]
)

mocked_client = mocker.AsyncMock()
mocked_client.get.return_value = mocker.MagicMock(
text=orjson.dumps(environment_1), raise_for_status=lambda: None
)

environment_service = EnvironmentService(settings=_settings, client=mocked_client)
await environment_service.refresh_environment_caches()

# When
# We retrieve the flags for an identity
result = environment_service.get_identity_response_data(
IdentityWithTraits(identifier="foo"), api_key
)

# Then
# we get the server-side only flag
flags = result.get("flags")
assert len(flags) == 3
assert flags[2].get("feature").get("name") == "feature_3"