Skip to content

Commit c6f013d

Browse files
Centralize Pulumi state management, eliminate CLI duplication, and scope all resources to cluster (#74)
1 parent 59c01a7 commit c6f013d

File tree

18 files changed

+712
-437
lines changed

18 files changed

+712
-437
lines changed

keras_remote/backend/execution.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from google.api_core import exceptions as google_exceptions
1717

1818
from keras_remote.backend import gke_client, pathways_client
19-
from keras_remote.constants import get_default_zone, zone_to_region
19+
from keras_remote.constants import (
20+
get_default_cluster_name,
21+
get_default_zone,
22+
zone_to_region,
23+
)
2024
from keras_remote.credentials import ensure_credentials
2125
from keras_remote.data import _make_data_ref
2226
from keras_remote.infra import container_builder
@@ -39,6 +43,7 @@ class JobContext:
3943
container_image: Optional[str]
4044
zone: str
4145
project: str
46+
cluster_name: str
4247

4348
# Generated identifiers
4449
job_id: str = field(default_factory=lambda: f"job-{uuid.uuid4().hex[:8]}")
@@ -58,7 +63,7 @@ class JobContext:
5863
image_uri: Optional[str] = None
5964

6065
def __post_init__(self):
61-
self.bucket_name = f"{self.project}-keras-remote-jobs"
66+
self.bucket_name = f"{self.project}-kr-{self.cluster_name}-jobs"
6267
self.region = zone_to_region(self.zone)
6368
self.display_name = f"keras-remote-{self.func.__name__}-{self.job_id}"
6469

@@ -73,9 +78,10 @@ def from_params(
7378
zone: Optional[str],
7479
project: Optional[str],
7580
env_vars: dict,
81+
cluster_name: Optional[str] = None,
7682
volumes: Optional[dict] = None,
7783
) -> "JobContext":
78-
"""Factory method with default resolution for zone/project."""
84+
"""Factory method with default resolution for zone/project/cluster."""
7985
if not zone:
8086
zone = get_default_zone()
8187
if not project:
@@ -85,6 +91,8 @@ def from_params(
8591
"project must be specified or set KERAS_REMOTE_PROJECT"
8692
" (or GOOGLE_CLOUD_PROJECT) environment variable"
8793
)
94+
if not cluster_name:
95+
cluster_name = get_default_cluster_name()
8896

8997
return cls(
9098
func=func,
@@ -95,6 +103,7 @@ def from_params(
95103
container_image=container_image,
96104
zone=zone,
97105
project=project,
106+
cluster_name=cluster_name,
98107
volumes=volumes,
99108
)
100109

@@ -303,6 +312,7 @@ def _build_container(ctx: JobContext) -> None:
303312
accelerator_type=ctx.accelerator,
304313
project=ctx.project,
305314
zone=ctx.zone,
315+
cluster_name=ctx.cluster_name,
306316
)
307317

308318

keras_remote/backend/execution_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def test_post_init_derived_fields(self):
4040
container_image=None,
4141
zone="europe-west4-b",
4242
project="my-proj",
43+
cluster_name="my-cluster",
4344
)
44-
self.assertEqual(ctx.bucket_name, "my-proj-keras-remote-jobs")
45+
self.assertEqual(ctx.bucket_name, "my-proj-kr-my-cluster-jobs")
4546
self.assertEqual(ctx.region, "europe-west4")
4647
self.assertTrue(ctx.display_name.startswith("keras-remote-my_train-"))
4748
self.assertRegex(ctx.job_id, r"^job-[0-9a-f]{8}$")
@@ -171,6 +172,7 @@ def _make_ctx(self, container_image=None):
171172
container_image=container_image,
172173
zone="us-central1-a",
173174
project="proj",
175+
cluster_name="keras-remote-cluster",
174176
)
175177

176178
def test_success_flow(self):

keras_remote/cli/commands/down.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,18 @@
11
"""keras-remote down command — tear down infrastructure."""
22

33
import click
4-
import pulumi.automation as auto
54

65
from keras_remote.cli.config import InfraConfig
76
from keras_remote.cli.constants import DEFAULT_CLUSTER_NAME, DEFAULT_ZONE
8-
from keras_remote.cli.infra.program import create_program
9-
from keras_remote.cli.infra.stack_manager import get_stack
10-
from keras_remote.cli.output import banner, console, success, warning
7+
from keras_remote.cli.infra.state import apply_destroy
8+
from keras_remote.cli.options import common_options
9+
from keras_remote.cli.output import banner, console, warning
1110
from keras_remote.cli.prerequisites_check import check_all
1211
from keras_remote.cli.prompts import resolve_project
1312

1413

1514
@click.command()
16-
@click.option(
17-
"--project",
18-
envvar="KERAS_REMOTE_PROJECT",
19-
default=None,
20-
help="GCP project ID [env: KERAS_REMOTE_PROJECT]",
21-
)
22-
@click.option(
23-
"--zone",
24-
envvar="KERAS_REMOTE_ZONE",
25-
default=None,
26-
help=(f"GCP zone [env: KERAS_REMOTE_ZONE, default: {DEFAULT_ZONE}]"),
27-
)
28-
@click.option(
29-
"--cluster",
30-
"cluster_name",
31-
envvar="KERAS_REMOTE_CLUSTER",
32-
default=None,
33-
help="GKE cluster name [default: keras-remote-cluster]",
34-
)
15+
@common_options
3516
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt")
3617
def down(project, zone, cluster_name, yes):
3718
"""Tear down keras-remote GCP infrastructure."""
@@ -60,19 +41,7 @@ def down(project, zone, cluster_name, yes):
6041
console.print()
6142

