Skip to content

Commit b827f38

Browse files
committed
feat: Improved variable handling for config
This allows an environment variable in any position for headers as long as it is in format `$VAR` or `${VAR}` allowing alphanumeric characters and underscore `_`. This is now also applied to `remote_schema_url` in addition to headers. Fixes #328 Relates to #231
1 parent 11bfe35 commit b827f38

File tree

2 files changed

+67
-8
lines changed

2 files changed

+67
-8
lines changed

ariadne_codegen/settings.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import enum
22
import os
3+
import re
34
from dataclasses import dataclass, field
45
from keyword import iskeyword
56
from pathlib import Path
@@ -50,6 +51,7 @@ def __post_init__(self):
5051
assert_path_exists(self.schema_path)
5152

5253
self.remote_schema_headers = resolve_headers(self.remote_schema_headers)
54+
self.remote_schema_url = resolve_schema(self.remote_schema_url)
5355

5456

5557
@dataclass
@@ -276,20 +278,29 @@ def assert_string_is_valid_python_identifier(name: str):
276278
def resolve_headers(headers: Dict) -> Dict:
277279
return {key: get_header_value(value) for key, value in headers.items()}
278280

281+
def resolve_schema(value: str) -> str:
282+
return _replace_env_vars(value)
283+
279284

280285
def get_header_value(value: str) -> str:
281-
env_var_prefix = "$"
282-
if value.startswith(env_var_prefix):
283-
env_var_name = value.lstrip(env_var_prefix)
286+
return _replace_env_vars(value)
287+
288+
289+
def _replace_env_vars(value: str) -> str:
290+
pattern = re.compile(r"\${?([\w_]+)}?")
291+
292+
def replacer(match):
293+
env_var_name = match.group(1)
284294
var_value = os.environ.get(env_var_name)
295+
285296
if not var_value:
286297
raise InvalidConfiguration(
287298
f"Environment variable {env_var_name} not found."
288299
)
289-
return var_value
290300

291-
return value
301+
return var_value
292302

303+
return pattern.sub(replacer, value)
293304

294305
def assert_class_is_defined_in_file(file_path: Path, class_name: str):
295306
file_content = file_path.read_text()

tests/test_settings.py

+51-3
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,24 @@ def test_client_settings_without_schema_path_or_remote_schema_url_raises_excepti
133133
ClientSettings(queries_path=queries_path)
134134

135135

136+
@pytest.mark.parametrize(
137+
"configured_header, expected_header",
138+
[
139+
("$TEST_VAR", "test_value"),
140+
("Bearer: $TEST_VAR", "Bearer: test_value"),
141+
("Bearer: ${TEST_VAR}", "Bearer: test_value"),
142+
pytest.param(
143+
"$NOT_SET_VAR",
144+
"",
145+
marks=pytest.mark.xfail(raises=InvalidConfiguration),
146+
),
147+
],
148+
)
136149
def test_client_settings_resolves_env_variable_for_remote_schema_header_with_prefix(
137-
tmp_path, mocker
150+
tmp_path,
151+
mocker,
152+
configured_header,
153+
expected_header,
138154
):
139155
queries_path = tmp_path / "queries.graphql"
140156
queries_path.touch()
@@ -143,10 +159,42 @@ def test_client_settings_resolves_env_variable_for_remote_schema_header_with_pre
143159
settings = ClientSettings(
144160
queries_path=queries_path,
145161
remote_schema_url="https://test",
146-
remote_schema_headers={"Authorization": "$TEST_VAR"},
162+
remote_schema_headers={"Authorization": configured_header},
163+
)
164+
165+
assert settings.remote_schema_headers["Authorization"] == expected_header
166+
167+
168+
@pytest.mark.parametrize(
169+
"configured_url, expected_url",
170+
[
171+
("$TEST_VAR", "test_value"),
172+
("https://${TEST_VAR}/graphql", "https://test_value/graphql"),
173+
("https://$TEST_VAR/graphql", "https://test_value/graphql"),
174+
("https://TEST_VAR/graphql", "https://TEST_VAR/graphql"),
175+
pytest.param(
176+
"https://${NOT_SET_VAR}/graphql",
177+
"",
178+
marks=pytest.mark.xfail(raises=InvalidConfiguration),
179+
),
180+
],
181+
)
182+
def test_client_settings_resolves_env_variable_for_remote_schema(
183+
tmp_path,
184+
mocker,
185+
configured_url,
186+
expected_url,
187+
):
188+
queries_path = tmp_path / "queries.graphql"
189+
queries_path.touch()
190+
mocker.patch.dict(os.environ, {"TEST_VAR": "test_value"})
191+
192+
settings = ClientSettings(
193+
queries_path=queries_path,
194+
remote_schema_url=configured_url,
147195
)
148196

149-
assert settings.remote_schema_headers["Authorization"] == "test_value"
197+
assert settings.remote_schema_url == expected_url
150198

151199

152200
def test_client_settings_doesnt_resolve_remote_schema_header_without_prefix(tmp_path):

0 commit comments

Comments
 (0)