Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
run: >
coverage run
-m unittest discover
-s keras_remote -p "test_*.py"
-s keras_remote -p "*_test.py"
-v

- name: Generate coverage report
Expand Down
File renamed without changes.
File renamed without changes.
301 changes: 301 additions & 0 deletions keras_remote/infra/container_builder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
"""Tests for keras_remote.infra.container_builder — hashing, Dockerfile gen, caching."""

import pathlib
import tempfile
from unittest import mock
from unittest.mock import MagicMock

from absl.testing import absltest, parameterized
from google.api_core import exceptions as google_exceptions

from keras_remote.infra.container_builder import (
_generate_dockerfile,
_hash_requirements,
_image_exists,
get_or_build_container,
)


def _make_temp_path(test_case):
"""Create a temp directory that is cleaned up after the test."""
td = tempfile.TemporaryDirectory()
test_case.addCleanup(td.cleanup)
return pathlib.Path(td.name)


class TestHashRequirements(parameterized.TestCase):
def test_deterministic(self):
tmp_path = _make_temp_path(self)
req = tmp_path / "requirements.txt"
req.write_text("numpy==1.26\n")

h1 = _hash_requirements(str(req), "l4", "python:3.12-slim")
h2 = _hash_requirements(str(req), "l4", "python:3.12-slim")
self.assertEqual(h1, h2)

def test_different_requirements_different_hash(self):
tmp_path = _make_temp_path(self)
req1 = tmp_path / "r1.txt"
req1.write_text("numpy==1.26\n")
req2 = tmp_path / "r2.txt"
req2.write_text("scipy==1.12\n")

h1 = _hash_requirements(str(req1), "l4", "python:3.12-slim")
h2 = _hash_requirements(str(req2), "l4", "python:3.12-slim")
self.assertNotEqual(h1, h2)

def test_different_accelerator_different_hash(self):
tmp_path = _make_temp_path(self)
req = tmp_path / "requirements.txt"
req.write_text("numpy\n")

h1 = _hash_requirements(str(req), "l4", "python:3.12-slim")
h2 = _hash_requirements(str(req), "v3-8", "python:3.12-slim")
self.assertNotEqual(h1, h2)

def test_different_base_image_different_hash(self):
tmp_path = _make_temp_path(self)
req = tmp_path / "requirements.txt"
req.write_text("numpy\n")

h1 = _hash_requirements(str(req), "l4", "python:3.12-slim")
h2 = _hash_requirements(str(req), "l4", "python:3.11-slim")
self.assertNotEqual(h1, h2)

@parameterized.named_parameters(
dict(testcase_name="none", requirements_path=None),
dict(
testcase_name="nonexistent",
requirements_path="/nonexistent/path.txt",
),
)
def test_missing_requirements_valid(self, requirements_path):
h = _hash_requirements(requirements_path, "cpu", "python:3.12-slim")
self.assertIsInstance(h, str)
self.assertLen(h, 64)

def test_returns_hex_string(self):
tmp_path = _make_temp_path(self)
req = tmp_path / "r.txt"
req.write_text("keras\n")
h = _hash_requirements(str(req), "l4", "python:3.12-slim")
self.assertRegex(h, r"^[0-9a-f]{64}$")


class TestGenerateDockerfile(parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name="cpu",
accelerator_type="cpu",
expected=["pip install jax"],
not_expected=["cuda", "tpu"],
),
dict(
testcase_name="gpu",
accelerator_type="l4",
expected=["jax[cuda12]"],
not_expected=[],
),
dict(
testcase_name="tpu",
accelerator_type="v3-8",
expected=["jax[tpu]", "libtpu_releases"],
not_expected=[],
),
)
def test_jax_install(self, accelerator_type, expected, not_expected):
content = _generate_dockerfile(
base_image="python:3.12-slim",
requirements_path=None,
accelerator_type=accelerator_type,
)
for s in expected:
self.assertIn(s, content)
for s in not_expected:
self.assertNotIn(s, content)

def test_with_requirements(self):
tmp_path = _make_temp_path(self)
req = tmp_path / "requirements.txt"
req.write_text("numpy\n")

content = _generate_dockerfile(
base_image="python:3.12-slim",
requirements_path=str(req),
accelerator_type="cpu",
)
self.assertIn("COPY requirements.txt", content)
self.assertIn("pip install -r", content)

def test_without_requirements(self):
content = _generate_dockerfile(
base_image="python:3.12-slim",
requirements_path=None,
accelerator_type="cpu",
)
self.assertNotIn("COPY requirements.txt", content)

