Skip to content

Commit 6b5fbaa

Browse files
Adds Kubernetes and execution pipeline tests
1 parent 67477a8 commit 6b5fbaa

File tree

2 files changed

+447
-0
lines changed

2 files changed

+447
-0
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
"""Tests for keras_remote.backend.execution — JobContext and execute_remote."""
2+
3+
import re
4+
from unittest.mock import MagicMock
5+
6+
import pytest
7+
8+
from keras_remote.backend.execution import (
9+
JobContext,
10+
_find_requirements,
11+
execute_remote,
12+
)
13+
14+
15+
class TestJobContext:
16+
def _make_func(self):
17+
def my_train():
18+
return 42
19+
20+
return my_train
21+
22+
def test_post_init_derived_fields(self):
23+
ctx = JobContext(
24+
func=self._make_func(),
25+
args=(),
26+
kwargs={},
27+
env_vars={},
28+
accelerator="cpu",
29+
container_image=None,
30+
zone="europe-west4-b",
31+
project="my-proj",
32+
)
33+
assert ctx.bucket_name == "my-proj-keras-remote-jobs"
34+
assert ctx.region == "europe-west4"
35+
assert ctx.display_name.startswith("keras-remote-my_train-")
36+
assert re.fullmatch(r"job-[0-9a-f]{8}", ctx.job_id)
37+
38+
def test_from_params_explicit(self):
39+
ctx = JobContext.from_params(
40+
func=self._make_func(),
41+
args=(1, 2),
42+
kwargs={"k": "v"},
43+
accelerator="l4",
44+
container_image="my-image:latest",
45+
zone="us-west1-a",
46+
project="explicit-proj",
47+
env_vars={"X": "Y"},
48+
)
49+
assert ctx.zone == "us-west1-a"
50+
assert ctx.project == "explicit-proj"
51+
assert ctx.accelerator == "l4"
52+
assert ctx.container_image == "my-image:latest"
53+
assert ctx.args == (1, 2)
54+
assert ctx.kwargs == {"k": "v"}
55+
assert ctx.env_vars == {"X": "Y"}
56+
57+
def test_from_params_resolves_zone_from_env(self, monkeypatch):
58+
monkeypatch.setenv("KERAS_REMOTE_ZONE", "asia-east1-c")
59+
monkeypatch.setenv("KERAS_REMOTE_PROJECT", "env-proj")
60+
61+
ctx = JobContext.from_params(
62+
func=self._make_func(),
63+
args=(),
64+
kwargs={},
65+
accelerator="cpu",
66+
container_image=None,
67+
zone=None,
68+
project=None,
69+
env_vars={},
70+
)
71+
assert ctx.zone == "asia-east1-c"
72+
assert ctx.project == "env-proj"
73+
74+
def test_from_params_no_project_raises(self, monkeypatch):
75+
monkeypatch.delenv("KERAS_REMOTE_PROJECT", raising=False)
76+
77+
with pytest.raises(ValueError, match="project must be specified"):
78+
JobContext.from_params(
79+
func=self._make_func(),
80+
args=(),
81+
kwargs={},
82+
accelerator="cpu",
83+
container_image=None,
84+
zone="us-central1-a",
85+
project=None,
86+
env_vars={},
87+
)
88+
89+
90+
class TestFindRequirements:
91+
def test_finds_in_start_dir(self, tmp_path):
92+
"""Returns the path when requirements.txt exists in the start directory."""
93+
(tmp_path / "requirements.txt").write_text("numpy\n")
94+
assert _find_requirements(str(tmp_path)) == str(
95+
tmp_path / "requirements.txt"
96+
)
97+
98+
def test_finds_in_parent_dir(self, tmp_path):
99+
"""Walks up the directory tree to find requirements.txt in a parent."""
100+
(tmp_path / "requirements.txt").write_text("numpy\n")
101+
child = tmp_path / "subdir"
102+
child.mkdir()
103+
assert _find_requirements(str(child)) == str(tmp_path / "requirements.txt")
104+
105+
def test_returns_none_when_not_found(self, tmp_path):
106+
"""Returns None when no requirements.txt exists in any ancestor."""
107+
empty = tmp_path / "empty"
108+
empty.mkdir()
109+
assert _find_requirements(str(empty)) is None
110+
111+
112+
class TestExecuteRemote:
113+
def _make_func(self):
114+
def my_train():
115+
return 42
116+
117+
return my_train
118+
119+
def _make_ctx(self, container_image=None):
120+
return JobContext(
121+
func=self._make_func(),
122+
args=(),
123+
kwargs={},
124+
env_vars={},
125+
accelerator="cpu",
126+
container_image=container_image,
127+
zone="us-central1-a",
128+
project="proj",
129+
)
130+
131+
def test_success_flow(self, mocker):
132+
mocker.patch("keras_remote.backend.execution._build_container")
133+
mocker.patch("keras_remote.backend.execution._upload_artifacts")
134+
mocker.patch(
135+
"keras_remote.backend.execution._download_result",
136+
return_value={"success": True, "result": 42},
137+
)
138+
mocker.patch(
139+
"keras_remote.backend.execution._cleanup_and_return",
140+
return_value=42,
141+
)
142+
143+
ctx = self._make_ctx()
144+
backend = MagicMock()
145+
146+
result = execute_remote(ctx, backend)
147+
148+
backend.submit_job.assert_called_once_with(ctx)
149+
backend.wait_for_job.assert_called_once()
150+
backend.cleanup_job.assert_called_once()
151+
assert result == 42
152+
153+
def test_cleanup_on_wait_failure(self, mocker):
154+
mocker.patch("keras_remote.backend.execution._build_container")
155+
mocker.patch("keras_remote.backend.execution._upload_artifacts")
156+
157+
ctx = self._make_ctx()
158+
backend = MagicMock()
159+
backend.wait_for_job.side_effect = RuntimeError("job failed")
160+
161+
with pytest.raises(RuntimeError, match="job failed"):
162+
execute_remote(ctx, backend)
163+
164+
# cleanup_job is called in finally block even when wait fails
165+
backend.cleanup_job.assert_called_once()

0 commit comments

Comments
 (0)