Skip to content

Commit 786604d

Browse files
Auto-installs pulumi if it's not available on the user's machine (#59)
* Auto-installs pulumi if it's not available on the user's machine * Install pulumi in keras-remote HOME
1 parent 686e314 commit 786604d

File tree

4 files changed

+27
-48
lines changed

4 files changed

+27
-48
lines changed

keras_remote/cli/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"KERAS_REMOTE_STATE_DIR",
1313
os.path.expanduser("~/.keras-remote/pulumi"),
1414
)
15+
PULUMI_ROOT = os.path.expanduser("~/.keras-remote/pulumi-cli")
1516
REQUIRED_APIS = [
1617
"compute.googleapis.com",
1718
"cloudbuild.googleapis.com",

keras_remote/cli/infra/stack_manager.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22

33
import os
44

5+
import click
56
import pulumi.automation as auto
67

7-
from keras_remote.cli.constants import RESOURCE_NAME_PREFIX, STATE_DIR
8+
from keras_remote.cli.constants import (
9+
PULUMI_ROOT,
10+
RESOURCE_NAME_PREFIX,
11+
STATE_DIR,
12+
)
813

914

1015
def get_stack(program_fn, config):
@@ -19,6 +24,13 @@ def get_stack(program_fn, config):
1924
"""
2025
os.makedirs(STATE_DIR, exist_ok=True)
2126

27+
# Auto-install the Pulumi CLI if not already present.
28+
try:
29+
pulumi_cmd = auto.PulumiCommand(root=PULUMI_ROOT)
30+
except Exception: # noqa: BLE001
31+
click.echo("Pulumi CLI not found. Installing...")
32+
pulumi_cmd = auto.PulumiCommand.install(root=PULUMI_ROOT)
33+
2234
# Use project ID as stack name so each GCP project gets its own stack
2335
stack_name = config.project
2436

@@ -35,6 +47,7 @@ def get_stack(program_fn, config):
3547
opts=auto.LocalWorkspaceOptions(
3648
project_settings=project_settings,
3749
env_vars={"PULUMI_CONFIG_PASSPHRASE": ""},
50+
pulumi_command=pulumi_cmd,
3851
),
3952
)
4053

keras_remote/cli/prerequisites_check.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
33
Delegates common credential checks (gcloud, auth plugin, ADC) to
44
:mod:`keras_remote.credentials` and converts ``RuntimeError`` into
5-
``click.ClickException``. CLI-only tool checks (Pulumi, kubectl) remain
6-
here.
5+
``click.ClickException``. CLI-only tool checks (kubectl) remain here.
76
"""
87

98
import shutil
@@ -21,14 +20,6 @@ def check_gcloud():
2120
raise click.ClickException(str(e)) # noqa: B904
2221

2322

24-
def check_pulumi():
25-
"""Verify Pulumi CLI is installed (required by Automation API)."""
26-
if not shutil.which("pulumi"):
27-
raise click.ClickException(
28-
"Pulumi CLI not found. Install from: https://www.pulumi.com/docs/install/"
29-
)
30-
31-
3223
def check_kubectl():
3324
"""Verify kubectl is installed."""
3425
if not shutil.which("kubectl"):
@@ -56,7 +47,6 @@ def check_gcloud_auth():
5647
def check_all():
5748
"""Run all prerequisite checks."""
5849
check_gcloud()
59-
check_pulumi()
6050
check_kubectl()
6151
check_gke_auth_plugin()
6252
check_gcloud_auth()

keras_remote/cli/prerequisites_check_test.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,56 +3,31 @@
33
from unittest import mock
44

55
import click
6-
from absl.testing import absltest, parameterized
6+
from absl.testing import absltest
77

88
from keras_remote.cli.prerequisites_check import (
99
check_gcloud,
1010
check_gcloud_auth,
1111
check_gke_auth_plugin,
1212
check_kubectl,
13-
check_pulumi,
1413
)
1514

1615
_MODULE = "keras_remote.cli.prerequisites_check"
1716

1817

19-
class TestToolChecks(parameterized.TestCase):
20-
"""Tests for CLI-only tool checks (pulumi, kubectl)."""
21-
22-
@parameterized.named_parameters(
23-
dict(
24-
testcase_name="pulumi",
25-
check_fn=check_pulumi,
26-
error_match="Pulumi CLI not found",
27-
),
28-
dict(
29-
testcase_name="kubectl",
30-
check_fn=check_kubectl,
31-
error_match="kubectl not found",
32-
),
33-
)
34-
def test_present(self, check_fn, error_match):
35-
with mock.patch("shutil.which", return_value="/usr/bin/tool"):
36-
check_fn()
37-
38-
@parameterized.named_parameters(
39-
dict(
40-
testcase_name="pulumi",
41-
check_fn=check_pulumi,
42-
error_match="Pulumi CLI not found",
43-
),
44-
dict(
45-
testcase_name="kubectl",
46-
check_fn=check_kubectl,
47-
error_match="kubectl not found",
48-
),
49-
)
50-
def test_missing(self, check_fn, error_match):
18+
class TestToolChecks(absltest.TestCase):
19+
"""Tests for CLI-only tool checks (kubectl)."""
20+
21+
def test_kubectl_present(self):
22+
with mock.patch("shutil.which", return_value="/usr/bin/kubectl"):
23+
check_kubectl()
24+
25+
def test_kubectl_missing(self):
5126
with (
5227
mock.patch("shutil.which", return_value=None),
53-
self.assertRaisesRegex(click.ClickException, error_match),
28+
self.assertRaisesRegex(click.ClickException, "kubectl not found"),
5429
):
55-
check_fn()
30+
check_kubectl()
5631

5732

5833
class TestDelegatedChecks(absltest.TestCase):

0 commit comments

Comments
 (0)