diff --git a/keras_remote/infra/container_builder.py b/keras_remote/infra/container_builder.py index 58464d6..ff9f842 100644 --- a/keras_remote/infra/container_builder.py +++ b/keras_remote/infra/container_builder.py @@ -2,6 +2,7 @@ import hashlib import os +import re import shutil import string import tarfile @@ -28,15 +29,67 @@ ) _RUNNER_DIR = os.path.join(os.path.dirname(__file__), os.pardir, "runner") +# JAX-related packages managed by the Dockerfile template. +# User requirements containing these are filtered out to prevent overriding +# the accelerator-specific JAX installation (e.g., jax[tpu], jax[cuda12]). +_JAX_PACKAGE_NAMES = frozenset({"jax", "jaxlib", "libtpu", "libtpu-nightly"}) +_PACKAGE_NAME_RE = re.compile(r"^([a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?)") +_KEEP_MARKER = "# kr:keep" + + +def _filter_jax_requirements(requirements_content: str) -> str: + """Remove JAX-related packages from requirements content. + + Strips lines that would override the accelerator-specific JAX installation + managed by the Dockerfile template. Logs a warning for each filtered line. + + To preserve a JAX line, append ``# kr:keep`` to it in requirements.txt. + + Args: + requirements_content: Raw text of a requirements.txt file. + + Returns: + Filtered requirements text with JAX-related lines removed. + """ + filtered_lines = [] + for line in requirements_content.splitlines(keepends=True): + stripped = line.strip() + # Preserve blanks, comments, and pip flags (-e, --index-url, etc.) + if not stripped or stripped.startswith("#") or stripped.startswith("-"): + filtered_lines.append(line) + continue + + # Allow users to bypass the filter with an inline marker. + if _KEEP_MARKER in line: + filtered_lines.append(line) + continue + + m = _PACKAGE_NAME_RE.match(stripped) + if m: + # PEP 503 normalization: lowercase, collapse [-_.] to '-' + normalized = re.sub(r"[-_.]+", "-", m.group(1)).lower() + if normalized in _JAX_PACKAGE_NAMES: + logging.warning( + "Filtered '%s' from requirements — JAX is installed " + "automatically with the correct accelerator backend. " + "To override, add '# kr:keep' to the line.", + m.group(1), + ) + continue + + filtered_lines.append(line) + + return "".join(filtered_lines) + def get_or_build_container( - base_image, - requirements_path, - accelerator_type, - project, - zone=None, - cluster_name=None, -): + base_image: str, + requirements_path: str | None, + accelerator_type: str, + project: str, + zone: str | None = None, + cluster_name: str | None = None, +) -> str: """Get existing container or build if requirements changed. Uses content-based hashing to detect requirement changes. @@ -56,9 +109,15 @@ def get_or_build_container( cluster_name = cluster_name or get_default_cluster_name() category = accelerators.get_category(accelerator_type) + # Read and filter requirements once, reuse for hashing and building. + filtered_requirements = None + if requirements_path and os.path.exists(requirements_path): + with open(requirements_path, "r") as f: + filtered_requirements = _filter_jax_requirements(f.read()) + # Generate deterministic hash from requirements + base image + category requirements_hash = _hash_requirements( - requirements_path, category, base_image + filtered_requirements, category, base_image ) # Use category for image name (e.g., 'tpu-hash', 'gpu-hash') @@ -84,7 +143,7 @@ def get_or_build_container( logging.info("Building new container (requirements changed): %s", image_uri) return _build_and_push( base_image, - requirements_path, + filtered_requirements, category, project, image_uri, @@ -93,11 +152,13 @@ def get_or_build_container( ) -def _hash_requirements(requirements_path, category, base_image): +def _hash_requirements( + filtered_requirements: str | None, category: str, base_image: str +) -> str: """Create deterministic hash from requirements + category + remote_runner + base image. Args: - requirements_path: Path to requirements.txt (or None) + filtered_requirements: Pre-filtered requirements content (or None) category: Accelerator category ('cpu', 'gpu', 'tpu') base_image: Base Docker image (e.g., 'python:3.12-slim') @@ -106,9 +167,8 @@ def _hash_requirements(requirements_path, category, base_image): """ content = f"base_image={base_image}\ncategory={category}\n" - if requirements_path and os.path.exists(requirements_path): - with open(requirements_path, "r") as f: - content += f.read() + if filtered_requirements: + content += filtered_requirements # Include remote_runner.py in the hash so container rebuilds when it changes remote_runner_path = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME) @@ -125,7 +185,7 @@ def _hash_requirements(requirements_path, category, base_image): return hashlib.sha256(content.encode()).hexdigest() -def _image_exists(image_uri, project): +def _image_exists(image_uri: str, project: str) -> bool: """Check if image exists in Artifact Registry. Args: @@ -162,19 +222,19 @@ def _image_exists(image_uri, project): def _build_and_push( - base_image, - requirements_path, - category, - project, - image_uri, - ar_location="us", - cluster_name=None, -): + base_image: str, + filtered_requirements: str | None, + category: str, + project: str, + image_uri: str, + ar_location: str = "us", + cluster_name: str | None = None, +) -> str: """Build and push Docker image using Cloud Build. Args: base_image: Base Docker image - requirements_path: Path to requirements.txt (or None) + filtered_requirements: Pre-filtered requirements content (or None) category: Accelerator category ('cpu', 'gpu', 'tpu') project: GCP project ID image_uri: Target image URI @@ -187,7 +247,7 @@ def _build_and_push( # Generate Dockerfile dockerfile_content = _generate_dockerfile( base_image=base_image, - requirements_path=requirements_path, + has_requirements=filtered_requirements is not None, category=category, ) @@ -195,9 +255,10 @@ def _build_and_push( with open(dockerfile_path, "w") as f: f.write(dockerfile_content) - # Copy requirements.txt if it exists - if requirements_path and os.path.exists(requirements_path): - shutil.copy(requirements_path, os.path.join(tmpdir, "requirements.txt")) + # Write pre-filtered requirements.txt + if filtered_requirements is not None: + with open(os.path.join(tmpdir, "requirements.txt"), "w") as f: + f.write(filtered_requirements) # Copy remote_runner.py remote_runner_src = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME) @@ -209,7 +270,7 @@ def _build_and_push( with tarfile.open(tarball_path, "w:gz") as tar: tar.add(dockerfile_path, arcname="Dockerfile") tar.add(remote_runner_dst, arcname=REMOTE_RUNNER_FILE_NAME) - if requirements_path and os.path.exists(requirements_path): + if filtered_requirements is not None: tar.add( os.path.join(tmpdir, "requirements.txt"), arcname="requirements.txt" ) @@ -271,12 +332,14 @@ def _build_and_push( raise RuntimeError(f"Build failed with status: {result.status}") -def _generate_dockerfile(base_image, requirements_path, category): +def _generate_dockerfile( + base_image: str, has_requirements: bool, category: str +) -> str: """Generate Dockerfile content based on configuration. Args: base_image: Base Docker image - requirements_path: Path to requirements.txt (or None) + has_requirements: Whether filtered requirements content is available category: Accelerator category ('cpu', 'gpu', 'tpu') Returns: @@ -294,7 +357,7 @@ def _generate_dockerfile(base_image, requirements_path, category): jax_install = "RUN python3 -m pip install 'jax[cuda12]'" requirements_section = "" - if requirements_path and os.path.exists(requirements_path): + if has_requirements: requirements_section = ( "COPY requirements.txt /tmp/requirements.txt\n" "RUN python3 -m pip install -r /tmp/requirements.txt\n" @@ -311,7 +374,9 @@ def _generate_dockerfile(base_image, requirements_path, category): ) -def _upload_build_source(tarball_path, bucket_name, project): +def _upload_build_source( + tarball_path: str, bucket_name: str, project: str +) -> str: """Upload build source tarball to Cloud Storage. Args: diff --git a/keras_remote/infra/container_builder_test.py b/keras_remote/infra/container_builder_test.py index 1d16e27..e64e77f 100644 --- a/keras_remote/infra/container_builder_test.py +++ b/keras_remote/infra/container_builder_test.py @@ -1,7 +1,5 @@ """Tests for keras_remote.infra.container_builder — hashing, Dockerfile gen, caching.""" -import pathlib -import tempfile from unittest import mock from unittest.mock import MagicMock @@ -9,6 +7,7 @@ from google.api_core import exceptions as google_exceptions from keras_remote.infra.container_builder import ( + _filter_jax_requirements, _generate_dockerfile, _hash_requirements, _image_exists, @@ -16,71 +15,104 @@ ) -def _make_temp_path(test_case): - """Create a temp directory that is cleaned up after the test.""" - td = tempfile.TemporaryDirectory() - test_case.addCleanup(td.cleanup) - return pathlib.Path(td.name) +class TestFilterJaxRequirements(parameterized.TestCase): + @parameterized.named_parameters( + dict(testcase_name="bare_jax", line="jax\n"), + dict(testcase_name="jax_with_tpu_extras", line="jax[tpu]>=0.4.6\n"), + dict(testcase_name="jax_cuda", line="jax[cuda12]==0.4.30\n"), + dict(testcase_name="jax_cpu", line="jax[cpu]\n"), + dict(testcase_name="jaxlib", line="jaxlib>=0.4.6\n"), + dict(testcase_name="libtpu", line="libtpu\n"), + dict(testcase_name="libtpu_nightly_hyphen", line="libtpu-nightly\n"), + dict(testcase_name="libtpu_nightly_underscore", line="libtpu_nightly\n"), + dict(testcase_name="jax_uppercase", line="JAX\n"), + dict(testcase_name="jax_mixed_case", line="Jax[tpu]\n"), + ) + def test_filters_jax_packages(self, line): + self.assertEqual(_filter_jax_requirements(line), "") + + @parameterized.named_parameters( + dict(testcase_name="numpy", line="numpy==1.26\n"), + dict(testcase_name="keras", line="keras\n"), + dict(testcase_name="scipy", line="scipy>=1.12\n"), + dict(testcase_name="comment", line="# jax should be here\n"), + dict(testcase_name="blank", line="\n"), + dict(testcase_name="pip_flag", line="-e git+https://foo\n"), + dict(testcase_name="index_url", line="--index-url https://pypi.org\n"), + ) + def test_preserves_non_jax_packages(self, line): + self.assertEqual(_filter_jax_requirements(line), line) + + @parameterized.named_parameters( + dict(testcase_name="jax_keep", line="jax==0.4.30 # kr:keep\n"), + dict(testcase_name="jaxlib_keep", line="jaxlib # kr:keep\n"), + dict(testcase_name="libtpu_keep", line="libtpu-nightly # kr:keep\n"), + ) + def test_kr_keep_overrides_filter(self, line): + self.assertEqual(_filter_jax_requirements(line), line) + + def test_mixed_requirements(self): + content = ( + "numpy==1.26\njax[tpu]>=0.4.6\nscipy\n" + "jaxlib\nkeras\njax==0.4.30 # kr:keep\n" + ) + result = _filter_jax_requirements(content) + self.assertEqual( + result, "numpy==1.26\nscipy\nkeras\njax==0.4.30 # kr:keep\n" + ) + + def test_empty_string(self): + self.assertEqual(_filter_jax_requirements(""), "") + + def test_only_jax_packages(self): + self.assertEqual(_filter_jax_requirements("jax\njaxlib\nlibtpu\n"), "") + + def test_preserves_comments_and_blanks(self): + content = "# ML deps\nnumpy\n\njax\n# end\n" + result = _filter_jax_requirements(content) + self.assertEqual(result, "# ML deps\nnumpy\n\n# end\n") class TestHashRequirements(parameterized.TestCase): def test_deterministic(self): - tmp_path = _make_temp_path(self) - req = tmp_path / "requirements.txt" - req.write_text("numpy==1.26\n") - - h1 = _hash_requirements(str(req), "gpu", "python:3.12-slim") - h2 = _hash_requirements(str(req), "gpu", "python:3.12-slim") + h1 = _hash_requirements("numpy==1.26\n", "gpu", "python:3.12-slim") + h2 = _hash_requirements("numpy==1.26\n", "gpu", "python:3.12-slim") self.assertEqual(h1, h2) def test_different_requirements_different_hash(self): - tmp_path = _make_temp_path(self) - req1 = tmp_path / "r1.txt" - req1.write_text("numpy==1.26\n") - req2 = tmp_path / "r2.txt" - req2.write_text("scipy==1.12\n") - - h1 = _hash_requirements(str(req1), "gpu", "python:3.12-slim") - h2 = _hash_requirements(str(req2), "gpu", "python:3.12-slim") + h1 = _hash_requirements("numpy==1.26\n", "gpu", "python:3.12-slim") + h2 = _hash_requirements("scipy==1.12\n", "gpu", "python:3.12-slim") self.assertNotEqual(h1, h2) def test_different_category_different_hash(self): - tmp_path = _make_temp_path(self) - req = tmp_path / "requirements.txt" - req.write_text("numpy\n") - - h1 = _hash_requirements(str(req), "gpu", "python:3.12-slim") - h2 = _hash_requirements(str(req), "tpu", "python:3.12-slim") + h1 = _hash_requirements("numpy\n", "gpu", "python:3.12-slim") + h2 = _hash_requirements("numpy\n", "tpu", "python:3.12-slim") self.assertNotEqual(h1, h2) def test_different_base_image_different_hash(self): - tmp_path = _make_temp_path(self) - req = tmp_path / "requirements.txt" - req.write_text("numpy\n") - - h1 = _hash_requirements(str(req), "gpu", "python:3.12-slim") - h2 = _hash_requirements(str(req), "gpu", "python:3.11-slim") + h1 = _hash_requirements("numpy\n", "gpu", "python:3.12-slim") + h2 = _hash_requirements("numpy\n", "gpu", "python:3.11-slim") self.assertNotEqual(h1, h2) - @parameterized.named_parameters( - dict(testcase_name="none", requirements_path=None), - dict( - testcase_name="nonexistent", - requirements_path="/nonexistent/path.txt", - ), - ) - def test_missing_requirements_valid(self, requirements_path): - h = _hash_requirements(requirements_path, "cpu", "python:3.12-slim") + def test_missing_requirements_valid(self): + h = _hash_requirements(None, "cpu", "python:3.12-slim") self.assertIsInstance(h, str) self.assertLen(h, 64) def test_returns_hex_string(self): - tmp_path = _make_temp_path(self) - req = tmp_path / "r.txt" - req.write_text("keras\n") - h = _hash_requirements(str(req), "gpu", "python:3.12-slim") + h = _hash_requirements("keras\n", "gpu", "python:3.12-slim") self.assertRegex(h, r"^[0-9a-f]{64}$") + def test_jax_in_requirements_does_not_affect_hash(self): + filtered_without_jax = _filter_jax_requirements("numpy==1.26\n") + filtered_with_jax = _filter_jax_requirements( + "numpy==1.26\njax[tpu]>=0.4.6\n" + ) + + h1 = _hash_requirements(filtered_without_jax, "tpu", "python:3.12-slim") + h2 = _hash_requirements(filtered_with_jax, "tpu", "python:3.12-slim") + self.assertEqual(h1, h2) + class TestGenerateDockerfile(parameterized.TestCase): @parameterized.named_parameters( @@ -106,7 +138,7 @@ class TestGenerateDockerfile(parameterized.TestCase): def test_jax_install(self, category, expected, not_expected): content = _generate_dockerfile( base_image="python:3.12-slim", - requirements_path=None, + has_requirements=False, category=category, ) for s in expected: @@ -115,13 +147,9 @@ def test_jax_install(self, category, expected, not_expected): self.assertNotIn(s, content) def test_with_requirements(self): - tmp_path = _make_temp_path(self) - req = tmp_path / "requirements.txt" - req.write_text("numpy\n") - content = _generate_dockerfile( base_image="python:3.12-slim", - requirements_path=str(req), + has_requirements=True, category="cpu", ) self.assertIn("COPY requirements.txt", content) @@ -130,7 +158,7 @@ def test_with_requirements(self): def test_without_requirements(self): content = _generate_dockerfile( base_image="python:3.12-slim", - requirements_path=None, + has_requirements=False, category="cpu", ) self.assertNotIn("COPY requirements.txt", content) @@ -148,7 +176,7 @@ def test_without_requirements(self): def test_contains_expected_content(self, expected_substring): content = _generate_dockerfile( base_image="python:3.12-slim", - requirements_path=None, + has_requirements=False, category="cpu", ) self.assertIn(expected_substring, content) @@ -156,7 +184,7 @@ def test_contains_expected_content(self, expected_substring): def test_uses_base_image(self): content = _generate_dockerfile( base_image="python:3.11-bullseye", - requirements_path=None, + has_requirements=False, category="cpu", ) self.assertIn("FROM python:3.11-bullseye", content)