diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 2e554a5..d851152 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -24,7 +24,7 @@ jobs: run: > coverage run -m unittest discover - -s keras_remote -p "test_*.py" + -s keras_remote -p "*_test.py" -v - name: Generate coverage report diff --git a/keras_remote/backend/test_execution.py b/keras_remote/backend/execution_test.py similarity index 100% rename from keras_remote/backend/test_execution.py rename to keras_remote/backend/execution_test.py diff --git a/keras_remote/backend/test_gke_client.py b/keras_remote/backend/gke_client_test.py similarity index 100% rename from keras_remote/backend/test_gke_client.py rename to keras_remote/backend/gke_client_test.py diff --git a/keras_remote/cli/test_prerequisites_check.py b/keras_remote/cli/prerequisites_check_test.py similarity index 100% rename from keras_remote/cli/test_prerequisites_check.py rename to keras_remote/cli/prerequisites_check_test.py diff --git a/keras_remote/test_constants.py b/keras_remote/constants_test.py similarity index 100% rename from keras_remote/test_constants.py rename to keras_remote/constants_test.py diff --git a/keras_remote/core/test_accelerators.py b/keras_remote/core/accelerators_test.py similarity index 100% rename from keras_remote/core/test_accelerators.py rename to keras_remote/core/accelerators_test.py diff --git a/keras_remote/core/test_core.py b/keras_remote/core/core_test.py similarity index 100% rename from keras_remote/core/test_core.py rename to keras_remote/core/core_test.py diff --git a/keras_remote/infra/container_builder_test.py b/keras_remote/infra/container_builder_test.py new file mode 100644 index 0000000..abcc019 --- /dev/null +++ b/keras_remote/infra/container_builder_test.py @@ -0,0 +1,301 @@ +"""Tests for keras_remote.infra.container_builder — hashing, Dockerfile gen, caching.""" + +import pathlib +import tempfile +from unittest import mock +from unittest.mock import MagicMock + +from absl.testing import absltest, parameterized +from google.api_core import exceptions as google_exceptions + +from keras_remote.infra.container_builder import ( + _generate_dockerfile, + _hash_requirements, + _image_exists, + get_or_build_container, +) + + +def _make_temp_path(test_case): + """Create a temp directory that is cleaned up after the test.""" + td = tempfile.TemporaryDirectory() + test_case.addCleanup(td.cleanup) + return pathlib.Path(td.name) + + +class TestHashRequirements(parameterized.TestCase): + def test_deterministic(self): + tmp_path = _make_temp_path(self) + req = tmp_path / "requirements.txt" + req.write_text("numpy==1.26\n") + + h1 = _hash_requirements(str(req), "l4", "python:3.12-slim") + h2 = _hash_requirements(str(req), "l4", "python:3.12-slim") + self.assertEqual(h1, h2) + + def test_different_requirements_different_hash(self): + tmp_path = _make_temp_path(self) + req1 = tmp_path / "r1.txt" + req1.write_text("numpy==1.26\n") + req2 = tmp_path / "r2.txt" + req2.write_text("scipy==1.12\n") + + h1 = _hash_requirements(str(req1), "l4", "python:3.12-slim") + h2 = _hash_requirements(str(req2), "l4", "python:3.12-slim") + self.assertNotEqual(h1, h2) + + def test_different_accelerator_different_hash(self): + tmp_path = _make_temp_path(self) + req = tmp_path / "requirements.txt" + req.write_text("numpy\n") + + h1 = _hash_requirements(str(req), "l4", "python:3.12-slim") + h2 = _hash_requirements(str(req), "v3-8", "python:3.12-slim") + self.assertNotEqual(h1, h2) + + def test_different_base_image_different_hash(self): + tmp_path = _make_temp_path(self) + req = tmp_path / "requirements.txt" + req.write_text("numpy\n") + + h1 = _hash_requirements(str(req), "l4", "python:3.12-slim") + h2 = _hash_requirements(str(req), "l4", "python:3.11-slim") + self.assertNotEqual(h1, h2) + + @parameterized.named_parameters( + dict(testcase_name="none", requirements_path=None), + dict( + testcase_name="nonexistent", + requirements_path="/nonexistent/path.txt", + ), + ) + def test_missing_requirements_valid(self, requirements_path): + h = _hash_requirements(requirements_path, "cpu", "python:3.12-slim") + self.assertIsInstance(h, str) + self.assertLen(h, 64) + + def test_returns_hex_string(self): + tmp_path = _make_temp_path(self) + req = tmp_path / "r.txt" + req.write_text("keras\n") + h = _hash_requirements(str(req), "l4", "python:3.12-slim") + self.assertRegex(h, r"^[0-9a-f]{64}$") + + +class TestGenerateDockerfile(parameterized.TestCase): + @parameterized.named_parameters( + dict( + testcase_name="cpu", + accelerator_type="cpu", + expected=["pip install jax"], + not_expected=["cuda", "tpu"], + ), + dict( + testcase_name="gpu", + accelerator_type="l4", + expected=["jax[cuda12]"], + not_expected=[], + ), + dict( + testcase_name="tpu", + accelerator_type="v3-8", + expected=["jax[tpu]", "libtpu_releases"], + not_expected=[], + ), + ) + def test_jax_install(self, accelerator_type, expected, not_expected): + content = _generate_dockerfile( + base_image="python:3.12-slim", + requirements_path=None, + accelerator_type=accelerator_type, + ) + for s in expected: + self.assertIn(s, content) + for s in not_expected: + self.assertNotIn(s, content) + + def test_with_requirements(self): + tmp_path = _make_temp_path(self) + req = tmp_path / "requirements.txt" + req.write_text("numpy\n") + + content = _generate_dockerfile( + base_image="python:3.12-slim", + requirements_path=str(req), + accelerator_type="cpu", + ) + self.assertIn("COPY requirements.txt", content) + self.assertIn("pip install -r", content) + + def test_without_requirements(self): + content = _generate_dockerfile( + base_image="python:3.12-slim", + requirements_path=None, + accelerator_type="cpu", + ) + self.assertNotIn("COPY requirements.txt", content) + + @parameterized.named_parameters( + dict( + testcase_name="remote_runner_copy", + expected_substring="COPY remote_runner.py /app/remote_runner.py", + ), + dict( + testcase_name="keras_backend_env", + expected_substring="ENV KERAS_BACKEND=jax", + ), + ) + def test_contains_expected_content(self, expected_substring): + content = _generate_dockerfile( + base_image="python:3.12-slim", + requirements_path=None, + accelerator_type="cpu", + ) + self.assertIn(expected_substring, content) + + def test_uses_base_image(self): + content = _generate_dockerfile( + base_image="python:3.11-bullseye", + requirements_path=None, + accelerator_type="cpu", + ) + self.assertIn("FROM python:3.11-bullseye", content) + + +class TestImageExists(parameterized.TestCase): + def test_returns_true_when_tag_found(self): + mock_client = MagicMock() + with mock.patch( + "keras_remote.infra.container_builder.artifactregistry_v1.ArtifactRegistryClient", + return_value=mock_client, + ): + result = _image_exists( + "us-docker.pkg.dev/my-proj/keras-remote/base:l4-abc123", + "my-proj", + ) + self.assertTrue(result) + mock_client.get_tag.assert_called_once() + + @parameterized.named_parameters( + dict( + testcase_name="not_found", + side_effect=google_exceptions.NotFound("nope"), + ), + dict( + testcase_name="other_error", + side_effect=RuntimeError("unexpected"), + ), + ) + def test_returns_false_on_error(self, side_effect): + mock_client = MagicMock() + mock_client.get_tag.side_effect = side_effect + with mock.patch( + "keras_remote.infra.container_builder.artifactregistry_v1.ArtifactRegistryClient", + return_value=mock_client, + ): + result = _image_exists( + "us-docker.pkg.dev/my-proj/keras-remote/base:l4-abc123", + "my-proj", + ) + self.assertFalse(result) + + def test_correct_resource_name(self): + mock_client = MagicMock() + with mock.patch( + "keras_remote.infra.container_builder.artifactregistry_v1.ArtifactRegistryClient", + return_value=mock_client, + ): + _image_exists( + "us-docker.pkg.dev/my-proj/keras-remote/base:v3-8-abc123def456", + "my-proj", + ) + call_args = mock_client.get_tag.call_args + request = call_args.kwargs["request"] + self.assertEqual( + request.name, + "projects/my-proj/locations/us" + "/repositories/keras-remote" + "/packages/base/tags/v3-8-abc123def456", + ) + + +class TestGetOrBuildContainer(absltest.TestCase): + def test_returns_cached_when_image_exists(self): + with ( + mock.patch( + "keras_remote.infra.container_builder._image_exists", + return_value=True, + ), + mock.patch( + "keras_remote.infra.container_builder._build_and_push", + ) as mock_build, + ): + result = get_or_build_container( + base_image="python:3.12-slim", + requirements_path=None, + accelerator_type="l4", + project="test-proj", + zone="us-central1-a", + ) + + mock_build.assert_not_called() + self.assertIn("us-docker.pkg.dev/test-proj/keras-remote/base:", result) + + def test_builds_when_image_missing(self): + with ( + mock.patch( + "keras_remote.infra.container_builder._image_exists", + return_value=False, + ), + mock.patch( + "keras_remote.infra.container_builder._build_and_push", + return_value="us-docker.pkg.dev/proj/keras-remote/base:l4-bbbbbbbbbbbb", + ) as mock_build, + ): + result = get_or_build_container( + base_image="python:3.12-slim", + requirements_path=None, + accelerator_type="l4", + project="proj", + zone="us-central1-a", + ) + + mock_build.assert_called_once() + self.assertEqual( + result, "us-docker.pkg.dev/proj/keras-remote/base:l4-bbbbbbbbbbbb" + ) + + def _get_image_uri(self, accelerator_type, project, zone): + with mock.patch( + "keras_remote.infra.container_builder._image_exists", + return_value=True, + ): + return get_or_build_container( + base_image="python:3.12-slim", + requirements_path=None, + accelerator_type=accelerator_type, + project=project, + zone=zone, + ) + + def test_image_uri_format_tpu_europe(self): + result = self._get_image_uri("v3-8", "my-proj", "europe-west4-b") + + self.assertTrue( + result.startswith("europe-docker.pkg.dev/my-proj/keras-remote/base:") + ) + tag = result.split(":")[-1] + self.assertRegex(tag, r"^v3-8-[0-9a-f]{12}$") + + def test_image_uri_format_gpu_us(self): + result = self._get_image_uri("a100-80gb", "proj", "us-central1-a") + + self.assertTrue( + result.startswith("us-docker.pkg.dev/proj/keras-remote/base:") + ) + tag = result.split(":")[-1] + self.assertRegex(tag, r"^a100-80gb-[0-9a-f]{12}$") + + +if __name__ == "__main__": + absltest.main() diff --git a/keras_remote/runner/test_remote_runner.py b/keras_remote/runner/remote_runner_test.py similarity index 100% rename from keras_remote/runner/test_remote_runner.py rename to keras_remote/runner/remote_runner_test.py diff --git a/keras_remote/utils/test_packager.py b/keras_remote/utils/packager_test.py similarity index 100% rename from keras_remote/utils/test_packager.py rename to keras_remote/utils/packager_test.py diff --git a/keras_remote/utils/storage_test.py b/keras_remote/utils/storage_test.py new file mode 100644 index 0000000..b31e623 --- /dev/null +++ b/keras_remote/utils/storage_test.py @@ -0,0 +1,136 @@ +"""Tests for keras_remote.utils.storage — Cloud Storage operations.""" + +import os +from unittest import mock +from unittest.mock import MagicMock + +from absl.testing import absltest, parameterized + +from keras_remote.utils.storage import ( + _get_project, + cleanup_artifacts, + download_result, + upload_artifacts, +) + + +class _GcsTestBase(absltest.TestCase): + """Base class that provides a mocked GCS client.""" + + def setUp(self): + super().setUp() + self.mock_gcs = MagicMock() + patcher = mock.patch( + "keras_remote.utils.storage.storage.Client", + return_value=self.mock_gcs, + ) + patcher.start() + self.addCleanup(patcher.stop) + + +class TestUploadArtifacts(_GcsTestBase): + def test_uploads_payload_and_context(self): + mock_bucket = self.mock_gcs.bucket.return_value + mock_blob = mock_bucket.blob.return_value + + upload_artifacts( + bucket_name="my-bucket", + job_id="job-abc123", + payload_path="/tmp/payload.pkl", + context_path="/tmp/context.zip", + project="test-project", + ) + + mock_bucket.blob.assert_any_call("job-abc123/payload.pkl") + mock_bucket.blob.assert_any_call("job-abc123/context.zip") + self.assertEqual(mock_blob.upload_from_filename.call_count, 2) + + def test_uses_correct_bucket(self): + upload_artifacts( + bucket_name="my-custom-bucket", + job_id="job-123", + payload_path="/tmp/p.pkl", + context_path="/tmp/c.zip", + project="proj", + ) + self.mock_gcs.bucket.assert_called_with("my-custom-bucket") + + +class TestDownloadResult(_GcsTestBase): + def test_downloads_result_blob(self): + mock_bucket = self.mock_gcs.bucket.return_value + mock_blob = mock_bucket.blob.return_value + + download_result("my-bucket", "job-abc", project="proj") + + mock_bucket.blob.assert_called_once_with("job-abc/result.pkl") + mock_blob.download_to_filename.assert_called_once() + + def test_returns_path_with_job_id(self): + result = download_result("my-bucket", "job-xyz", project="proj") + self.assertIn("result-job-xyz.pkl", result) + + +class TestCleanupArtifacts(_GcsTestBase): + def test_deletes_all_blobs(self): + mock_bucket = self.mock_gcs.bucket.return_value + blob1 = MagicMock() + blob2 = MagicMock() + blob3 = MagicMock() + mock_bucket.list_blobs.return_value = [blob1, blob2, blob3] + + cleanup_artifacts("my-bucket", "job-abc", project="proj") + + mock_bucket.list_blobs.assert_called_once_with(prefix="job-abc/") + blob1.delete.assert_called_once() + blob2.delete.assert_called_once() + blob3.delete.assert_called_once() + + def test_no_blobs_no_error(self): + mock_bucket = self.mock_gcs.bucket.return_value + mock_bucket.list_blobs.return_value = [] + + cleanup_artifacts("my-bucket", "job-abc", project="proj") + + mock_bucket.list_blobs.assert_called_once_with(prefix="job-abc/") + + +class TestGetProject(parameterized.TestCase): + @parameterized.named_parameters( + dict( + testcase_name="keras_remote_project_only", + kr_project="kr-proj", + gc_project=None, + expected="kr-proj", + ), + dict( + testcase_name="google_cloud_project_fallback", + kr_project=None, + gc_project="gc-proj", + expected="gc-proj", + ), + dict( + testcase_name="neither_set", + kr_project=None, + gc_project=None, + expected=None, + ), + dict( + testcase_name="keras_remote_takes_precedence", + kr_project="kr-proj", + gc_project="gc-proj", + expected="kr-proj", + ), + ) + def test_resolves_project(self, kr_project, gc_project, expected): + env = {} + if kr_project: + env["KERAS_REMOTE_PROJECT"] = kr_project + if gc_project: + env["GOOGLE_CLOUD_PROJECT"] = gc_project + with mock.patch.dict(os.environ, env, clear=True): + self.assertEqual(_get_project(), expected) + + +if __name__ == "__main__": + absltest.main() diff --git a/keras_remote/test_utils.py b/keras_remote/utils_test.py similarity index 100% rename from keras_remote/test_utils.py rename to keras_remote/utils_test.py diff --git a/tests/integration/test_packager_roundtrip.py b/tests/integration/packager_roundtrip_test.py similarity index 100% rename from tests/integration/test_packager_roundtrip.py rename to tests/integration/packager_roundtrip_test.py