diff --git a/reflex/config.py b/reflex/config.py index 81441ff4c5..6977d74563 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -10,7 +10,7 @@ from importlib.util import find_spec from pathlib import Path from types import ModuleType -from typing import TYPE_CHECKING, Any, ClassVar, Literal +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal from reflex import constants from reflex.constants.base import LogLevel @@ -18,6 +18,7 @@ from reflex.environment import EnvVar as EnvVar from reflex.environment import ( ExistingPath, + SequenceOptions, _load_dotenv_from_files, _paths_from_env_files, interpret_env_var_value, @@ -207,8 +208,10 @@ class BaseConfig: # Timeout to do a production build of a frontend page. static_page_generation_timeout: int = 60 - # List of origins that are allowed to connect to the backend API. - cors_allowed_origins: Sequence[str] = dataclasses.field(default=("*",)) + # Comma separated list of origins that are allowed to connect to the backend API. + cors_allowed_origins: Annotated[Sequence[str], SequenceOptions(delimiter=",")] = ( + dataclasses.field(default=("*",)) + ) # Whether to use React strict mode. react_strict_mode: bool = True diff --git a/reflex/environment.py b/reflex/environment.py index 2ec12a32fa..279fc5f60c 100644 --- a/reflex/environment.py +++ b/reflex/environment.py @@ -212,6 +212,17 @@ def interpret_enum_env(value: str, field_type: GenericType, field_name: str) -> raise EnvironmentVarValueError(msg) from ve +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class SequenceOptions: + """Options for interpreting Sequence environment variables.""" + + delimiter: str = ":" + strip: bool = False + + +DEFAULT_SEQUENCE_OPTIONS = SequenceOptions() + + def interpret_env_var_value( value: str, field_type: GenericType, field_name: str ) -> Any: @@ -278,14 +289,26 @@ def interpret_env_var_value( continue msg = f"Invalid literal value: {value!r} for {field_name}, expected one of {literal_values}" raise EnvironmentVarValueError(msg) + # If the field is Annotated with SequenceOptions, extract the options + sequence_options = DEFAULT_SEQUENCE_OPTIONS + if get_origin(field_type) is Annotated: + annotated_args = get_args(field_type) + field_type = annotated_args[0] + for arg in annotated_args[1:]: + if isinstance(arg, SequenceOptions): + sequence_options = arg + break if get_origin(field_type) in (list, Sequence): + items = value.split(sequence_options.delimiter) + if sequence_options.strip: + items = [item.strip() for item in items] return [ interpret_env_var_value( v, get_args(field_type)[0], f"{field_name}[{i}]", ) - for i, v in enumerate(value.split(":")) + for i, v in enumerate(items) ] if isinstance(field_type, type) and issubclass(field_type, enum.Enum): return interpret_enum_env(value, field_type, field_name) diff --git a/tests/units/test_config.py b/tests/units/test_config.py index 67e752f00a..5b5ad71d71 100644 --- a/tests/units/test_config.py +++ b/tests/units/test_config.py @@ -96,6 +96,41 @@ def test_update_from_env_path( assert config.bun_path == tmp_path +def test_update_from_env_cors( + base_config_values: dict[str, Any], + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +): + """Test that environment variables override config values. + + Args: + base_config_values: Config values. + monkeypatch: The pytest monkeypatch object. + tmp_path: The pytest tmp_path fixture object. + """ + config = rx.Config(**base_config_values) + assert config.cors_allowed_origins == ("*",) + + monkeypatch.setenv("REFLEX_CORS_ALLOWED_ORIGINS", "") + config = rx.Config(**base_config_values) + assert config.cors_allowed_origins == ("*",) + + monkeypatch.setenv("REFLEX_CORS_ALLOWED_ORIGINS", "https://foo.example.com") + config = rx.Config(**base_config_values) + assert config.cors_allowed_origins == [ + "https://foo.example.com", + ] + + monkeypatch.setenv( + "REFLEX_CORS_ALLOWED_ORIGINS", "http://example.com, http://another.com " + ) + config = rx.Config(**base_config_values) + assert config.cors_allowed_origins == [ + "http://example.com", + "http://another.com", + ] + + @pytest.mark.parametrize( ("kwargs", "expected"), [ diff --git a/tests/units/test_environment.py b/tests/units/test_environment.py index cc8f9bd5df..839fea4545 100644 --- a/tests/units/test_environment.py +++ b/tests/units/test_environment.py @@ -4,6 +4,7 @@ import os import tempfile from pathlib import Path +from typing import Annotated from unittest.mock import patch import pytest @@ -15,6 +16,7 @@ ExecutorType, ExistingPath, PerformanceMode, + SequenceOptions, _load_dotenv_from_files, _paths_from_env_files, _paths_from_environment, @@ -175,6 +177,14 @@ def test_interpret_list(self): result = interpret_env_var_value("1:2:3", list[int], "TEST_FIELD") assert result == [1, 2, 3] + def test_interpret_annotated_sequence(self): + """Test annotated sequence interpretation.""" + annotated_type = Annotated[ + list[str], SequenceOptions(delimiter=",", strip=True) + ] + result = interpret_env_var_value("a, b, c ", annotated_type, "TEST_FIELD") + assert result == ["a", "b", "c"] + def test_interpret_enum(self): """Test enum interpretation.""" result = interpret_env_var_value("value1", _TestEnum, "TEST_FIELD")