Skip to content

Commit 05d95c9

Browse files
Adds unit tests for Storage and ContainerBuilder (#34)
* Adds cloud wrapper tests for storage and container builder * Migrate cloud wrapper tests from pytest to absl testing * address reviews * rename test files * rename test files * fix ci
1 parent 7b711a4 commit 05d95c9

File tree

13 files changed

+438
-1
lines changed

13 files changed

+438
-1
lines changed

.github/workflows/tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
run: >
2525
coverage run
2626
-m unittest discover
27-
-s keras_remote -p "test_*.py"
27+
-s keras_remote -p "*_test.py"
2828
-v
2929
3030
- name: Generate coverage report
File renamed without changes.
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
"""Tests for keras_remote.infra.container_builder — hashing, Dockerfile gen, caching."""
2+
3+
import pathlib
4+
import tempfile
5+
from unittest import mock
6+
from unittest.mock import MagicMock
7+
8+
from absl.testing import absltest, parameterized
9+
from google.api_core import exceptions as google_exceptions
10+
11+
from keras_remote.infra.container_builder import (
12+
_generate_dockerfile,
13+
_hash_requirements,
14+
_image_exists,
15+
get_or_build_container,
16+
)
17+
18+
19+
def _make_temp_path(test_case):
20+
"""Create a temp directory that is cleaned up after the test."""
21+
td = tempfile.TemporaryDirectory()
22+
test_case.addCleanup(td.cleanup)
23+
return pathlib.Path(td.name)
24+
25+
26+
class TestHashRequirements(parameterized.TestCase):
27+
def test_deterministic(self):
28+
tmp_path = _make_temp_path(self)
29+
req = tmp_path / "requirements.txt"
30+
req.write_text("numpy==1.26\n")
31+
32+
h1 = _hash_requirements(str(req), "l4", "python:3.12-slim")
33+
h2 = _hash_requirements(str(req), "l4", "python:3.12-slim")
34+
self.assertEqual(h1, h2)
35+
36+
def test_different_requirements_different_hash(self):
37+
tmp_path = _make_temp_path(self)
38+
req1 = tmp_path / "r1.txt"
39+
req1.write_text("numpy==1.26\n")
40+
req2 = tmp_path / "r2.txt"
41+
req2.write_text("scipy==1.12\n")
42+
43+
h1 = _hash_requirements(str(req1), "l4", "python:3.12-slim")
44+
h2 = _hash_requirements(str(req2), "l4", "python:3.12-slim")
45+
self.assertNotEqual(h1, h2)
46+
47+
def test_different_accelerator_different_hash(self):
48+
tmp_path = _make_temp_path(self)
49+
req = tmp_path / "requirements.txt"
50+
req.write_text("numpy\n")
51+
52+
h1 = _hash_requirements(str(req), "l4", "python:3.12-slim")
53+
h2 = _hash_requirements(str(req), "v3-8", "python:3.12-slim")
54+
self.assertNotEqual(h1, h2)
55+
56+
def test_different_base_image_different_hash(self):
57+
tmp_path = _make_temp_path(self)
58+
req = tmp_path / "requirements.txt"
59+
req.write_text("numpy\n")
60+
61+
h1 = _hash_requirements(str(req), "l4", "python:3.12-slim")
62+
h2 = _hash_requirements(str(req), "l4", "python:3.11-slim")
63+
self.assertNotEqual(h1, h2)
64+
65+
@parameterized.named_parameters(
66+
dict(testcase_name="none", requirements_path=None),
67+
dict(
68+
testcase_name="nonexistent",
69+
requirements_path="/nonexistent/path.txt",
70+
),
71+
)
72+
def test_missing_requirements_valid(self, requirements_path):
73+
h = _hash_requirements(requirements_path, "cpu", "python:3.12-slim")
74+
self.assertIsInstance(h, str)
75+
self.assertLen(h, 64)
76+
77+
def test_returns_hex_string(self):
78+
tmp_path = _make_temp_path(self)
79+
req = tmp_path / "r.txt"
80+
req.write_text("keras\n")
81+
h = _hash_requirements(str(req), "l4", "python:3.12-slim")
82+
self.assertRegex(h, r"^[0-9a-f]{64}$")
83+
84+
85+
class TestGenerateDockerfile(parameterized.TestCase):
86+
@parameterized.named_parameters(
87+
dict(
88+
testcase_name="cpu",
89+
accelerator_type="cpu",
90+
expected=["pip install jax"],
91+
not_expected=["cuda", "tpu"],
92+
),
93+
dict(
94+
testcase_name="gpu",
95+
accelerator_type="l4",
96+
expected=["jax[cuda12]"],
97+
not_expected=[],
98+
),
99+
dict(
100+
testcase_name="tpu",
101+
accelerator_type="v3-8",
102+
expected=["jax[tpu]", "libtpu_releases"],
103+
not_expected=[],
104+
),
105+
)
106+
def test_jax_install(self, accelerator_type, expected, not_expected):
107+
content = _generate_dockerfile(
108+
base_image="python:3.12-slim",
109+
requirements_path=None,
110+
accelerator_type=accelerator_type,
111+
)
112+
for s in expected:
113+
self.assertIn(s, content)
114+
for s in not_expected:
115+
self.assertNotIn(s, content)
116+
117+
def test_with_requirements(self):
118+
tmp_path = _make_temp_path(self)
119+
req = tmp_path / "requirements.txt"
120+
req.write_text("numpy\n")
121+
122+
content = _generate_dockerfile(
123+
base_image="python:3.12-slim",
124+
requirements_path=str(req),
125+
accelerator_type="cpu",
126+
)
127+
self.assertIn("COPY requirements.txt", content)
128+
self.assertIn("pip install -r", content)
129+
130+
def test_without_requirements(self):
131+
content = _generate_dockerfile(
132+
base_image="python:3.12-slim",
133+
requirements_path=None,
134+
accelerator_type="cpu",
135+
)
136+
self.assertNotIn("COPY requirements.txt", content)
137+
138+
@parameterized.named_parameters(
139+
dict(
140+
testcase_name="remote_runner_copy",
141+
expected_substring="COPY remote_runner.py /app/remote_runner.py",
142+
),
143+
dict(
144+
testcase_name="keras_backend_env",
145+
expected_substring="ENV KERAS_BACKEND=jax",
146+
),
147+
)
148+
def test_contains_expected_content(self, expected_substring):
149+
content = _generate_dockerfile(
150+
base_image="python:3.12-slim",
151+
requirements_path=None,
152+
accelerator_type="cpu",
153+
)
154+
self.assertIn(expected_substring, content)
155+
156+
def test_uses_base_image(self):
157+
content = _generate_dockerfile(
158+
base_image="python:3.11-bullseye",
159+
requirements_path=None,
160+
accelerator_type="cpu",
161+
)
162+
self.assertIn("FROM python:3.11-bullseye", content)
163+
164+
165+
class TestImageExists(parameterized.TestCase):
166+
def test_returns_true_when_tag_found(self):
167+
mock_client = MagicMock()
168+
with mock.patch(
169+
"keras_remote.infra.container_builder.artifactregistry_v1.ArtifactRegistryClient",
170+
return_value=mock_client,
171+
):
172+
result = _image_exists(
173+
"us-docker.pkg.dev/my-proj/keras-remote/base:l4-abc123",
174+
"my-proj",
175+
)
176+
self.assertTrue(result)
177+
mock_client.get_tag.assert_called_once()
178+
179+
@parameterized.named_parameters(
180+
dict(
181+
testcase_name="not_found",
182+
side_effect=google_exceptions.NotFound("nope"),
183+
),
184+
dict(
185+
testcase_name="other_error",
186+
side_effect=RuntimeError("unexpected"),
187+
),
188+
)
189+
def test_returns_false_on_error(self, side_effect):
190+
mock_client = MagicMock()
191+
mock_client.get_tag.side_effect = side_effect
192+
with mock.patch(
193+
"keras_remote.infra.container_builder.artifactregistry_v1.ArtifactRegistryClient",
194+
return_value=mock_client,
195+
):
196+
result = _image_exists(
197+
"us-docker.pkg.dev/my-proj/keras-remote/base:l4-abc123",
198+
"my-proj",
199+
)
200+
self.assertFalse(result)
201+
202+
def test_correct_resource_name(self):
203+
mock_client = MagicMock()
204+
with mock.patch(
205+
"keras_remote.infra.container_builder.artifactregistry_v1.ArtifactRegistryClient",
206+
return_value=mock_client,
207+
):
208+
_image_exists(
209+
"us-docker.pkg.dev/my-proj/keras-remote/base:v3-8-abc123def456",
210+
"my-proj",
211+
)
212+
call_args = mock_client.get_tag.call_args
213+
request = call_args.kwargs["request"]
214+
self.assertEqual(
215+
request.name,
216+
"projects/my-proj/locations/us"
217+
"/repositories/keras-remote"
218+
"/packages/base/tags/v3-8-abc123def456",
219+
)
220+
221+
222+
class TestGetOrBuildContainer(absltest.TestCase):
223+
def test_returns_cached_when_image_exists(self):
224+
with (
225+
mock.patch(
226+
"keras_remote.infra.container_builder._image_exists",
227+
return_value=True,
228+
),
229+
mock.patch(
230+
"keras_remote.infra.container_builder._build_and_push",
231+
) as mock_build,
232+
):
233+
result = get_or_build_container(
234+
base_image="python:3.12-slim",
235+
requirements_path=None,
236+
accelerator_type="l4",
237+
project="test-proj",
238+
zone="us-central1-a",
239+
)
240+
241+
mock_build.assert_not_called()
242+
self.assertIn("us-docker.pkg.dev/test-proj/keras-remote/base:", result)
243+
244+
def test_builds_when_image_missing(self):
245+
with (
246+
mock.patch(
247+
"keras_remote.infra.container_builder._image_exists",
248+
return_value=False,
249+
),
250+
mock.patch(
251+
"keras_remote.infra.container_builder._build_and_push",
252+
return_value="us-docker.pkg.dev/proj/keras-remote/base:l4-bbbbbbbbbbbb",
253+
) as mock_build,
254+
):
255+
result = get_or_build_container(
256+
base_image="python:3.12-slim",
257+
requirements_path=None,
258+
accelerator_type="l4",
259+
project="proj",
260+
zone="us-central1-a",
261+
)
262+
263+
mock_build.assert_called_once()
264+
self.assertEqual(
265+
result, "us-docker.pkg.dev/proj/keras-remote/base:l4-bbbbbbbbbbbb"
266+
)
267+
268+
def _get_image_uri(self, accelerator_type, project, zone):
269+
with mock.patch(
270+
"keras_remote.infra.container_builder._image_exists",
271+
return_value=True,
272+
):
273+
return get_or_build_container(
274+
base_image="python:3.12-slim",
275+
requirements_path=None,
276+
accelerator_type=accelerator_type,
277+
project=project,
278+
zone=zone,
279+
)
280+
281+
def test_image_uri_format_tpu_europe(self):
282+
result = self._get_image_uri("v3-8", "my-proj", "europe-west4-b")
283+
284+
self.assertTrue(
285+
result.startswith("europe-docker.pkg.dev/my-proj/keras-remote/base:")
286+
)
287+
tag = result.split(":")[-1]
288+
self.assertRegex(tag, r"^v3-8-[0-9a-f]{12}$")
289+
290+
def test_image_uri_format_gpu_us(self):
291+
result = self._get_image_uri("a100-80gb", "proj", "us-central1-a")
292+
293+
self.assertTrue(
294+
result.startswith("us-docker.pkg.dev/proj/keras-remote/base:")
295+
)
296+
tag = result.split(":")[-1]
297+
self.assertRegex(tag, r"^a100-80gb-[0-9a-f]{12}$")
298+
299+
300+
if __name__ == "__main__":
301+
absltest.main()

0 commit comments

Comments
 (0)