Skip to content

Commit 6ef9fce

Browse files
Adds cloud wrapper tests for storage and container builder
1 parent 67477a8 commit 6ef9fce

File tree

2 files changed

+374
-0
lines changed

2 files changed

+374
-0
lines changed
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
"""Tests for keras_remote.infra.container_builder — hashing, Dockerfile gen, caching."""
2+
3+
import re
4+
from unittest.mock import MagicMock
5+
6+
import pytest
7+
from google.api_core import exceptions as google_exceptions
8+
9+
from keras_remote.infra.container_builder import (
10+
_generate_dockerfile,
11+
_hash_requirements,
12+
_image_exists,
13+
get_or_build_container,
14+
)
15+
16+
17+
class TestHashRequirements:
18+
def test_deterministic(self, tmp_path):
19+
req = tmp_path / "requirements.txt"
20+
req.write_text("numpy==1.26\n")
21+
22+
h1 = _hash_requirements(str(req), "l4", "python:3.12-slim")
23+
h2 = _hash_requirements(str(req), "l4", "python:3.12-slim")
24+
assert h1 == h2
25+
26+
def test_different_requirements_different_hash(self, tmp_path):
27+
req1 = tmp_path / "r1.txt"
28+
req1.write_text("numpy==1.26\n")
29+
req2 = tmp_path / "r2.txt"
30+
req2.write_text("scipy==1.12\n")
31+
32+
h1 = _hash_requirements(str(req1), "l4", "python:3.12-slim")
33+
h2 = _hash_requirements(str(req2), "l4", "python:3.12-slim")
34+
assert h1 != h2
35+
36+
def test_different_accelerator_different_hash(self, tmp_path):
37+
req = tmp_path / "requirements.txt"
38+
req.write_text("numpy\n")
39+
40+
h1 = _hash_requirements(str(req), "l4", "python:3.12-slim")
41+
h2 = _hash_requirements(str(req), "v3-8", "python:3.12-slim")
42+
assert h1 != h2
43+
44+
def test_different_base_image_different_hash(self, tmp_path):
45+
req = tmp_path / "requirements.txt"
46+
req.write_text("numpy\n")
47+
48+
h1 = _hash_requirements(str(req), "l4", "python:3.12-slim")
49+
h2 = _hash_requirements(str(req), "l4", "python:3.11-slim")
50+
assert h1 != h2
51+
52+
@pytest.mark.parametrize(
53+
"requirements_path",
54+
[None, "/nonexistent/path.txt"],
55+
ids=["none", "nonexistent"],
56+
)
57+
def test_missing_requirements_valid(self, requirements_path):
58+
h = _hash_requirements(requirements_path, "cpu", "python:3.12-slim")
59+
assert isinstance(h, str)
60+
assert len(h) == 64
61+
62+
def test_returns_hex_string(self, tmp_path):
63+
req = tmp_path / "r.txt"
64+
req.write_text("keras\n")
65+
h = _hash_requirements(str(req), "l4", "python:3.12-slim")
66+
assert re.fullmatch(r"[0-9a-f]{64}", h)
67+
68+
69+
class TestGenerateDockerfile:
70+
@pytest.mark.parametrize(
71+
("accelerator_type", "expected", "not_expected"),
72+
[
73+
pytest.param("cpu", ["pip install jax"], ["cuda", "tpu"], id="cpu"),
74+
pytest.param("l4", ["jax[cuda12]"], [], id="gpu"),
75+
pytest.param("v3-8", ["jax[tpu]", "libtpu_releases"], [], id="tpu"),
76+
],
77+
)
78+
def test_jax_install(self, accelerator_type, expected, not_expected):
79+
content = _generate_dockerfile(
80+
base_image="python:3.12-slim",
81+
requirements_path=None,
82+
accelerator_type=accelerator_type,
83+
)
84+
for s in expected:
85+
assert s in content
86+
for s in not_expected:
87+
assert s not in content
88+
89+
def test_with_requirements(self, tmp_path):
90+
req = tmp_path / "requirements.txt"
91+
req.write_text("numpy\n")
92+
93+
content = _generate_dockerfile(
94+
base_image="python:3.12-slim",
95+
requirements_path=str(req),
96+
accelerator_type="cpu",
97+
)
98+
assert "COPY requirements.txt" in content
99+
assert "pip install -r" in content
100+
101+
def test_without_requirements(self):
102+
content = _generate_dockerfile(
103+
base_image="python:3.12-slim",
104+
requirements_path=None,
105+
accelerator_type="cpu",
106+
)
107+
assert "COPY requirements.txt" not in content
108+
109+
@pytest.mark.parametrize(
110+
"expected_substring",
111+
[
112+
pytest.param(
113+
"COPY remote_runner.py /app/remote_runner.py",
114+
id="remote_runner_copy",
115+
),
116+
pytest.param("ENV KERAS_BACKEND=jax", id="keras_backend_env"),
117+
],
118+
)
119+
def test_contains_expected_content(self, expected_substring):
120+
content = _generate_dockerfile(
121+
base_image="python:3.12-slim",
122+
requirements_path=None,
123+
accelerator_type="cpu",
124+
)
125+
assert expected_substring in content
126+
127+
def test_uses_base_image(self):
128+
content = _generate_dockerfile(
129+
base_image="python:3.11-bullseye",
130+
requirements_path=None,
131+
accelerator_type="cpu",
132+
)
133+
assert "FROM python:3.11-bullseye" in content
134+
135+
136+
class TestImageExists:
137+
def test_returns_true_when_tag_found(self, mocker):
138+
mock_client = MagicMock()
139+
mocker.patch(
140+
"keras_remote.infra.container_builder.artifactregistry_v1.ArtifactRegistryClient",
141+
return_value=mock_client,
142+
)
143+
result = _image_exists(
144+
"us-docker.pkg.dev/my-proj/keras-remote/base:l4-abc123",
145+
"my-proj",
146+
)
147+
assert result is True
148+
mock_client.get_tag.assert_called_once()
149+
150+
@pytest.mark.parametrize(
151+
"side_effect",
152+
[
153+
pytest.param(google_exceptions.NotFound("nope"), id="not_found"),
154+
pytest.param(RuntimeError("unexpected"), id="other_error"),
155+
],
156+
)
157+
def test_returns_false_on_error(self, mocker, side_effect):
158+
mock_client = MagicMock()
159+
mock_client.get_tag.side_effect = side_effect
160+
mocker.patch(
161+
"keras_remote.infra.container_builder.artifactregistry_v1.ArtifactRegistryClient",
162+
return_value=mock_client,
163+
)
164+
result = _image_exists(
165+
"us-docker.pkg.dev/my-proj/keras-remote/base:l4-abc123",
166+
"my-proj",
167+
)
168+
assert result is False
169+
170+
def test_correct_resource_name(self, mocker):
171+
mock_client = MagicMock()
172+
mocker.patch(
173+
"keras_remote.infra.container_builder.artifactregistry_v1.ArtifactRegistryClient",
174+
return_value=mock_client,
175+
)
176+
_image_exists(
177+
"us-docker.pkg.dev/my-proj/keras-remote/base:v3-8-abc123def456",
178+
"my-proj",
179+
)
180+
call_args = mock_client.get_tag.call_args
181+
request = call_args.kwargs.get("request") or call_args[1].get("request")
182+
assert request.name == (
183+
"projects/my-proj/locations/us"
184+
"/repositories/keras-remote"
185+
"/packages/base/tags/v3-8-abc123def456"
186+
)
187+
188+
189+
class TestGetOrBuildContainer:
190+
def test_returns_cached_when_image_exists(self, mocker):
191+
mocker.patch(
192+
"keras_remote.infra.container_builder._image_exists",
193+
return_value=True,
194+
)
195+
mock_build = mocker.patch(
196+
"keras_remote.infra.container_builder._build_and_push",
197+
)
198+
199+
result = get_or_build_container(
200+
base_image="python:3.12-slim",
201+
requirements_path=None,
202+
accelerator_type="l4",
203+
project="test-proj",
204+
zone="us-central1-a",
205+
)
206+
207+
mock_build.assert_not_called()
208+
assert "us-docker.pkg.dev/test-proj/keras-remote/base:" in result
209+
210+
def test_builds_when_image_missing(self, mocker):
211+
mocker.patch(
212+
"keras_remote.infra.container_builder._image_exists",
213+
return_value=False,
214+
)
215+
mock_build = mocker.patch(
216+
"keras_remote.infra.container_builder._build_and_push",
217+
return_value="us-docker.pkg.dev/proj/keras-remote/base:l4-bbbbbbbbbbbb",
218+
)
219+
220+
result = get_or_build_container(
221+
base_image="python:3.12-slim",
222+
requirements_path=None,
223+
accelerator_type="l4",
224+
project="proj",
225+
zone="us-central1-a",
226+
)
227+
228+
mock_build.assert_called_once()
229+
assert result == "us-docker.pkg.dev/proj/keras-remote/base:l4-bbbbbbbbbbbb"
230+
231+
def _get_image_uri(self, mocker, accelerator_type, project, zone):
232+
mocker.patch(
233+
"keras_remote.infra.container_builder._image_exists",
234+
return_value=True,
235+
)
236+
return get_or_build_container(
237+
base_image="python:3.12-slim",
238+
requirements_path=None,
239+
accelerator_type=accelerator_type,
240+
project=project,
241+
zone=zone,
242+
)
243+
244+
def test_image_uri_format_tpu_europe(self, mocker):
245+
result = self._get_image_uri(mocker, "v3-8", "my-proj", "europe-west4-b")
246+
247+
assert result.startswith("europe-docker.pkg.dev/my-proj/keras-remote/base:")
248+
tag = result.split(":")[-1]
249+
assert re.fullmatch(r"v3-8-[0-9a-f]{12}", tag)
250+
251+
def test_image_uri_format_gpu_us(self, mocker):
252+
result = self._get_image_uri(mocker, "a100-80gb", "proj", "us-central1-a")
253+
254+
assert result.startswith("us-docker.pkg.dev/proj/keras-remote/base:")
255+
tag = result.split(":")[-1]
256+
assert re.fullmatch(r"a100-80gb-[0-9a-f]{12}", tag)

