Skip to content

Commit bf7094e

Browse files
Adds unit tests for @run, remote runner, and prereq checks (#35)
* Adds tests for run decorator, remote runner, and CLI prerequisites Tests cover: - run() decorator behavior, env var capture (exact, wildcard, mixed) - Remote runner GCS helpers, execution flow, and error handling - CLI prerequisite checks for gcloud, pulumi, kubectl, docker, and auth * address reviews * Migrate decorator, runner, and CLI tests from pytest to absl testing * test fix * Reduce duplication * remove unnecessary mock
1 parent c6b4dad commit bf7094e

File tree

3 files changed

+514
-0
lines changed

3 files changed

+514
-0
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""Tests for keras_remote.cli.prerequisites_check — tool availability checks."""
2+
3+
from unittest import mock
4+
5+
import click
6+
from absl.testing import absltest, parameterized
7+
8+
from keras_remote.cli.prerequisites_check import (
9+
check_docker,
10+
check_gcloud,
11+
check_gcloud_auth,
12+
check_kubectl,
13+
check_pulumi,
14+
)
15+
16+
17+
class TestToolChecks(parameterized.TestCase):
18+
@parameterized.named_parameters(
19+
dict(
20+
testcase_name="gcloud",
21+
check_fn=check_gcloud,
22+
error_match="gcloud CLI not found",
23+
),
24+
dict(
25+
testcase_name="pulumi",
26+
check_fn=check_pulumi,
27+
error_match="Pulumi CLI not found",
28+
),
29+
dict(
30+
testcase_name="kubectl",
31+
check_fn=check_kubectl,
32+
error_match="kubectl not found",
33+
),
34+
dict(
35+
testcase_name="docker",
36+
check_fn=check_docker,
37+
error_match="Docker not found",
38+
),
39+
)
40+
def test_present(self, check_fn, error_match):
41+
with mock.patch("shutil.which", return_value="/usr/bin/tool"):
42+
check_fn()
43+
44+
@parameterized.named_parameters(
45+
dict(
46+
testcase_name="gcloud",
47+
check_fn=check_gcloud,
48+
error_match="gcloud CLI not found",
49+
),
50+
dict(
51+
testcase_name="pulumi",
52+
check_fn=check_pulumi,
53+
error_match="Pulumi CLI not found",
54+
),
55+
dict(
56+
testcase_name="kubectl",
57+
check_fn=check_kubectl,
58+
error_match="kubectl not found",
59+
),
60+
dict(
61+
testcase_name="docker",
62+
check_fn=check_docker,
63+
error_match="Docker not found",
64+
),
65+
)
66+
def test_missing(self, check_fn, error_match):
67+
with (
68+
mock.patch("shutil.which", return_value=None),
69+
self.assertRaisesRegex(click.ClickException, error_match),
70+
):
71+
check_fn()
72+
73+
74+
class TestCheckGcloudAuth(absltest.TestCase):
75+
def test_token_success(self):
76+
"""When print-access-token succeeds, no login is triggered."""
77+
with mock.patch(
78+
"keras_remote.cli.prerequisites_check.subprocess.run",
79+
) as mock_run:
80+
mock_run.return_value.returncode = 0
81+
check_gcloud_auth()
82+
# Only called once (the token check), not a second time for login
83+
self.assertEqual(mock_run.call_count, 1)
84+
85+
def test_token_failure_triggers_login(self):
86+
"""When print-access-token fails, gcloud auth login is run."""
87+
with (
88+
mock.patch(
89+
"keras_remote.cli.prerequisites_check.subprocess.run",
90+
) as mock_run,
91+
mock.patch("keras_remote.cli.prerequisites_check.warning"),
92+
mock.patch("click.echo"),
93+
):
94+
token_result = mock.MagicMock()
95+
token_result.returncode = 1
96+
mock_run.return_value = token_result
97+
98+
check_gcloud_auth()
99+
100+
self.assertEqual(mock_run.call_count, 2)
101+
# Second call should be the login command
102+
login_call = mock_run.call_args_list[1]
103+
self.assertIn("login", login_call[0][0])
104+
105+
106+
if __name__ == "__main__":
107+
absltest.main()

keras_remote/core/test_core.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
"""Tests for keras_remote.core.core — run decorator and env var capture."""
2+
3+
import os
4+
from unittest import mock
5+
from unittest.mock import MagicMock
6+
7+
from absl.testing import absltest
8+
9+
from keras_remote.core.core import run
10+
11+
12+
class TestRunDecorator(absltest.TestCase):
13+
def test_decorator_preserves_wrapped_function(self):
14+
@run(accelerator="cpu")
15+
def my_func():
16+
"""My docstring."""
17+
18+
self.assertTrue(callable(my_func))
19+
self.assertEqual(my_func.__name__, "my_func")
20+
self.assertEqual(my_func.__doc__, "My docstring.")
21+
22+
23+
class TestEnvVarCapture(absltest.TestCase):
24+
def test_exact_match(self):
25+
with (
26+
mock.patch.dict(os.environ, {"MY_VAR": "my_val"}),
27+
mock.patch("keras_remote.core.core._execute_on_gke") as mock_exec,
28+
):
29+
30+
@run(accelerator="cpu", capture_env_vars=["MY_VAR"])
31+
def func():
32+
pass
33+
34+
func()
35+
call_args = mock_exec.call_args
36+
env_vars = call_args[0][-1] # last positional arg
37+
self.assertEqual(env_vars, {"MY_VAR": "my_val"})
38+
39+
def test_wildcard_pattern(self):
40+
env = {
41+
k: v
42+
for k, v in os.environ.items()
43+
if k not in ("PREFIX_A", "PREFIX_B", "OTHER")
44+
}
45+
env.update({"PREFIX_A": "1", "PREFIX_B": "2", "OTHER": "3"})
46+
with (
47+
mock.patch.dict(os.environ, env, clear=True),
48+
mock.patch("keras_remote.core.core._execute_on_gke") as mock_exec,
49+
):
50+
51+
@run(accelerator="cpu", capture_env_vars=["PREFIX_*"])
52+
def func():
53+
pass
54+
55+
func()
56+
env_vars = mock_exec.call_args[0][-1]
57+
self.assertIn("PREFIX_A", env_vars)
58+
self.assertIn("PREFIX_B", env_vars)
59+
self.assertNotIn("OTHER", env_vars)
60+
61+
def test_missing_var_skipped(self):
62+
env = {k: v for k, v in os.environ.items() if k != "NONEXISTENT"}
63+
with (
64+
mock.patch.dict(os.environ, env, clear=True),
65+
mock.patch("keras_remote.core.core._execute_on_gke") as mock_exec,
66+
):
67+
68+
@run(accelerator="cpu", capture_env_vars=["NONEXISTENT"])
69+
def func():
70+
pass
71+
72+
func()
73+
env_vars = mock_exec.call_args[0][-1]
74+
self.assertEqual(env_vars, {})
75+
76+
def test_none_capture(self):
77+
with mock.patch("keras_remote.core.core._execute_on_gke") as mock_exec:
78+
79+
@run(accelerator="cpu", capture_env_vars=None)
80+
def func():
81+
pass
82+
83+
func()
84+
env_vars = mock_exec.call_args[0][-1]
85+
self.assertEqual(env_vars, {})
86+
87+
def test_mixed_exact_and_wildcard(self):
88+
env = {
89+
k: v
90+
for k, v in os.environ.items()
91+
if k not in ("EXACT_VAR", "WILD_A", "WILD_B")
92+
}
93+
env.update({"EXACT_VAR": "exact", "WILD_A": "a", "WILD_B": "b"})
94+
with (
95+
mock.patch.dict(os.environ, env, clear=True),
96+
mock.patch("keras_remote.core.core._execute_on_gke") as mock_exec,
97+
):
98+
99+
@run(
100+
accelerator="cpu",
101+
capture_env_vars=["EXACT_VAR", "WILD_*"],
102+
)
103+
def func():
104+
pass
105+
106+
func()
107+
env_vars = mock_exec.call_args[0][-1]
108+
self.assertEqual(
109+
env_vars, {"EXACT_VAR": "exact", "WILD_A": "a", "WILD_B": "b"}
110+
)
111+
112+
113+
class TestExecuteOnGkeDefaults(absltest.TestCase):
114+
def test_cluster_from_env(self):
115+
"""When cluster=None, falls back to KERAS_REMOTE_CLUSTER env var."""
116+
with (
117+
mock.patch.dict(
118+
os.environ,
119+
{
120+
"KERAS_REMOTE_CLUSTER": "env-cluster",
121+
"KERAS_REMOTE_PROJECT": "proj",
122+
},
123+
),
124+
mock.patch(
125+
"keras_remote.core.core.execute_remote",
126+
return_value=42,
127+
) as mock_exec,
128+
mock.patch(
129+
"keras_remote.core.core.JobContext.from_params",
130+
return_value=MagicMock(),
131+
),
132+
):
133+
134+
@run(accelerator="cpu", cluster=None)
135+
def func():
136+
pass
137+
138+
func()
139+
140+
call_args = mock_exec.call_args
141+
backend = call_args[0][1]
142+
self.assertEqual(backend.cluster, "env-cluster")
143+
144+
def test_namespace_from_env(self):
145+
"""When namespace=None, falls back to KERAS_REMOTE_GKE_NAMESPACE env var."""
146+
with (
147+
mock.patch.dict(
148+
os.environ,
149+
{
150+
"KERAS_REMOTE_GKE_NAMESPACE": "custom-ns",
151+
"KERAS_REMOTE_PROJECT": "proj",
152+
},
153+
),
154+
mock.patch(
155+
"keras_remote.core.core.execute_remote",
156+
return_value=42,
157+
) as mock_exec,
158+
mock.patch(
159+
"keras_remote.core.core.JobContext.from_params",
160+
return_value=MagicMock(),
161+
),
162+
):
163+
164+
@run(accelerator="cpu", namespace=None)
165+
def func():
166+
pass
167+
168+
func()
169+
170+
backend = mock_exec.call_args[0][1]
171+
self.assertEqual(backend.namespace, "custom-ns")
172+
173+
174+
if __name__ == "__main__":
175+
absltest.main()

0 commit comments

Comments
 (0)