Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions keras_remote/cli/test_prerequisites_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Tests for keras_remote.cli.prerequisites_check — tool availability checks."""

from unittest import mock

import click
from absl.testing import absltest, parameterized

from keras_remote.cli.prerequisites_check import (
check_docker,
check_gcloud,
check_gcloud_auth,
check_kubectl,
check_pulumi,
)


class TestToolChecks(parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name="gcloud",
check_fn=check_gcloud,
error_match="gcloud CLI not found",
),
dict(
testcase_name="pulumi",
check_fn=check_pulumi,
error_match="Pulumi CLI not found",
),
dict(
testcase_name="kubectl",
check_fn=check_kubectl,
error_match="kubectl not found",
),
dict(
testcase_name="docker",
check_fn=check_docker,
error_match="Docker not found",
),
)
def test_present(self, check_fn, error_match):
with mock.patch("shutil.which", return_value="/usr/bin/tool"):
check_fn()

@parameterized.named_parameters(
dict(
testcase_name="gcloud",
check_fn=check_gcloud,
error_match="gcloud CLI not found",
),
dict(
testcase_name="pulumi",
check_fn=check_pulumi,
error_match="Pulumi CLI not found",
),
dict(
testcase_name="kubectl",
check_fn=check_kubectl,
error_match="kubectl not found",
),
dict(
testcase_name="docker",
check_fn=check_docker,
error_match="Docker not found",
),
)
def test_missing(self, check_fn, error_match):
with (
mock.patch("shutil.which", return_value=None),
self.assertRaisesRegex(click.ClickException, error_match),
):
check_fn()


class TestCheckGcloudAuth(absltest.TestCase):
def test_token_success(self):
"""When print-access-token succeeds, no login is triggered."""
with mock.patch(
"keras_remote.cli.prerequisites_check.subprocess.run",
) as mock_run:
mock_run.return_value.returncode = 0
check_gcloud_auth()
# Only called once (the token check), not a second time for login
self.assertEqual(mock_run.call_count, 1)

def test_token_failure_triggers_login(self):
"""When print-access-token fails, gcloud auth login is run."""
with (
mock.patch(
"keras_remote.cli.prerequisites_check.subprocess.run",
) as mock_run,
mock.patch("keras_remote.cli.prerequisites_check.warning"),
mock.patch("click.echo"),
):
token_result = mock.MagicMock()
token_result.returncode = 1
mock_run.return_value = token_result

check_gcloud_auth()

self.assertEqual(mock_run.call_count, 2)
# Second call should be the login command
login_call = mock_run.call_args_list[1]
self.assertIn("login", login_call[0][0])


if __name__ == "__main__":
absltest.main()
175 changes: 175 additions & 0 deletions keras_remote/core/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Tests for keras_remote.core.core — run decorator and env var capture."""

import os
from unittest import mock
from unittest.mock import MagicMock

from absl.testing import absltest

from keras_remote.core.core import run


class TestRunDecorator(absltest.TestCase):
def test_decorator_preserves_wrapped_function(self):
@run(accelerator="cpu")
def my_func():
"""My docstring."""

self.assertTrue(callable(my_func))
self.assertEqual(my_func.__name__, "my_func")
self.assertEqual(my_func.__doc__, "My docstring.")


class TestEnvVarCapture(absltest.TestCase):
def test_exact_match(self):
with (
mock.patch.dict(os.environ, {"MY_VAR": "my_val"}),
mock.patch("keras_remote.core.core._execute_on_gke") as mock_exec,
):

@run(accelerator="cpu", capture_env_vars=["MY_VAR"])
def func():
pass

func()
call_args = mock_exec.call_args
env_vars = call_args[0][-1] # last positional arg
self.assertEqual(env_vars, {"MY_VAR": "my_val"})

def test_wildcard_pattern(self):
env = {
k: v
for k, v in os.environ.items()
if k not in ("PREFIX_A", "PREFIX_B", "OTHER")
}
env.update({"PREFIX_A": "1", "PREFIX_B": "2", "OTHER": "3"})
with (
mock.patch.dict(os.environ, env, clear=True),
mock.patch("keras_remote.core.core._execute_on_gke") as mock_exec,
):

@run(accelerator="cpu", capture_env_vars=["PREFIX_*"])
def func():
pass

func()
env_vars = mock_exec.call_args[0][-1]
self.assertIn("PREFIX_A", env_vars)
self.assertIn("PREFIX_B", env_vars)
self.assertNotIn("OTHER", env_vars)

def test_missing_var_skipped(self):
env = {k: v for k, v in os.environ.items() if k != "NONEXISTENT"}
with (
mock.patch.dict(os.environ, env, clear=True),
mock.patch("keras_remote.core.core._execute_on_gke") as mock_exec,
):

@run(accelerator="cpu", capture_env_vars=["NONEXISTENT"])
def func():
pass

func()
env_vars = mock_exec.call_args[0][-1]
self.assertEqual(env_vars, {})

def test_none_capture(self):
with mock.patch("keras_remote.core.core._execute_on_gke") as mock_exec:

@run(accelerator="cpu", capture_env_vars=None)
def func():
pass

func()
env_vars = mock_exec.call_args[0][-1]
self.assertEqual(env_vars, {})

def test_mixed_exact_and_wildcard(self):
env = {
k: v
for k, v in os.environ.items()
if k not in ("EXACT_VAR", "WILD_A", "WILD_B")
}
env.update({"EXACT_VAR": "exact", "WILD_A": "a", "WILD_B": "b"})
with (
mock.patch.dict(os.environ, env, clear=True),
mock.patch("keras_remote.core.core._execute_on_gke") as mock_exec,
):

@run(
accelerator="cpu",
capture_env_vars=["EXACT_VAR", "WILD_*"],
)
def func():
pass

func()
env_vars = mock_exec.call_args[0][-1]
self.assertEqual(
env_vars, {"EXACT_VAR": "exact", "WILD_A": "a", "WILD_B": "b"}
)


class TestExecuteOnGkeDefaults(absltest.TestCase):
def test_cluster_from_env(self):
"""When cluster=None, falls back to KERAS_REMOTE_CLUSTER env var."""
with (
mock.patch.dict(
os.environ,
{
"KERAS_REMOTE_CLUSTER": "env-cluster",
"KERAS_REMOTE_PROJECT": "proj",
},
),
mock.patch(
"keras_remote.core.core.execute_remote",
return_value=42,
) as mock_exec,
mock.patch(
"keras_remote.core.core.JobContext.from_params",
return_value=MagicMock(),
),
):

@run(accelerator="cpu", cluster=None)
def func():
pass

func()

call_args = mock_exec.call_args
backend = call_args[0][1]
self.assertEqual(backend.cluster, "env-cluster")

def test_namespace_from_env(self):
"""When namespace=None, falls back to KERAS_REMOTE_GKE_NAMESPACE env var."""
with (
mock.patch.dict(
os.environ,
{
"KERAS_REMOTE_GKE_NAMESPACE": "custom-ns",
"KERAS_REMOTE_PROJECT": "proj",
},
),
mock.patch(
"keras_remote.core.core.execute_remote",
return_value=42,
) as mock_exec,
mock.patch(
"keras_remote.core.core.JobContext.from_params",
return_value=MagicMock(),
),
):

@run(accelerator="cpu", namespace=None)
def func():
pass

func()

backend = mock_exec.call_args[0][1]
self.assertEqual(backend.namespace, "custom-ns")


if __name__ == "__main__":
absltest.main()
Loading
Loading