@parameterized.named_parameters(
dict(
testcase_name="remote_runner_copy",
expected_substring="COPY remote_runner.py /app/remote_runner.py",
),
dict(
testcase_name="keras_backend_env",
expected_substring="ENV KERAS_BACKEND=jax",
),
)
def test_contains_expected_content(self, expected_substring):
content = _generate_dockerfile(
base_image="python:3.12-slim",
requirements_path=None,
accelerator_type="cpu",
)
self.assertIn(expected_substring, content)

def test_uses_base_image(self):
content = _generate_dockerfile(
base_image="python:3.11-bullseye",
requirements_path=None,
accelerator_type="cpu",
)
self.assertIn("FROM python:3.11-bullseye", content)


class TestImageExists(parameterized.TestCase):
def test_returns_true_when_tag_found(self):
mock_client = MagicMock()
with mock.patch(
"keras_remote.infra.container_builder.artifactregistry_v1.ArtifactRegistryClient",
return_value=mock_client,
):
result = _image_exists(
"us-docker.pkg.dev/my-proj/keras-remote/base:l4-abc123",
"my-proj",
)
self.assertTrue(result)
mock_client.get_tag.assert_called_once()

@parameterized.named_parameters(
dict(
testcase_name="not_found",
side_effect=google_exceptions.NotFound("nope"),
),
dict(
testcase_name="other_error",
side_effect=RuntimeError("unexpected"),
),
)
def test_returns_false_on_error(self, side_effect):
mock_client = MagicMock()
mock_client.get_tag.side_effect = side_effect
with mock.patch(
"keras_remote.infra.container_builder.artifactregistry_v1.ArtifactRegistryClient",
return_value=mock_client,
):
result = _image_exists(
"us-docker.pkg.dev/my-proj/keras-remote/base:l4-abc123",
"my-proj",
)
self.assertFalse(result)

def test_correct_resource_name(self):
mock_client = MagicMock()
with mock.patch(
"keras_remote.infra.container_builder.artifactregistry_v1.ArtifactRegistryClient",
return_value=mock_client,
):
_image_exists(
"us-docker.pkg.dev/my-proj/keras-remote/base:v3-8-abc123def456",
"my-proj",
)
call_args = mock_client.get_tag.call_args
request = call_args.kwargs["request"]
self.assertEqual(
request.name,
"projects/my-proj/locations/us"
"/repositories/keras-remote"
"/packages/base/tags/v3-8-abc123def456",
)


class TestGetOrBuildContainer(absltest.TestCase):
def test_returns_cached_when_image_exists(self):
with (
mock.patch(
"keras_remote.infra.container_builder._image_exists",
return_value=True,
),
mock.patch(
"keras_remote.infra.container_builder._build_and_push",
) as mock_build,
):
result = get_or_build_container(
base_image="python:3.12-slim",
requirements_path=None,
accelerator_type="l4",
project="test-proj",
zone="us-central1-a",
)

mock_build.assert_not_called()
self.assertIn("us-docker.pkg.dev/test-proj/keras-remote/base:", result)

def test_builds_when_image_missing(self):
with (
mock.patch(
"keras_remote.infra.container_builder._image_exists",
return_value=False,
),
mock.patch(
"keras_remote.infra.container_builder._build_and_push",
return_value="us-docker.pkg.dev/proj/keras-remote/base:l4-bbbbbbbbbbbb",
) as mock_build,
):
result = get_or_build_container(
base_image="python:3.12-slim",
requirements_path=None,
accelerator_type="l4",
project="proj",
zone="us-central1-a",
)

mock_build.assert_called_once()
self.assertEqual(
result, "us-docker.pkg.dev/proj/keras-remote/base:l4-bbbbbbbbbbbb"
)

def _get_image_uri(self, accelerator_type, project, zone):
with mock.patch(
"keras_remote.infra.container_builder._image_exists",
return_value=True,
):
return get_or_build_container(
base_image="python:3.12-slim",
requirements_path=None,
accelerator_type=accelerator_type,
project=project,
zone=zone,
)

def test_image_uri_format_tpu_europe(self):
result = self._get_image_uri("v3-8", "my-proj", "europe-west4-b")

self.assertTrue(
result.startswith("europe-docker.pkg.dev/my-proj/keras-remote/base:")
)
tag = result.split(":")[-1]
self.assertRegex(tag, r"^v3-8-[0-9a-f]{12}$")

def test_image_uri_format_gpu_us(self):
result = self._get_image_uri("a100-80gb", "proj", "us-central1-a")

self.assertTrue(
result.startswith("us-docker.pkg.dev/proj/keras-remote/base:")
)
tag = result.split(":")[-1]
self.assertRegex(tag, r"^a100-80gb-[0-9a-f]{12}$")


if __name__ == "__main__":
absltest.main()
Loading
Loading