Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 84 additions & 22 deletions keras_remote/infra/container_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import hashlib
import os
import re
import shutil
import string
import tarfile
Expand All @@ -28,15 +29,67 @@
)
_RUNNER_DIR = os.path.join(os.path.dirname(__file__), os.pardir, "runner")

# JAX-related packages managed by the Dockerfile template.
# User requirements containing these are filtered out to prevent overriding
# the accelerator-specific JAX installation (e.g., jax[tpu], jax[cuda12]).
_JAX_PACKAGE_NAMES = frozenset({"jax", "jaxlib", "libtpu", "libtpu-nightly"})
_PACKAGE_NAME_RE = re.compile(r"^([a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?)")
_KEEP_MARKER = "# kr:keep"


def _filter_jax_requirements(requirements_content: str) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we also support requirements specified in pyproject.toml?

Copy link
Collaborator Author

@JyotinderSingh JyotinderSingh Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should add support for that in subsequent PRs. Created #81 for tracking.

"""Remove JAX-related packages from requirements content.

Strips lines that would override the accelerator-specific JAX installation
managed by the Dockerfile template. Logs a warning for each filtered line.

To preserve a JAX line, append ``# kr:keep`` to it in requirements.txt.

Args:
requirements_content: Raw text of a requirements.txt file.

Returns:
Filtered requirements text with JAX-related lines removed.
"""
filtered_lines = []
for line in requirements_content.splitlines(keepends=True):
stripped = line.strip()
# Preserve blanks, comments, and pip flags (-e, --index-url, etc.)
if not stripped or stripped.startswith("#") or stripped.startswith("-"):
filtered_lines.append(line)
continue

# Allow users to bypass the filter with an inline marker.
if _KEEP_MARKER in line:
filtered_lines.append(line)
continue

m = _PACKAGE_NAME_RE.match(stripped)
if m:
# PEP 503 normalization: lowercase, collapse [-_.] to '-'
normalized = re.sub(r"[-_.]+", "-", m.group(1)).lower()
if normalized in _JAX_PACKAGE_NAMES:
logging.warning(
"Filtered '%s' from requirements — JAX is installed "
"automatically with the correct accelerator backend. "
"To override, add '# kr:keep' to the line.",
stripped,
)
continue

filtered_lines.append(line)

return "".join(filtered_lines)


def get_or_build_container(
base_image,
requirements_path,
accelerator_type,
project,
zone=None,
cluster_name=None,
):
base_image: str,
requirements_path: str | None,
accelerator_type: str,
project: str,
zone: str | None = None,
cluster_name: str | None = None,
) -> str:
"""Get existing container or build if requirements changed.

Uses content-based hashing to detect requirement changes.
Expand Down Expand Up @@ -93,7 +146,9 @@ def get_or_build_container(
)


def _hash_requirements(requirements_path, category, base_image):
def _hash_requirements(
requirements_path: str | None, category: str, base_image: str
) -> str:
"""Create deterministic hash from requirements + category + remote_runner + base image.

Args:
Expand All @@ -108,7 +163,7 @@ def _hash_requirements(requirements_path, category, base_image):

if requirements_path and os.path.exists(requirements_path):
with open(requirements_path, "r") as f:
content += f.read()
content += _filter_jax_requirements(f.read())

# Include remote_runner.py in the hash so container rebuilds when it changes
remote_runner_path = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME)
Expand All @@ -125,7 +180,7 @@ def _hash_requirements(requirements_path, category, base_image):
return hashlib.sha256(content.encode()).hexdigest()


def _image_exists(image_uri, project):
def _image_exists(image_uri: str, project: str) -> bool:
"""Check if image exists in Artifact Registry.

Args:
Expand Down Expand Up @@ -162,14 +217,14 @@ def _image_exists(image_uri, project):


def _build_and_push(
base_image,
requirements_path,
category,
project,
image_uri,
ar_location="us",
cluster_name=None,
):
base_image: str,
requirements_path: str | None,
category: str,
project: str,
image_uri: str,
ar_location: str = "us",
cluster_name: str | None = None,
) -> str:
"""Build and push Docker image using Cloud Build.

