Skip to content

Commit f4a90b4

Browse files
authored
Merge pull request #189 from awslabs/fix/sdk_init_config_handlers
fix(config): config_dict path now applies full normalisation pipeline
2 parents 89bb5d9 + 3d98f2f commit f4a90b4

3 files changed

Lines changed: 202 additions & 12 deletions

File tree

src/orb/config/loader.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,40 @@ def load(
151151

152152
return config
153153

154+
@classmethod
155+
def _build_raw_config_from_dict(
156+
cls,
157+
config_dict: dict[str, Any],
158+
config_manager: Optional["ConfigurationManager"] = None,
159+
) -> dict[str, Any]:
160+
"""
161+
Apply the full normalisation pipeline to an in-memory dict.
162+
Mirrors ConfigurationLoader.load() but starts from a dict instead of a file.
163+
Pipeline: package defaults -> strategy defaults -> user dict -> env vars -> env expansion.
164+
"""
165+
from orb.config.utils.env_expansion import expand_config_env_vars
166+
167+
base = cls._load_default_config()
168+
169+
strategy_defaults = cls._load_strategy_defaults(config_manager)
170+
if strategy_defaults:
171+
cls._merge_config(base, strategy_defaults)
172+
173+
cls._merge_config(base, config_dict)
174+
175+
cls._load_from_env(base, config_manager)
176+
result = expand_config_env_vars(base)
177+
178+
# Hoist provider.provider_defaults to the top level so callers can
179+
# inspect strategy defaults without navigating the nested provider key.
180+
# Only hoist if not already present at the top level (e.g. from a mock or explicit dict).
181+
if "provider_defaults" not in result:
182+
provider_section = result.get("provider", {})
183+
if "provider_defaults" in provider_section:
184+
result["provider_defaults"] = provider_section["provider_defaults"]
185+
186+
return result
187+
154188
@classmethod
155189
def _load_strategy_defaults(cls, config_manager=None) -> dict[str, Any]:
156190
merged: dict[str, Any] = {}

src/orb/config/managers/configuration_manager.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,7 @@ def app_config(self) -> AppConfig:
9696
def _load_app_config(self) -> AppConfig:
9797
"""Load application configuration from loader or in-memory dict."""
9898
try:
99-
if self._config_dict is not None:
100-
return self.loader.create_app_config(self._config_dict)
101-
raw_config = self.loader.load(self._config_file, config_manager=self)
99+
raw_config = self._ensure_raw_config()
102100
return self.loader.create_app_config(raw_config)
103101
except Exception as e:
104102
logger.error("Failed to load app config: %s", e, exc_info=True)
@@ -108,13 +106,11 @@ def _ensure_raw_config(self) -> dict[str, Any]:
108106
"""Ensure raw configuration is loaded."""
109107
if self._raw_config is None:
110108
if self._config_dict is not None:
111-
# Merge provided dict on top of package defaults so that
112-
# provider_defaults (supports_spot etc.) are always present.
113109
from orb.config.loader import ConfigurationLoader
114110

115-
base = ConfigurationLoader._load_default_config()
116-
ConfigurationLoader._merge_config(base, self._config_dict)
117-
self._raw_config = base
111+
self._raw_config = ConfigurationLoader._build_raw_config_from_dict(
112+
self._config_dict, config_manager=self
113+
)
118114
else:
119115
self._raw_config = self.loader.load(self._config_file, config_manager=self)
120116
return self._raw_config
@@ -169,10 +165,8 @@ def get_typed_with_defaults(self, config_type: type[T]) -> T:
169165
def reload(self) -> None:
170166
"""Reload configuration from sources."""
171167
try:
172-
# Re-derive config file path in case ORB_CONFIG_DIR changed between tests
173-
from orb.config.platform_dirs import get_config_location
174-
175-
self._config_file = str(get_config_location() / "config.json")
168+
# Do NOT re-derive _config_file — preserve construction parameters (_config_dict, _config_file)
169+
# Only reset cached derived state
176170

177171
# Clear all caches
178172
self._cache_manager.clear_cache()
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""Tests proving bug #185: config_dict path skips _load_strategy_defaults().
2+
3+
Tests 1-2 are baseline/resilience checks (expected to PASS).
4+
Tests 3-6 prove the bug (expected to FAIL against unmodified code).
5+
"""
6+
7+
from __future__ import annotations
8+
9+
from unittest.mock import MagicMock, patch
10+
11+
from orb.config.loader import ConfigurationLoader
12+
from orb.config.managers.configuration_manager import ConfigurationManager
13+
14+
_MINIMAL_AWS_DICT = {"provider": {"type": "aws"}}
15+
16+
17+
# ---------------------------------------------------------------------------
18+
# Test 1 — baseline: _load_strategy_defaults works in isolation (PASS)
19+
# ---------------------------------------------------------------------------
20+
21+
22+
class TestLoadStrategyDefaultsBaseline:
23+
def test_load_strategy_defaults_returns_non_empty_dict(self):
24+
"""_load_strategy_defaults() must return a non-empty dict when the
25+
provider registry is available and the aws provider is registered."""
26+
result = ConfigurationLoader._load_strategy_defaults()
27+
assert isinstance(result, dict)
28+
assert len(result) > 0, (
29+
"_load_strategy_defaults() returned an empty dict — "
30+
"expected provider handler defaults to be present"
31+
)
32+
33+
34+
# ---------------------------------------------------------------------------
35+
# Test 2 — resilience: registry failure must not propagate (PASS)
36+
# ---------------------------------------------------------------------------
37+
38+
39+
class TestLoadStrategyDefaultsResilience:
40+
def test_load_strategy_defaults_survives_registry_failure(self):
41+
"""If get_provider_registry raises, _load_strategy_defaults must
42+
swallow the error and return a dict (possibly empty).
43+
44+
The import is local inside the method body, so we patch the symbol
45+
on the orb.providers.registry module directly.
46+
"""
47+
with patch(
48+
"orb.providers.registry.get_provider_registry",
49+
side_effect=RuntimeError("registry unavailable"),
50+
):
51+
result = ConfigurationLoader._load_strategy_defaults()
52+
assert isinstance(result, dict)
53+
54+
55+
# ---------------------------------------------------------------------------
56+
# Test 3 — bug: dict path never calls _load_strategy_defaults (FAIL)
57+
# ---------------------------------------------------------------------------
58+
59+
60+
class TestEnsureRawConfigCallsLoadStrategyDefaults:
61+
def test_ensure_raw_config_with_config_dict_calls_load_strategy_defaults(self):
62+
"""_ensure_raw_config() must call _load_strategy_defaults() when
63+
config_dict= is supplied, just as the file path does.
64+
65+
FAILS against current code because the dict branch only calls
66+
_load_default_config() and skips _load_strategy_defaults().
67+
"""
68+
mock_load_strategy = MagicMock(return_value={})
69+
70+
with patch.object(ConfigurationLoader, "_load_strategy_defaults", mock_load_strategy):
71+
cm = ConfigurationManager(config_dict=_MINIMAL_AWS_DICT)
72+
cm._ensure_raw_config()
73+
74+
mock_load_strategy.assert_called()
75+
76+
77+
# ---------------------------------------------------------------------------
78+
# Test 4 — bug: strategy defaults not merged into result (FAIL)
79+
# ---------------------------------------------------------------------------
80+
81+
82+
class TestEnsureRawConfigMergesStrategyDefaults:
83+
def test_ensure_raw_config_with_config_dict_merges_strategy_defaults(self):
84+
"""Strategy defaults returned by _load_strategy_defaults() must be
85+
present in the dict produced by _ensure_raw_config().
86+
87+
FAILS against current code because the dict branch never merges them.
88+
"""
89+
fake_defaults = {
90+
"provider_defaults": {"aws": {"handlers": {"RunInstances": {"enabled": True}}}}
91+
}
92+
93+
with patch.object(
94+
ConfigurationLoader, "_load_strategy_defaults", return_value=fake_defaults
95+
):
96+
cm = ConfigurationManager(config_dict=_MINIMAL_AWS_DICT)
97+
result = cm._ensure_raw_config()
98+
99+
assert "provider_defaults" in result, (
100+
"provider_defaults key missing — strategy defaults were not merged"
101+
)
102+
assert result["provider_defaults"]["aws"]["handlers"]["RunInstances"]["enabled"] is True
103+
104+
105+
# ---------------------------------------------------------------------------
106+
# Test 5 — bug: real defaults absent from config_dict path (FAIL)
107+
# ---------------------------------------------------------------------------
108+
109+
110+
class TestConfigurationManagerConfigDictHasAwsHandlerDefaults:
111+
def test_configuration_manager_config_dict_has_aws_handler_defaults(self):
112+
"""Without any mocking, constructing ConfigurationManager(config_dict=...)
113+
and calling _ensure_raw_config() must produce a result that contains
114+
evidence of strategy defaults (e.g. a 'provider_defaults' key or any
115+
handler-related key contributed by _load_strategy_defaults()).
116+
117+
FAILS against current code because the dict branch skips
118+
_load_strategy_defaults() entirely.
119+
"""
120+
cm = ConfigurationManager(config_dict=_MINIMAL_AWS_DICT)
121+
result = cm._ensure_raw_config()
122+
123+
# _load_strategy_defaults merges provider_defaults from the AWS strategy.
124+
# If the dict path called it, this key would be present.
125+
assert "provider_defaults" in result, (
126+
"provider_defaults missing from _ensure_raw_config() result when "
127+
"config_dict= is used — _load_strategy_defaults() was not called"
128+
)
129+
130+
131+
# ---------------------------------------------------------------------------
132+
# Test 6 — bug: call-count confirms dict path skips the method (FAIL)
133+
# ---------------------------------------------------------------------------
134+
135+
136+
class TestFileAndDictPathsProduceEquivalentStrategyDefaults:
137+
def test_file_and_dict_paths_produce_equivalent_strategy_defaults(self):
138+
"""_load_strategy_defaults must be called at least once when the
139+
config_dict= path is taken, matching the behaviour of the file path.
140+
141+
FAILS against current code because the dict branch in
142+
_ensure_raw_config() never invokes _load_strategy_defaults().
143+
"""
144+
call_count: list[int] = [0]
145+
original = ConfigurationLoader._load_strategy_defaults
146+
147+
def counting_load_strategy_defaults(*args, **kwargs):
148+
call_count[0] += 1
149+
return original(*args, **kwargs)
150+
151+
with patch.object(
152+
ConfigurationLoader,
153+
"_load_strategy_defaults",
154+
side_effect=counting_load_strategy_defaults,
155+
):
156+
cm = ConfigurationManager(config_dict=_MINIMAL_AWS_DICT)
157+
cm._ensure_raw_config()
158+
159+
assert call_count[0] >= 1, (
160+
f"_load_strategy_defaults was called {call_count[0]} time(s) via the "
161+
"config_dict path — expected at least 1 call"
162+
)

0 commit comments

Comments
 (0)