Skip to content

Commit cd688db

Browse files
Migrate kubernetes backend tests from pytest to absl testing
1 parent 3d89921 commit cd688db

File tree

2 files changed

+278
-202
lines changed

2 files changed

+278
-202
lines changed
Lines changed: 106 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Tests for keras_remote.backend.execution — JobContext and execute_remote."""
22

3-
import re
3+
import os
4+
import pathlib
5+
import tempfile
6+
from unittest import mock
47
from unittest.mock import MagicMock
58

6-
import pytest
9+
from absl.testing import absltest
710

811
from keras_remote.backend.execution import (
912
JobContext,
@@ -12,7 +15,14 @@
1215
)
1316

1417

15-
class TestJobContext:
18+
def _make_temp_path(test_case):
19+
"""Create a temp directory that is cleaned up after the test."""
20+
td = tempfile.TemporaryDirectory()
21+
test_case.addCleanup(td.cleanup)
22+
return pathlib.Path(td.name)
23+
24+
25+
class TestJobContext(absltest.TestCase):
1626
def _make_func(self):
1727
def my_train():
1828
return 42
@@ -30,10 +40,10 @@ def test_post_init_derived_fields(self):
3040
zone="europe-west4-b",
3141
project="my-proj",
3242
)
33-
assert ctx.bucket_name == "my-proj-keras-remote-jobs"
34-
assert ctx.region == "europe-west4"
35-
assert ctx.display_name.startswith("keras-remote-my_train-")
36-
assert re.fullmatch(r"job-[0-9a-f]{8}", ctx.job_id)
43+
self.assertEqual(ctx.bucket_name, "my-proj-keras-remote-jobs")
44+
self.assertEqual(ctx.region, "europe-west4")
45+
self.assertTrue(ctx.display_name.startswith("keras-remote-my_train-"))
46+
self.assertRegex(ctx.job_id, r"^job-[0-9a-f]{8}$")
3747

3848
def test_from_params_explicit(self):
3949
ctx = JobContext.from_params(
@@ -46,35 +56,38 @@ def test_from_params_explicit(self):
4656
project="explicit-proj",
4757
env_vars={"X": "Y"},
4858
)
49-
assert ctx.zone == "us-west1-a"
50-
assert ctx.project == "explicit-proj"
51-
assert ctx.accelerator == "l4"
52-
assert ctx.container_image == "my-image:latest"
53-
assert ctx.args == (1, 2)
54-
assert ctx.kwargs == {"k": "v"}
55-
assert ctx.env_vars == {"X": "Y"}
56-
57-
def test_from_params_resolves_zone_from_env(self, monkeypatch):
58-
monkeypatch.setenv("KERAS_REMOTE_ZONE", "asia-east1-c")
59-
monkeypatch.setenv("KERAS_REMOTE_PROJECT", "env-proj")
60-
61-
ctx = JobContext.from_params(
62-
func=self._make_func(),
63-
args=(),
64-
kwargs={},
65-
accelerator="cpu",
66-
container_image=None,
67-
zone=None,
68-
project=None,
69-
env_vars={},
70-
)
71-
assert ctx.zone == "asia-east1-c"
72-
assert ctx.project == "env-proj"
73-
74-
def test_from_params_no_project_raises(self, monkeypatch):
75-
monkeypatch.delenv("KERAS_REMOTE_PROJECT", raising=False)
76-
77-
with pytest.raises(ValueError, match="project must be specified"):
59+
self.assertEqual(ctx.zone, "us-west1-a")
60+
self.assertEqual(ctx.project, "explicit-proj")
61+
self.assertEqual(ctx.accelerator, "l4")
62+
self.assertEqual(ctx.container_image, "my-image:latest")
63+
self.assertEqual(ctx.args, (1, 2))
64+
self.assertEqual(ctx.kwargs, {"k": "v"})
65+
self.assertEqual(ctx.env_vars, {"X": "Y"})
66+
67+
def test_from_params_resolves_zone_from_env(self):
68+
with mock.patch.dict(
69+
os.environ,
70+
{"KERAS_REMOTE_ZONE": "asia-east1-c", "KERAS_REMOTE_PROJECT": "env-proj"},
71+
):
72+
ctx = JobContext.from_params(
73+
func=self._make_func(),
74+
args=(),
75+
kwargs={},
76+
accelerator="cpu",
77+
container_image=None,
78+
zone=None,
79+
project=None,
80+
env_vars={},
81+
)
82+
self.assertEqual(ctx.zone, "asia-east1-c")
83+
self.assertEqual(ctx.project, "env-proj")
84+
85+
def test_from_params_no_project_raises(self):
86+
env = {k: v for k, v in os.environ.items() if k != "KERAS_REMOTE_PROJECT"}
87+
with (
88+
mock.patch.dict(os.environ, env, clear=True),
89+
self.assertRaisesRegex(ValueError, "project must be specified"),
90+
):
7891
JobContext.from_params(
7992
func=self._make_func(),
8093
args=(),
@@ -87,29 +100,36 @@ def test_from_params_no_project_raises(self, monkeypatch):
87100
)
88101

89102

90-
class TestFindRequirements:
91-
def test_finds_in_start_dir(self, tmp_path):
103+
class TestFindRequirements(absltest.TestCase):
104+
def test_finds_in_start_dir(self):
92105
"""Returns the path when requirements.txt exists in the start directory."""
106+
tmp_path = _make_temp_path(self)
93107
(tmp_path / "requirements.txt").write_text("numpy\n")
94-
assert _find_requirements(str(tmp_path)) == str(
95-
tmp_path / "requirements.txt"
108+
self.assertEqual(
109+
_find_requirements(str(tmp_path)),
110+
str(tmp_path / "requirements.txt"),
96111
)
97112

98-
def test_finds_in_parent_dir(self, tmp_path):
113+
def test_finds_in_parent_dir(self):
99114
"""Walks up the directory tree to find requirements.txt in a parent."""
115+
tmp_path = _make_temp_path(self)
100116
(tmp_path / "requirements.txt").write_text("numpy\n")
101117
child = tmp_path / "subdir"
102118
child.mkdir()
103-
assert _find_requirements(str(child)) == str(tmp_path / "requirements.txt")
119+
self.assertEqual(
120+
_find_requirements(str(child)),
121+
str(tmp_path / "requirements.txt"),
122+
)
104123

105-
def test_returns_none_when_not_found(self, tmp_path):
124+
def test_returns_none_when_not_found(self):
106125
"""Returns None when no requirements.txt exists in any ancestor."""
126+
tmp_path = _make_temp_path(self)
107127
empty = tmp_path / "empty"
108128
empty.mkdir()
109-
assert _find_requirements(str(empty)) is None
129+
self.assertIsNone(_find_requirements(str(empty)))
110130

111131

112-
class TestExecuteRemote:
132+
class TestExecuteRemote(absltest.TestCase):
113133
def _make_func(self):
114134
def my_train():
115135
return 42
@@ -128,38 +148,44 @@ def _make_ctx(self, container_image=None):
128148
project="proj",
129149
)
130150

131-
def test_success_flow(self, mocker):
132-
mocker.patch("keras_remote.backend.execution._build_container")
133-
mocker.patch("keras_remote.backend.execution._upload_artifacts")
134-
mocker.patch(
135-
"keras_remote.backend.execution._download_result",
136-
return_value={"success": True, "result": 42},
137-
)
138-
mocker.patch(
139-
"keras_remote.backend.execution._cleanup_and_return",
140-
return_value=42,
141-
)
142-
143-
ctx = self._make_ctx()
144-
backend = MagicMock()
145-
146-
result = execute_remote(ctx, backend)
147-
148-
backend.submit_job.assert_called_once_with(ctx)
149-
backend.wait_for_job.assert_called_once()
150-
backend.cleanup_job.assert_called_once()
151-
assert result == 42
152-
153-
def test_cleanup_on_wait_failure(self, mocker):
154-
mocker.patch("keras_remote.backend.execution._build_container")
155-
mocker.patch("keras_remote.backend.execution._upload_artifacts")
156-
157-
ctx = self._make_ctx()
158-
backend = MagicMock()
159-
backend.wait_for_job.side_effect = RuntimeError("job failed")
160-
161-
with pytest.raises(RuntimeError, match="job failed"):
162-
execute_remote(ctx, backend)
163-
164-
# cleanup_job is called in finally block even when wait fails
165-
backend.cleanup_job.assert_called_once()
151+
def test_success_flow(self):
152+
with (
153+
mock.patch("keras_remote.backend.execution._build_container"),
154+
mock.patch("keras_remote.backend.execution._upload_artifacts"),
155+
mock.patch(
156+
"keras_remote.backend.execution._download_result",
157+
return_value={"success": True, "result": 42},
158+
),
159+
mock.patch(
160+
"keras_remote.backend.execution._cleanup_and_return",
161+
return_value=42,
162+
),
163+
):
164+
ctx = self._make_ctx()
165+
backend = MagicMock()
166+
167+
result = execute_remote(ctx, backend)
168+
169+
backend.submit_job.assert_called_once_with(ctx)
170+
backend.wait_for_job.assert_called_once()
171+
backend.cleanup_job.assert_called_once()
172+
self.assertEqual(result, 42)
173+
174+
def test_cleanup_on_wait_failure(self):
175+
with (
176+
mock.patch("keras_remote.backend.execution._build_container"),
177+
mock.patch("keras_remote.backend.execution._upload_artifacts"),
178+
):
179+
ctx = self._make_ctx()
180+
backend = MagicMock()
181+
backend.wait_for_job.side_effect = RuntimeError("job failed")
182+
183+
with self.assertRaisesRegex(RuntimeError, "job failed"):
184+
execute_remote(ctx, backend)
185+
186+
# cleanup_job is called in finally block even when wait fails
187+
backend.cleanup_job.assert_called_once()
188+
189+
190+
if __name__ == "__main__":
191+
absltest.main()

0 commit comments

Comments
 (0)