Skip to content

Commit ba3be24

Browse files
Adds tests for run decorator, remote runner, and CLI prerequisites
Tests cover: - run() decorator behavior, env var capture (exact, wildcard, mixed) - Remote runner GCS helpers, execution flow, and error handling - CLI prerequisite checks for gcloud, pulumi, kubectl, docker, and auth
1 parent 67477a8 commit ba3be24

File tree

3 files changed

+520
-0
lines changed

3 files changed

+520
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Tests for keras_remote.cli.prerequisites_check — tool availability checks."""
2+
3+
import click
4+
import pytest
5+
6+
from keras_remote.cli.prerequisites_check import (
7+
check_docker,
8+
check_gcloud,
9+
check_gcloud_auth,
10+
check_kubectl,
11+
check_pulumi,
12+
)
13+
14+
15+
@pytest.mark.parametrize(
16+
("check_fn", "error_match"),
17+
[
18+
(check_gcloud, "gcloud CLI not found"),
19+
(check_pulumi, "Pulumi CLI not found"),
20+
(check_kubectl, "kubectl not found"),
21+
(check_docker, "Docker not found"),
22+
],
23+
ids=["gcloud", "pulumi", "kubectl", "docker"],
24+
)
25+
class TestToolChecks:
26+
def test_present(self, mocker, check_fn, error_match):
27+
mocker.patch("shutil.which", return_value="/usr/bin/tool")
28+
check_fn()
29+
30+
def test_missing(self, mocker, check_fn, error_match):
31+
mocker.patch("shutil.which", return_value=None)
32+
with pytest.raises(click.ClickException, match=error_match):
33+
check_fn()
34+
35+
36+
class TestCheckGcloudAuth:
37+
def test_token_success(self, mocker):
38+
"""When print-access-token succeeds, no login is triggered."""
39+
mock_run = mocker.patch(
40+
"keras_remote.cli.prerequisites_check.subprocess.run",
41+
)
42+
mock_run.return_value.returncode = 0
43+
44+
check_gcloud_auth()
45+
46+
# Only called once (the token check), not a second time for login
47+
assert mock_run.call_count == 1
48+
49+
def test_token_failure_triggers_login(self, mocker):
50+
"""When print-access-token fails, gcloud auth login is run."""
51+
mock_run = mocker.patch(
52+
"keras_remote.cli.prerequisites_check.subprocess.run",
53+
)
54+
# First call = token check (fails), second call = login (succeeds)
55+
token_result = mocker.MagicMock()
56+
token_result.returncode = 1
57+
mock_run.return_value = token_result
58+
59+
mocker.patch("keras_remote.cli.prerequisites_check.warning")
60+
mocker.patch("click.echo")
61+
62+
check_gcloud_auth()
63+
64+
assert mock_run.call_count == 2
65+
# Second call should be the login command
66+
login_call = mock_run.call_args_list[1]
67+
assert "login" in login_call[0][0]

keras_remote/core/test_core.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Tests for keras_remote.core.core — run decorator and env var capture."""
2+
3+
from unittest.mock import MagicMock
4+
5+
from keras_remote.core.core import run
6+
7+
8+
class TestRunDecorator:
9+
def test_decorator_preserves_wrapped_function(self):
10+
@run(accelerator="cpu")
11+
def my_func():
12+
"""My docstring."""
13+
14+
assert callable(my_func)
15+
assert my_func.__name__ == "my_func"
16+
assert my_func.__doc__ == "My docstring."
17+
18+
19+
class TestEnvVarCapture:
20+
def test_exact_match(self, monkeypatch, mocker):
21+
monkeypatch.setenv("MY_VAR", "my_val")
22+
mock_exec = mocker.patch("keras_remote.core.core._execute_on_gke")
23+
24+
@run(accelerator="cpu", capture_env_vars=["MY_VAR"])
25+
def func():
26+
pass
27+
28+
func()
29+
call_args = mock_exec.call_args
30+
env_vars = call_args[0][-1] # last positional arg
31+
assert env_vars == {"MY_VAR": "my_val"}
32+
33+
def test_wildcard_pattern(self, monkeypatch, mocker):
34+
monkeypatch.setenv("PREFIX_A", "1")
35+
monkeypatch.setenv("PREFIX_B", "2")
36+
monkeypatch.setenv("OTHER", "3")
37+
mock_exec = mocker.patch("keras_remote.core.core._execute_on_gke")
38+
39+
@run(accelerator="cpu", capture_env_vars=["PREFIX_*"])
40+
def func():
41+
pass
42+
43+
func()
44+
env_vars = mock_exec.call_args[0][-1]
45+
assert "PREFIX_A" in env_vars
46+
assert "PREFIX_B" in env_vars
47+
assert "OTHER" not in env_vars
48+
49+
def test_missing_var_skipped(self, monkeypatch, mocker):
50+
monkeypatch.delenv("NONEXISTENT", raising=False)
51+
mock_exec = mocker.patch("keras_remote.core.core._execute_on_gke")
52+
53+
@run(accelerator="cpu", capture_env_vars=["NONEXISTENT"])
54+
def func():
55+
pass
56+
57+
func()
58+
env_vars = mock_exec.call_args[0][-1]
59+
assert env_vars == {}
60+
61+
def test_none_capture(self, mocker):
62+
mock_exec = mocker.patch("keras_remote.core.core._execute_on_gke")
63+
64+
@run(accelerator="cpu", capture_env_vars=None)
65+
def func():
66+
pass
67+
68+
func()
69+
env_vars = mock_exec.call_args[0][-1]
70+
assert env_vars == {}
71+
72+
def test_mixed_exact_and_wildcard(self, monkeypatch, mocker):
73+
monkeypatch.setenv("EXACT_VAR", "exact")
74+
monkeypatch.setenv("WILD_A", "a")
75+
monkeypatch.setenv("WILD_B", "b")
76+
mock_exec = mocker.patch("keras_remote.core.core._execute_on_gke")
77+
78+
@run(
79+
accelerator="cpu",
80+
capture_env_vars=["EXACT_VAR", "WILD_*"],
81+
)
82+
def func():
83+
pass
84+
85+
func()
86+
env_vars = mock_exec.call_args[0][-1]
87+
assert env_vars == {"EXACT_VAR": "exact", "WILD_A": "a", "WILD_B": "b"}
88+
89+
90+
class TestExecuteOnGkeDefaults:
91+
def test_cluster_from_env(self, monkeypatch, mocker):
92+
"""When cluster=None, falls back to KERAS_REMOTE_GKE_CLUSTER env var."""
93+
monkeypatch.setenv("KERAS_REMOTE_GKE_CLUSTER", "env-cluster")
94+
monkeypatch.setenv("KERAS_REMOTE_PROJECT", "proj")
95+
mock_exec = mocker.patch(
96+
"keras_remote.core.core.execute_remote",
97+
return_value=42,
98+
)
99+
mocker.patch(
100+
"keras_remote.core.core.JobContext.from_params",
101+
return_value=MagicMock(),
102+
)
103+
104+
@run(accelerator="cpu", cluster=None)
105+
def func():
106+
pass
107+
108+
func()
109+
110+
call_args = mock_exec.call_args
111+
backend = call_args[0][1]
112+
assert backend.cluster == "env-cluster"
113+
114+
def test_namespace_from_env(self, monkeypatch, mocker):
115+
"""When namespace=None, falls back to KERAS_REMOTE_GKE_NAMESPACE env var."""
116+
monkeypatch.setenv("KERAS_REMOTE_GKE_NAMESPACE", "custom-ns")
117+
monkeypatch.setenv("KERAS_REMOTE_PROJECT", "proj")
118+
mock_exec = mocker.patch(
119+
"keras_remote.core.core.execute_remote",
120+
return_value=42,
121+
)
122+
mocker.patch(
123+
"keras_remote.core.core.JobContext.from_params",
124+
return_value=MagicMock(),
125+
)
126+
127+
@run(accelerator="cpu", namespace=None)
128+
def func():
129+
pass
130+
131+
func()
132+
133+
backend = mock_exec.call_args[0][1]
134+
assert backend.namespace == "custom-ns"

0 commit comments

Comments
 (0)