keras_remote/utils/test_storage.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Tests for keras_remote.utils.storage — Cloud Storage operations."""
2+
3+
from unittest.mock import MagicMock
4+
5+
import pytest
6+
7+
from keras_remote.utils.storage import (
8+
_get_project,
9+
cleanup_artifacts,
10+
download_result,
11+
upload_artifacts,
12+
)
13+
14+
15+
@pytest.fixture
16+
def mock_gcs(mocker):
17+
"""Mock google.cloud.storage.Client at the import site."""
18+
mock_client = MagicMock()
19+
mocker.patch(
20+
"keras_remote.utils.storage.storage.Client",
21+
return_value=mock_client,
22+
)
23+
return mock_client
24+
25+
26+
class TestUploadArtifacts:
27+
def test_uploads_payload_and_context(self, mock_gcs):
28+
mock_bucket = mock_gcs.bucket.return_value
29+
mock_blob = mock_bucket.blob.return_value
30+
31+
upload_artifacts(
32+
bucket_name="my-bucket",
33+
job_id="job-abc123",
34+
payload_path="/tmp/payload.pkl",
35+
context_path="/tmp/context.zip",
36+
project="test-project",
37+
)
38+
39+
mock_bucket.blob.assert_any_call("job-abc123/payload.pkl")
40+
mock_bucket.blob.assert_any_call("job-abc123/context.zip")
41+
assert mock_blob.upload_from_filename.call_count == 2
42+
43+
def test_uses_correct_bucket(self, mock_gcs):
44+
upload_artifacts(
45+
bucket_name="my-custom-bucket",
46+
job_id="job-123",
47+
payload_path="/tmp/p.pkl",
48+
context_path="/tmp/c.zip",
49+
project="proj",
50+
)
51+
mock_gcs.bucket.assert_called_with("my-custom-bucket")
52+
53+
54+
class TestDownloadResult:
55+
def test_downloads_result_blob(self, mock_gcs):
56+
mock_bucket = mock_gcs.bucket.return_value
57+
mock_blob = mock_bucket.blob.return_value
58+
59+
download_result("my-bucket", "job-abc", project="proj")
60+
61+
mock_bucket.blob.assert_called_once_with("job-abc/result.pkl")
62+
mock_blob.download_to_filename.assert_called_once()
63+
64+
def test_returns_path_with_job_id(self, mock_gcs):
65+
result = download_result("my-bucket", "job-xyz", project="proj")
66+
assert "result-job-xyz.pkl" in result
67+
68+
69+
class TestCleanupArtifacts:
70+
def test_deletes_all_blobs(self, mock_gcs):
71+
mock_bucket = mock_gcs.bucket.return_value
72+
blob1 = MagicMock()
73+
blob2 = MagicMock()
74+
blob3 = MagicMock()
75+
mock_bucket.list_blobs.return_value = [blob1, blob2, blob3]
76+
77+
cleanup_artifacts("my-bucket", "job-abc", project="proj")
78+
79+
mock_bucket.list_blobs.assert_called_once_with(prefix="job-abc/")
80+
blob1.delete.assert_called_once()
81+
blob2.delete.assert_called_once()
82+
blob3.delete.assert_called_once()
83+
84+
def test_no_blobs_no_error(self, mock_gcs):
85+
mock_bucket = mock_gcs.bucket.return_value
86+
mock_bucket.list_blobs.return_value = []
87+
88+
cleanup_artifacts("my-bucket", "job-abc", project="proj")
89+
90+
mock_bucket.list_blobs.assert_called_once_with(prefix="job-abc/")
91+
92+
93+
class TestGetProject:
94+
@pytest.mark.parametrize(
95+
("kr_project", "gc_project", "expected"),
96+
[
97+
# Only KERAS_REMOTE_PROJECT set: use it directly
98+
("kr-proj", None, "kr-proj"),
99+
# Only GOOGLE_CLOUD_PROJECT set: fall back to it
100+
(None, "gc-proj", "gc-proj"),
101+
# Neither set: no project resolved
102+
(None, None, None),
103+
# Both set: KERAS_REMOTE_PROJECT takes precedence
104+
("kr-proj", "gc-proj", "kr-proj"),
105+
],
106+
)
107+
def test_resolves_project(
108+
self, monkeypatch, kr_project, gc_project, expected
109+
):
110+
if kr_project:
111+
monkeypatch.setenv("KERAS_REMOTE_PROJECT", kr_project)
112+
else:
113+
monkeypatch.delenv("KERAS_REMOTE_PROJECT", raising=False)
114+
if gc_project:
115+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", gc_project)
116+
else:
117+
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
118+
assert _get_project() == expected

0 commit comments

Comments
 (0)