Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions reflex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from importlib.util import find_spec
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal

from reflex import constants
from reflex.constants.base import LogLevel
from reflex.environment import EnvironmentVariables as EnvironmentVariables
from reflex.environment import EnvVar as EnvVar
from reflex.environment import (
ExistingPath,
SequenceOptions,
_load_dotenv_from_files,
_paths_from_env_files,
interpret_env_var_value,
Expand Down Expand Up @@ -207,8 +208,10 @@ class BaseConfig:
# Timeout to do a production build of a frontend page.
static_page_generation_timeout: int = 60

# List of origins that are allowed to connect to the backend API.
cors_allowed_origins: Sequence[str] = dataclasses.field(default=("*",))
# Comma separated list of origins that are allowed to connect to the backend API.
cors_allowed_origins: Annotated[Sequence[str], SequenceOptions(delimiter=",")] = (
dataclasses.field(default=("*",))
)

# Whether to use React strict mode.
react_strict_mode: bool = True
Expand Down
25 changes: 24 additions & 1 deletion reflex/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,17 @@ def interpret_enum_env(value: str, field_type: GenericType, field_name: str) ->
raise EnvironmentVarValueError(msg) from ve


@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class SequenceOptions:
"""Options for interpreting Sequence environment variables."""

delimiter: str = ":"
strip: bool = False


DEFAULT_SEQUENCE_OPTIONS = SequenceOptions()


def interpret_env_var_value(
value: str, field_type: GenericType, field_name: str
) -> Any:
Expand Down Expand Up @@ -278,14 +289,26 @@ def interpret_env_var_value(
continue
msg = f"Invalid literal value: {value!r} for {field_name}, expected one of {literal_values}"
raise EnvironmentVarValueError(msg)
# If the field is Annotated with SequenceOptions, extract the options
sequence_options = DEFAULT_SEQUENCE_OPTIONS
if get_origin(field_type) is Annotated:
annotated_args = get_args(field_type)
field_type = annotated_args[0]
for arg in annotated_args[1:]:
if isinstance(arg, SequenceOptions):
sequence_options = arg
break
if get_origin(field_type) in (list, Sequence):
items = value.split(sequence_options.delimiter)
if sequence_options.strip:
items = [item.strip() for item in items]
return [
interpret_env_var_value(
v,
get_args(field_type)[0],
f"{field_name}[{i}]",
)
for i, v in enumerate(value.split(":"))
for i, v in enumerate(items)
]
if isinstance(field_type, type) and issubclass(field_type, enum.Enum):
return interpret_enum_env(value, field_type, field_name)
Expand Down
35 changes: 35 additions & 0 deletions tests/units/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,41 @@ def test_update_from_env_path(
assert config.bun_path == tmp_path


def test_update_from_env_cors(
base_config_values: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
):
"""Test that environment variables override config values.

Args:
base_config_values: Config values.
monkeypatch: The pytest monkeypatch object.
tmp_path: The pytest tmp_path fixture object.
"""
config = rx.Config(**base_config_values)
assert config.cors_allowed_origins == ("*",)

monkeypatch.setenv("REFLEX_CORS_ALLOWED_ORIGINS", "")
config = rx.Config(**base_config_values)
assert config.cors_allowed_origins == ("*",)

monkeypatch.setenv("REFLEX_CORS_ALLOWED_ORIGINS", "https://foo.example.com")
config = rx.Config(**base_config_values)
assert config.cors_allowed_origins == [
"https://foo.example.com",
]

monkeypatch.setenv(
"REFLEX_CORS_ALLOWED_ORIGINS", "http://example.com, http://another.com "
)
config = rx.Config(**base_config_values)
assert config.cors_allowed_origins == [
"http://example.com",
"http://another.com",
]


@pytest.mark.parametrize(
("kwargs", "expected"),
[
Expand Down
10 changes: 10 additions & 0 deletions tests/units/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import tempfile
from pathlib import Path
from typing import Annotated
from unittest.mock import patch

import pytest
Expand All @@ -15,6 +16,7 @@
ExecutorType,
ExistingPath,
PerformanceMode,
SequenceOptions,
_load_dotenv_from_files,
_paths_from_env_files,
_paths_from_environment,
Expand Down Expand Up @@ -175,6 +177,14 @@ def test_interpret_list(self):
result = interpret_env_var_value("1:2:3", list[int], "TEST_FIELD")
assert result == [1, 2, 3]

def test_interpret_annotated_sequence(self):
"""Test annotated sequence interpretation."""
annotated_type = Annotated[
list[str], SequenceOptions(delimiter=",", strip=True)
]
result = interpret_env_var_value("a, b, c ", annotated_type, "TEST_FIELD")
assert result == ["a", "b", "c"]

def test_interpret_enum(self):
"""Test enum interpretation."""
result = interpret_env_var_value("value1", _TestEnum, "TEST_FIELD")
Expand Down
Loading