6243
config = InfraConfig(project=project, zone=zone, cluster_name=cluster_name)
63-
64-
# Pulumi destroy
65-
try:
66-
# Minimal config to load the stack — accelerator is not
67-
# needed for destroy since the stack already has its state.
68-
program = create_program(config)
69-
stack = get_stack(program, config)
70-
console.print("[bold]Destroying Pulumi-managed resources...[/bold]\n")
71-
result = stack.destroy(on_output=print)
72-
console.print()
73-
success(f"Pulumi destroy complete. {result.summary.resource_changes}")
74-
except auto.errors.CommandError as e:
75-
warning(f"Pulumi destroy encountered an issue: {e}")
44+
apply_destroy(config)
7645

7746
# Summary
7847
console.print()
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Tests for keras_remote.cli.commands.down — destroy infrastructure."""
2+
3+
from unittest import mock
4+
5+
from absl.testing import absltest
6+
from click.testing import CliRunner
7+
8+
from keras_remote.cli.commands.down import down
9+
10+
# Shared CLI args that skip interactive prompts.
11+
_CLI_ARGS = [
12+
"--project",
13+
"test-project",
14+
"--zone",
15+
"us-central2-b",
16+
"--yes",
17+
]
18+
19+
# Patches applied to every test to bypass prerequisites and infrastructure.
20+
_BASE_PATCHES = {
21+
"check_all": mock.patch("keras_remote.cli.commands.down.check_all"),
22+
"resolve_project": mock.patch(
23+
"keras_remote.cli.commands.down.resolve_project",
24+
return_value="test-project",
25+
),
26+
"apply_destroy": mock.patch(
27+
"keras_remote.cli.commands.down.apply_destroy", return_value=True
28+
),
29+
}
30+
31+
32+
def _start_patches(test_case):
33+
"""Start all base patches and return a dict of mock objects."""
34+
mocks = {}
35+
for name, patcher in _BASE_PATCHES.items():
36+
mocks[name] = test_case.enterContext(patcher)
37+
return mocks
38+
39+
40+
class DownCommandTest(absltest.TestCase):
41+
def setUp(self):
42+
super().setUp()
43+
self.runner = CliRunner()
44+
self.mocks = _start_patches(self)
45+
46+
def test_successful_destroy(self):
47+
"""Successful destroy — exit code 0, 'Cleanup Complete' shown."""
48+
result = self.runner.invoke(down, _CLI_ARGS)
49+
50+
self.assertEqual(result.exit_code, 0, result.output)
51+
self.assertIn("Cleanup Complete", result.output)
52+
self.mocks["apply_destroy"].assert_called_once()
53+
54+
def test_destroy_failure_still_shows_summary(self):
55+
"""apply_destroy returns False — summary still displayed."""
56+
self.mocks["apply_destroy"].return_value = False
57+
58+
result = self.runner.invoke(down, _CLI_ARGS)
59+
60+
self.assertEqual(result.exit_code, 0, result.output)
61+
self.assertIn("Cleanup Complete", result.output)
62+
self.assertIn("Check manually", result.output)
63+
64+
def test_abort_on_no_confirmation(self):
65+
"""User declines confirmation — apply_destroy not called."""
66+
args = ["--project", "test-project", "--zone", "us-central2-b"]
67+
68+
self.runner.invoke(down, args, input="n\n")
69+
70+
self.mocks["apply_destroy"].assert_not_called()
71+
72+
def test_yes_flag_skips_confirmation(self):
73+
"""--yes flag skips confirmation prompt."""
74+
result = self.runner.invoke(down, _CLI_ARGS)
75+
76+
self.assertEqual(result.exit_code, 0, result.output)
77+
self.mocks["apply_destroy"].assert_called_once()
78+
79+
def test_config_passed_correctly(self):
80+
"""InfraConfig args match CLI options."""
81+
args = [
82+
"--project",
83+
"my-proj",
84+
"--zone",
85+
"europe-west1-b",
86+
"--cluster",
87+
"my-cluster",
88+
"--yes",
89+
]
90+
91+
result = self.runner.invoke(down, args)
92+
93+
self.assertEqual(result.exit_code, 0, result.output)
94+
config = self.mocks["apply_destroy"].call_args[0][0]
95+
self.assertEqual(config.project, "my-proj")
96+
self.assertEqual(config.zone, "europe-west1-b")
97+
self.assertEqual(config.cluster_name, "my-cluster")
98+
99+
def test_resolve_project_allow_create_false(self):
100+
"""When --project not given, resolve_project(allow_create=False) is called."""
101+
args = ["--zone", "us-central1-a", "--yes"]
102+
103+
result = self.runner.invoke(down, args)
104+
105+
self.assertEqual(result.exit_code, 0, result.output)
106+
self.mocks["resolve_project"].assert_called_once_with(allow_create=False)
107+
108+
109+
if __name__ == "__main__":
110+
absltest.main()

0 commit comments

Comments
 (0)