Args:
Expand All @@ -195,9 +250,12 @@ def _build_and_push(
with open(dockerfile_path, "w") as f:
f.write(dockerfile_content)

# Copy requirements.txt if it exists
# Copy requirements.txt (with JAX-related packages filtered out)
if requirements_path and os.path.exists(requirements_path):
shutil.copy(requirements_path, os.path.join(tmpdir, "requirements.txt"))
with open(requirements_path, "r") as f:
filtered = _filter_jax_requirements(f.read())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The requirements.txt file is being read and filtered here, but it's also read and filtered earlier in _hash_requirements (lines 162-164). This duplicates work and file I/O.

Consider refactoring to read and filter the requirements file only once within the get_or_build_container function. The resulting filtered content could then be passed as an argument to both _hash_requirements and _build_and_push to avoid this duplication.

with open(os.path.join(tmpdir, "requirements.txt"), "w") as f:
f.write(filtered)

# Copy remote_runner.py
remote_runner_src = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME)
Expand Down Expand Up @@ -271,7 +329,9 @@ def _build_and_push(
raise RuntimeError(f"Build failed with status: {result.status}")


def _generate_dockerfile(base_image, requirements_path, category):
def _generate_dockerfile(
base_image: str, requirements_path: str | None, category: str
) -> str:
"""Generate Dockerfile content based on configuration.

Args:
Expand Down Expand Up @@ -311,7 +371,9 @@ def _generate_dockerfile(base_image, requirements_path, category):
)


def _upload_build_source(tarball_path, bucket_name, project):
def _upload_build_source(
tarball_path: str, bucket_name: str, project: str
) -> str:
"""Upload build source tarball to Cloud Storage.

Args:
Expand Down
70 changes: 70 additions & 0 deletions keras_remote/infra/container_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from google.api_core import exceptions as google_exceptions

from keras_remote.infra.container_builder import (
_filter_jax_requirements,
_generate_dockerfile,
_hash_requirements,
_image_exists,
Expand All @@ -23,6 +24,64 @@ def _make_temp_path(test_case):
return pathlib.Path(td.name)


class TestFilterJaxRequirements(parameterized.TestCase):
@parameterized.named_parameters(
dict(testcase_name="bare_jax", line="jax\n"),
dict(testcase_name="jax_with_tpu_extras", line="jax[tpu]>=0.4.6\n"),
dict(testcase_name="jax_cuda", line="jax[cuda12]==0.4.30\n"),
dict(testcase_name="jax_cpu", line="jax[cpu]\n"),
dict(testcase_name="jaxlib", line="jaxlib>=0.4.6\n"),
dict(testcase_name="libtpu", line="libtpu\n"),
dict(testcase_name="libtpu_nightly_hyphen", line="libtpu-nightly\n"),
dict(testcase_name="libtpu_nightly_underscore", line="libtpu_nightly\n"),
dict(testcase_name="jax_uppercase", line="JAX\n"),
dict(testcase_name="jax_mixed_case", line="Jax[tpu]\n"),
)
def test_filters_jax_packages(self, line):
self.assertEqual(_filter_jax_requirements(line), "")

@parameterized.named_parameters(
dict(testcase_name="numpy", line="numpy==1.26\n"),
dict(testcase_name="keras", line="keras\n"),
dict(testcase_name="scipy", line="scipy>=1.12\n"),
dict(testcase_name="comment", line="# jax should be here\n"),
dict(testcase_name="blank", line="\n"),
dict(testcase_name="pip_flag", line="-e git+https://foo\n"),
dict(testcase_name="index_url", line="--index-url https://pypi.org\n"),
)
def test_preserves_non_jax_packages(self, line):
self.assertEqual(_filter_jax_requirements(line), line)

@parameterized.named_parameters(
dict(testcase_name="jax_keep", line="jax==0.4.30 # kr:keep\n"),
dict(testcase_name="jaxlib_keep", line="jaxlib # kr:keep\n"),
dict(testcase_name="libtpu_keep", line="libtpu-nightly # kr:keep\n"),
)
def test_kr_keep_overrides_filter(self, line):
self.assertEqual(_filter_jax_requirements(line), line)

def test_mixed_requirements(self):
content = (
"numpy==1.26\njax[tpu]>=0.4.6\nscipy\n"
"jaxlib\nkeras\njax==0.4.30 # kr:keep\n"
)
result = _filter_jax_requirements(content)
self.assertEqual(
result, "numpy==1.26\nscipy\nkeras\njax==0.4.30 # kr:keep\n"
)

def test_empty_string(self):
self.assertEqual(_filter_jax_requirements(""), "")

def test_only_jax_packages(self):
self.assertEqual(_filter_jax_requirements("jax\njaxlib\nlibtpu\n"), "")

def test_preserves_comments_and_blanks(self):
content = "# ML deps\nnumpy\n\njax\n# end\n"
result = _filter_jax_requirements(content)
self.assertEqual(result, "# ML deps\nnumpy\n\n# end\n")


class TestHashRequirements(parameterized.TestCase):
def test_deterministic(self):
tmp_path = _make_temp_path(self)
Expand Down Expand Up @@ -81,6 +140,17 @@ def test_returns_hex_string(self):
h = _hash_requirements(str(req), "gpu", "python:3.12-slim")
self.assertRegex(h, r"^[0-9a-f]{64}$")

def test_jax_in_requirements_does_not_affect_hash(self):
tmp_path = _make_temp_path(self)
req_without_jax = tmp_path / "r1.txt"
req_without_jax.write_text("numpy==1.26\n")
req_with_jax = tmp_path / "r2.txt"
req_with_jax.write_text("numpy==1.26\njax[tpu]>=0.4.6\n")

h1 = _hash_requirements(str(req_without_jax), "tpu", "python:3.12-slim")
h2 = _hash_requirements(str(req_with_jax), "tpu", "python:3.12-slim")
self.assertEqual(h1, h2)


class TestGenerateDockerfile(parameterized.TestCase):
@parameterized.named_parameters(
Expand Down
Loading