Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions keras_remote/cli/infra/stack_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Pulumi Automation API wrapper for keras-remote."""

import os
import shutil

import click
import pulumi.automation as auto

from keras_remote.cli.constants import RESOURCE_NAME_PREFIX, STATE_DIR
Expand All @@ -19,6 +21,11 @@ def get_stack(program_fn, config):
"""
os.makedirs(STATE_DIR, exist_ok=True)

# Auto-install the Pulumi CLI if not already present.
if not shutil.which("pulumi"):
click.echo("Pulumi CLI not found. Installing...")
pulumi_cmd = auto.PulumiCommand.install()

# Use project ID as stack name so each GCP project gets its own stack
stack_name = config.project

Expand All @@ -35,6 +42,7 @@ def get_stack(program_fn, config):
opts=auto.LocalWorkspaceOptions(
project_settings=project_settings,
env_vars={"PULUMI_CONFIG_PASSPHRASE": ""},
pulumi_command=pulumi_cmd,
),
)

Expand Down
12 changes: 1 addition & 11 deletions keras_remote/cli/prerequisites_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

Delegates common credential checks (gcloud, auth plugin, ADC) to
:mod:`keras_remote.credentials` and converts ``RuntimeError`` into
``click.ClickException``. CLI-only tool checks (Pulumi, kubectl) remain
here.
``click.ClickException``. CLI-only tool checks (kubectl) remain here.
"""

import shutil
Expand All @@ -21,14 +20,6 @@ def check_gcloud():
raise click.ClickException(str(e)) # noqa: B904


def check_pulumi():
"""Verify Pulumi CLI is installed (required by Automation API)."""
if not shutil.which("pulumi"):
raise click.ClickException(
"Pulumi CLI not found. Install from: https://www.pulumi.com/docs/install/"
)


def check_kubectl():
"""Verify kubectl is installed."""
if not shutil.which("kubectl"):
Expand Down Expand Up @@ -56,7 +47,6 @@ def check_gcloud_auth():
def check_all():
"""Run all prerequisite checks."""
check_gcloud()
check_pulumi()
check_kubectl()
check_gke_auth_plugin()
check_gcloud_auth()
47 changes: 11 additions & 36 deletions keras_remote/cli/prerequisites_check_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,56 +3,31 @@
from unittest import mock

import click
from absl.testing import absltest, parameterized
from absl.testing import absltest

from keras_remote.cli.prerequisites_check import (
check_gcloud,
check_gcloud_auth,
check_gke_auth_plugin,
check_kubectl,
check_pulumi,
)

_MODULE = "keras_remote.cli.prerequisites_check"


class TestToolChecks(parameterized.TestCase):
"""Tests for CLI-only tool checks (pulumi, kubectl)."""

@parameterized.named_parameters(
dict(
testcase_name="pulumi",
check_fn=check_pulumi,
error_match="Pulumi CLI not found",
),
dict(
testcase_name="kubectl",
check_fn=check_kubectl,
error_match="kubectl not found",
),
)
def test_present(self, check_fn, error_match):
with mock.patch("shutil.which", return_value="/usr/bin/tool"):
check_fn()

@parameterized.named_parameters(
dict(
testcase_name="pulumi",
check_fn=check_pulumi,
error_match="Pulumi CLI not found",
),
dict(
testcase_name="kubectl",
check_fn=check_kubectl,
error_match="kubectl not found",
),
)
def test_missing(self, check_fn, error_match):
class TestToolChecks(absltest.TestCase):
"""Tests for CLI-only tool checks (kubectl)."""

def test_kubectl_present(self):
with mock.patch("shutil.which", return_value="/usr/bin/kubectl"):
check_kubectl()

def test_kubectl_missing(self):
with (
mock.patch("shutil.which", return_value=None),
self.assertRaisesRegex(click.ClickException, error_match),
self.assertRaisesRegex(click.ClickException, "kubectl not found"),
):
check_fn()
check_kubectl()


class TestDelegatedChecks(absltest.TestCase):
Expand Down
Loading