Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_remote/cli/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"KERAS_REMOTE_STATE_DIR",
os.path.expanduser("~/.keras-remote/pulumi"),
)
PULUMI_ROOT = os.path.expanduser("~/.keras-remote/pulumi-cli")
REQUIRED_APIS = [
"compute.googleapis.com",
"cloudbuild.googleapis.com",
Expand Down
15 changes: 14 additions & 1 deletion keras_remote/cli/infra/stack_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

import os

import click
import pulumi.automation as auto

from keras_remote.cli.constants import RESOURCE_NAME_PREFIX, STATE_DIR
from keras_remote.cli.constants import (
PULUMI_ROOT,
RESOURCE_NAME_PREFIX,
STATE_DIR,
)


def get_stack(program_fn, config):
Expand All @@ -19,6 +24,13 @@ def get_stack(program_fn, config):
"""
os.makedirs(STATE_DIR, exist_ok=True)

# Auto-install the Pulumi CLI if not already present.
try:
pulumi_cmd = auto.PulumiCommand(root=PULUMI_ROOT)
except Exception: # noqa: BLE001
click.echo("Pulumi CLI not found. Installing...")
pulumi_cmd = auto.PulumiCommand.install(root=PULUMI_ROOT)

# Use project ID as stack name so each GCP project gets its own stack
stack_name = config.project

Expand All @@ -35,6 +47,7 @@ def get_stack(program_fn, config):
opts=auto.LocalWorkspaceOptions(
project_settings=project_settings,
env_vars={"PULUMI_CONFIG_PASSPHRASE": ""},
pulumi_command=pulumi_cmd,
),
)

Expand Down
12 changes: 1 addition & 11 deletions keras_remote/cli/prerequisites_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

Delegates common credential checks (gcloud, auth plugin, ADC) to
:mod:`keras_remote.credentials` and converts ``RuntimeError`` into
``click.ClickException``. CLI-only tool checks (Pulumi, kubectl) remain
here.
``click.ClickException``. CLI-only tool checks (kubectl) remain here.
"""

import shutil
Expand All @@ -21,14 +20,6 @@ def check_gcloud():
raise click.ClickException(str(e)) # noqa: B904


def check_pulumi():
"""Verify Pulumi CLI is installed (required by Automation API)."""
if not shutil.which("pulumi"):
raise click.ClickException(
"Pulumi CLI not found. Install from: https://www.pulumi.com/docs/install/"
)


def check_kubectl():
"""Verify kubectl is installed."""
if not shutil.which("kubectl"):
Expand Down Expand Up @@ -56,7 +47,6 @@ def check_gcloud_auth():
def check_all():
"""Run all prerequisite checks."""
check_gcloud()
check_pulumi()
check_kubectl()
check_gke_auth_plugin()
check_gcloud_auth()
47 changes: 11 additions & 36 deletions keras_remote/cli/prerequisites_check_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,56 +3,31 @@
from unittest import mock

import click
from absl.testing import absltest, parameterized
from absl.testing import absltest

from keras_remote.cli.prerequisites_check import (
check_gcloud,
check_gcloud_auth,
check_gke_auth_plugin,
check_kubectl,
check_pulumi,
)

_MODULE = "keras_remote.cli.prerequisites_check"


class TestToolChecks(parameterized.TestCase):
"""Tests for CLI-only tool checks (pulumi, kubectl)."""

@parameterized.named_parameters(
dict(
testcase_name="pulumi",
check_fn=check_pulumi,
error_match="Pulumi CLI not found",
),
dict(
testcase_name="kubectl",
check_fn=check_kubectl,
error_match="kubectl not found",
),
)
def test_present(self, check_fn, error_match):
with mock.patch("shutil.which", return_value="/usr/bin/tool"):
check_fn()

@parameterized.named_parameters(
dict(
testcase_name="pulumi",
check_fn=check_pulumi,
error_match="Pulumi CLI not found",
),
dict(
testcase_name="kubectl",
check_fn=check_kubectl,
error_match="kubectl not found",
),
)
def test_missing(self, check_fn, error_match):
class TestToolChecks(absltest.TestCase):
"""Tests for CLI-only tool checks (kubectl)."""

def test_kubectl_present(self):
with mock.patch("shutil.which", return_value="/usr/bin/kubectl"):
check_kubectl()

def test_kubectl_missing(self):
with (
mock.patch("shutil.which", return_value=None),
self.assertRaisesRegex(click.ClickException, error_match),
self.assertRaisesRegex(click.ClickException, "kubectl not found"),
):
check_fn()
check_kubectl()


class TestDelegatedChecks(absltest.TestCase):
Expand Down
Loading