Skip to content

Commit 11290c1

Browse files
committed
[dg] add --fix-env-requirements flag to dg check
1 parent 1cd036d commit 11290c1

File tree

5 files changed

+165
-25
lines changed

5 files changed

+165
-25
lines changed

python_modules/libraries/dagster-dg/dagster_dg/check.py

+44-14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, NamedTuple, Optional
55

66
import click
7+
import yaml
78
from dagster_shared.serdes.objects import LibraryObjectKey
89
from dagster_shared.yaml_utils import parse_yaml_with_source_positions
910
from dagster_shared.yaml_utils.source_position import (
@@ -62,12 +63,13 @@ class ErrorInput(NamedTuple):
6263
def check_yaml(
6364
dg_context: DgContext,
6465
resolved_paths: Sequence[Path],
66+
fix_env_requirements: bool,
6567
) -> bool:
6668
top_level_component_validator = Draft202012Validator(schema=COMPONENT_FILE_SCHEMA)
6769

6870
validation_errors: list[ErrorInput] = []
6971
all_specified_env_var_deps = set()
70-
72+
updated_files = set()
7173
component_contents_by_key: dict[LibraryObjectKey, Any] = {}
7274
modules_to_fetch = set()
7375
for component_dir in dg_context.defs_path.iterdir():
@@ -103,21 +105,38 @@ def check_yaml(
103105
all_specified_env_var_deps.update(specified_env_var_deps)
104106

105107
if used_env_vars - specified_env_var_deps:
106-
msg = (
107-
"Component uses environment variables that are not specified in the component file: "
108-
+ ", ".join(used_env_vars - specified_env_var_deps)
109-
)
108+
if fix_env_requirements:
109+
component_doc_updated = component_doc_tree.value
110+
if "requires" not in component_doc_updated:
111+
component_doc_updated["requires"] = {}
112+
if "env" not in component_doc_updated["requires"]:
113+
component_doc_updated["requires"]["env"] = []
114+
115+
component_doc_updated["requires"]["env"].extend(
116+
used_env_vars - specified_env_var_deps
117+
)
110118

111-
validation_errors.append(
112-
ErrorInput(
113-
None,
114-
ValidationError(
115-
msg,
116-
path=["requires", "env"],
117-
),
118-
component_doc_tree,
119+
component_path.write_text(yaml.dump(component_doc_updated, sort_keys=False))
120+
click.echo(f"Updated {component_path}")
121+
updated_files.add(component_path)
122+
all_specified_env_var_deps.update(used_env_vars)
123+
else:
124+
msg = (
125+
"Component uses environment variables that are not specified in the component file: "
126+
+ ", ".join(used_env_vars - specified_env_var_deps)
127+
+ "\nTo automatically add these environment variables to the component file, run `dg check yaml --fix-env-requirements`"
128+
)
129+
130+
validation_errors.append(
131+
ErrorInput(
132+
None,
133+
ValidationError(
134+
msg,
135+
path=["requires", "env"],
136+
),
137+
component_doc_tree,
138+
)
119139
)
120-
)
121140

122141
# First, validate the top-level structure of the component file
123142
# (type and params keys) before we try to validate the params themselves.
@@ -175,11 +194,22 @@ def check_yaml(
175194
prefix=["attributes"] if key else [],
176195
)
177196
)
197+
if updated_files:
198+
click.echo(
199+
"The following component files were updated to fix environment variable requirements:\n"
200+
+ "\n".join(updated_files)
201+
)
178202
return False
179203
else:
180204
missing_env_vars = (
181205
all_specified_env_var_deps - ProjectEnvVars.from_ctx(dg_context).values.keys()
182206
) - os.environ.keys()
207+
if updated_files:
208+
click.echo(
209+
"The following component files were updated to fix environment variable requirements:\n"
210+
+ "\n".join(str(file) for file in updated_files)
211+
)
212+
183213
if missing_env_vars:
184214
click.echo(
185215
"The following environment variables are used in components but not specified in the .env file or the current shell environment:\n"

python_modules/libraries/dagster-dg/dagster_dg/cli/check.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,17 @@ def check_group():
3232
@click.option(
3333
"--watch", is_flag=True, help="Watch for changes to the component files and re-validate them."
3434
)
35+
@click.option(
36+
"--fix-env-requirements",
37+
is_flag=True,
38+
help="Automatically add environment variable requirements to component yaml files.",
39+
)
3540
@dg_global_options
3641
@cli_telemetry_wrapper
3742
def check_yaml_command(
3843
paths: Sequence[str],
3944
watch: bool,
45+
fix_env_requirements: bool,
4046
**global_options: object,
4147
) -> None:
4248
"""Check component.yaml files against their schemas, showing validation errors."""
@@ -45,7 +51,7 @@ def check_yaml_command(
4551
resolved_paths = [Path(path).absolute() for path in paths]
4652

4753
def run_check(_: Any = None) -> bool:
48-
return check_yaml_fn(dg_context, resolved_paths)
54+
return check_yaml_fn(dg_context, resolved_paths, fix_env_requirements=fix_env_requirements)
4955

5056
if watch:
5157
watched_paths = (
@@ -145,6 +151,7 @@ def check_definitions_command(
145151
check_result = check_yaml_fn(
146152
dg_context.for_project_environment(project_dir, cli_config),
147153
[],
154+
fix_env_requirements=False,
148155
)
149156
overall_check_result = overall_check_result and check_result
150157
if not overall_check_result:

python_modules/libraries/dagster-dg/dagster_dg/cli/env.py

+63
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
from collections.abc import Sequence
12
from pathlib import Path
23

34
import click
5+
import yaml
6+
from dagster_shared.yaml_utils import parse_yaml_with_source_positions
47
from rich.console import Console
58
from rich.table import Table
9+
from yaml.scanner import ScannerError
610

711
from dagster_dg.cli.shared_options import dg_global_options
12+
from dagster_dg.component import get_specified_env_var_deps, get_used_env_vars
813
from dagster_dg.config import normalize_cli_config
914
from dagster_dg.context import DgContext
1015
from dagster_dg.env import ProjectEnvVars
@@ -42,3 +47,61 @@ def list_env_command(**global_options: object) -> None:
4247
table.add_row(key, value)
4348
console = Console()
4449
console.print(table)
50+
51+
52+
# ########################
53+
# ##### FIX COMPONENT YAML REQUIREMENTS
54+
# ########################
55+
56+
57+
@env_group.command(name="fix-component-requirements", cls=DgClickCommand)
58+
@click.argument("paths", nargs=-1, type=click.Path(exists=True))
59+
@dg_global_options
60+
def fix_component_requirements(paths: Sequence[str], **global_options: object) -> None:
61+
"""Automatically add environment variable requirements to component yaml files."""
62+
cli_config = normalize_cli_config(global_options, click.get_current_context())
63+
dg_context = DgContext.for_project_environment(Path.cwd(), cli_config)
64+
65+
resolved_paths = [Path(path).absolute() for path in paths]
66+
67+
updated_files = set()
68+
69+
for component_dir in dg_context.defs_path.iterdir():
70+
if resolved_paths and not any(
71+
path == component_dir or path in component_dir.parents for path in resolved_paths
72+
):
73+
continue
74+
75+
component_path = component_dir / "component.yaml"
76+
77+
if component_path.exists():
78+
text = component_path.read_text()
79+
try:
80+
component_doc_tree = parse_yaml_with_source_positions(
81+
text, filename=str(component_path)
82+
)
83+
except ScannerError:
84+
continue
85+
86+
specified_env_var_deps = get_specified_env_var_deps(component_doc_tree.value)
87+
used_env_vars = get_used_env_vars(component_doc_tree.value)
88+
89+
if used_env_vars - specified_env_var_deps:
90+
component_doc_updated = component_doc_tree.value
91+
if "requires" not in component_doc_updated:
92+
component_doc_updated["requires"] = {}
93+
if "env" not in component_doc_updated["requires"]:
94+
component_doc_updated["requires"]["env"] = []
95+
96+
component_doc_updated["requires"]["env"].extend(
97+
used_env_vars - specified_env_var_deps
98+
)
99+
100+
component_path.write_text(yaml.dump(component_doc_updated, sort_keys=False))
101+
click.echo(f"Updated {component_path}")
102+
updated_files.add(component_path)
103+
104+
if updated_files:
105+
click.echo(f"Updated {len(updated_files)} component yaml files.")
106+
else:
107+
click.echo("No component yaml files were updated.")

python_modules/libraries/dagster-dg/dagster_dg/component.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,6 @@ def get_used_env_vars(data_structure: Union[Mapping[str, Any], Sequence[Any], An
166166
elif isinstance(data_structure, str):
167167
return set(env_var_regex.findall(data_structure))
168168
elif isinstance(data_structure, Sequence):
169-
return set.union(*(get_used_env_vars(item) for item in data_structure))
169+
return set.union(set(), *(get_used_env_vars(item) for item in data_structure))
170170
else:
171171
return set()

python_modules/libraries/dagster-dg/dagster_dg_tests/cli_tests/test_check_yaml_command.py

+49-9
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@
2929
ProxyRunner,
3030
assert_runner_result,
3131
create_project_from_components,
32+
set_env_var,
33+
)
34+
35+
NON_SPECIFIED_ENV_VAR_TEST_CASE = ComponentValidationTestCase(
36+
component_path="validation/basic_component_missing_declared_env",
37+
component_type_filepath=BASIC_COMPONENT_TYPE_FILEPATH,
38+
should_error=True,
39+
check_error_msg=msg_includes_all_of(
40+
"component.yaml:1",
41+
"Component uses environment variables that are not specified in the component file: A_STRING",
42+
),
3243
)
3344

3445
CLI_TEST_CASES = [
@@ -51,15 +62,7 @@
5162
"'an_extra_top_level_value' was unexpected",
5263
),
5364
),
54-
ComponentValidationTestCase(
55-
component_path="validation/basic_component_missing_declared_env",
56-
component_type_filepath=BASIC_COMPONENT_TYPE_FILEPATH,
57-
should_error=True,
58-
check_error_msg=msg_includes_all_of(
59-
"component.yaml:1",
60-
"Component uses environment variables that are not specified in the component file: A_STRING",
61-
),
62-
),
65+
NON_SPECIFIED_ENV_VAR_TEST_CASE,
6366
ComponentValidationTestCase(
6467
component_path="validation/basic_component_with_env",
6568
component_type_filepath=BASIC_COMPONENT_TYPE_FILEPATH,
@@ -291,3 +294,40 @@ def test_check_yaml_local_component_cache() -> None:
291294
r"CACHE \[write\].*basic_component_invalid_value.*local_component_registry",
292295
result.stdout,
293296
)
297+
298+
299+
def test_check_yaml_fix_env_requirements() -> None:
300+
with (
301+
ProxyRunner.test() as runner,
302+
create_project_from_components(
303+
runner,
304+
NON_SPECIFIED_ENV_VAR_TEST_CASE.component_path,
305+
local_component_defn_to_inject=NON_SPECIFIED_ENV_VAR_TEST_CASE.component_type_filepath,
306+
) as tmpdir,
307+
):
308+
with pushd(tmpdir):
309+
result = runner.invoke(
310+
"check",
311+
"yaml",
312+
)
313+
assert result.exit_code != 0, str(result.stdout)
314+
assert NON_SPECIFIED_ENV_VAR_TEST_CASE.check_error_msg
315+
NON_SPECIFIED_ENV_VAR_TEST_CASE.check_error_msg(str(result.stdout))
316+
317+
result = runner.invoke(
318+
"check",
319+
"yaml",
320+
"--fix-env-requirements",
321+
)
322+
assert result.exit_code == 1, str(result.stdout)
323+
assert (
324+
"The following component files were updated to fix environment variable requirements:"
325+
in str(result.stdout)
326+
)
327+
328+
with set_env_var("A_STRING", "foo"):
329+
result = runner.invoke(
330+
"check",
331+
"yaml",
332+
)
333+
assert result.exit_code == 0, str(result.stdout)

0 commit comments

Comments
 (0)