Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 98 additions & 33 deletions keras_remote/infra/container_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import hashlib
import os
import re
import shutil
import string
import tarfile
Expand All @@ -28,15 +29,67 @@
)
_RUNNER_DIR = os.path.join(os.path.dirname(__file__), os.pardir, "runner")

# JAX-related packages managed by the Dockerfile template.
# User requirements containing these are filtered out to prevent overriding
# the accelerator-specific JAX installation (e.g., jax[tpu], jax[cuda12]).
_JAX_PACKAGE_NAMES = frozenset({"jax", "jaxlib", "libtpu", "libtpu-nightly"})
_PACKAGE_NAME_RE = re.compile(r"^([a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?)")
_KEEP_MARKER = "# kr:keep"


def _filter_jax_requirements(requirements_content: str) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we also support requirements specified in pyproject.toml?

Copy link
Collaborator Author

@JyotinderSingh JyotinderSingh Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should add support for that in subsequent PRs. Created #81 for tracking.

"""Remove JAX-related packages from requirements content.

Strips lines that would override the accelerator-specific JAX installation
managed by the Dockerfile template. Logs a warning for each filtered line.

To preserve a JAX line, append ``# kr:keep`` to it in requirements.txt.

Args:
requirements_content: Raw text of a requirements.txt file.

Returns:
Filtered requirements text with JAX-related lines removed.
"""
filtered_lines = []
for line in requirements_content.splitlines(keepends=True):
stripped = line.strip()
# Preserve blanks, comments, and pip flags (-e, --index-url, etc.)
if not stripped or stripped.startswith("#") or stripped.startswith("-"):
filtered_lines.append(line)
continue

# Allow users to bypass the filter with an inline marker.
if _KEEP_MARKER in line:
filtered_lines.append(line)
continue

m = _PACKAGE_NAME_RE.match(stripped)
if m:
# PEP 503 normalization: lowercase, collapse [-_.] to '-'
normalized = re.sub(r"[-_.]+", "-", m.group(1)).lower()
if normalized in _JAX_PACKAGE_NAMES:
logging.warning(
"Filtered '%s' from requirements — JAX is installed "
"automatically with the correct accelerator backend. "
"To override, add '# kr:keep' to the line.",
m.group(1),
)
continue

filtered_lines.append(line)

return "".join(filtered_lines)


def get_or_build_container(
base_image,
requirements_path,
accelerator_type,
project,
zone=None,
cluster_name=None,
):
base_image: str,
requirements_path: str | None,
accelerator_type: str,
project: str,
zone: str | None = None,
cluster_name: str | None = None,
) -> str:
"""Get existing container or build if requirements changed.

Uses content-based hashing to detect requirement changes.
Expand All @@ -56,9 +109,15 @@ def get_or_build_container(
cluster_name = cluster_name or get_default_cluster_name()
category = accelerators.get_category(accelerator_type)

# Read and filter requirements once, reuse for hashing and building.
filtered_requirements = None
if requirements_path and os.path.exists(requirements_path):
with open(requirements_path, "r") as f:
filtered_requirements = _filter_jax_requirements(f.read())

# Generate deterministic hash from requirements + base image + category
requirements_hash = _hash_requirements(
requirements_path, category, base_image
filtered_requirements, category, base_image
)

# Use category for image name (e.g., 'tpu-hash', 'gpu-hash')
Expand All @@ -84,7 +143,7 @@ def get_or_build_container(
logging.info("Building new container (requirements changed): %s", image_uri)
return _build_and_push(
base_image,
requirements_path,
filtered_requirements,
category,
project,
image_uri,
Expand All @@ -93,11 +152,13 @@ def get_or_build_container(
)


def _hash_requirements(requirements_path, category, base_image):
def _hash_requirements(
filtered_requirements: str | None, category: str, base_image: str
) -> str:
"""Create deterministic hash from requirements + category + remote_runner + base image.

Args:
requirements_path: Path to requirements.txt (or None)
filtered_requirements: Pre-filtered requirements content (or None)
category: Accelerator category ('cpu', 'gpu', 'tpu')
base_image: Base Docker image (e.g., 'python:3.12-slim')

Expand All @@ -106,9 +167,8 @@ def _hash_requirements(requirements_path, category, base_image):
"""
content = f"base_image={base_image}\ncategory={category}\n"

if requirements_path and os.path.exists(requirements_path):
with open(requirements_path, "r") as f:
content += f.read()
if filtered_requirements:
content += filtered_requirements

# Include remote_runner.py in the hash so container rebuilds when it changes
remote_runner_path = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME)
Expand All @@ -125,7 +185,7 @@ def _hash_requirements(requirements_path, category, base_image):
return hashlib.sha256(content.encode()).hexdigest()


def _image_exists(image_uri, project):
def _image_exists(image_uri: str, project: str) -> bool:
"""Check if image exists in Artifact Registry.

Args:
Expand Down Expand Up @@ -162,19 +222,19 @@ def _image_exists(image_uri, project):


def _build_and_push(
base_image,
requirements_path,
category,
project,
image_uri,
ar_location="us",
cluster_name=None,
):
base_image: str,
filtered_requirements: str | None,
category: str,
project: str,
image_uri: str,
ar_location: str = "us",
cluster_name: str | None = None,
) -> str:
"""Build and push Docker image using Cloud Build.

Args:
base_image: Base Docker image
requirements_path: Path to requirements.txt (or None)
filtered_requirements: Pre-filtered requirements content (or None)
category: Accelerator category ('cpu', 'gpu', 'tpu')
project: GCP project ID
image_uri: Target image URI
Expand All @@ -187,17 +247,18 @@ def _build_and_push(
# Generate Dockerfile
dockerfile_content = _generate_dockerfile(
base_image=base_image,
requirements_path=requirements_path,
has_requirements=filtered_requirements is not None,
category=category,
)

dockerfile_path = os.path.join(tmpdir, "Dockerfile")
with open(dockerfile_path, "w") as f:
f.write(dockerfile_content)

# Copy requirements.txt if it exists
if requirements_path and os.path.exists(requirements_path):
shutil.copy(requirements_path, os.path.join(tmpdir, "requirements.txt"))
# Write pre-filtered requirements.txt
if filtered_requirements is not None:
with open(os.path.join(tmpdir, "requirements.txt"), "w") as f:
f.write(filtered_requirements)

# Copy remote_runner.py
remote_runner_src = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME)
Expand All @@ -209,7 +270,7 @@ def _build_and_push(
with tarfile.open(tarball_path, "w:gz") as tar:
tar.add(dockerfile_path, arcname="Dockerfile")
tar.add(remote_runner_dst, arcname=REMOTE_RUNNER_FILE_NAME)
if requirements_path and os.path.exists(requirements_path):
if filtered_requirements is not None:
tar.add(
os.path.join(tmpdir, "requirements.txt"), arcname="requirements.txt"
)
Expand Down Expand Up @@ -271,12 +332,14 @@ def _build_and_push(
raise RuntimeError(f"Build failed with status: {result.status}")


def _generate_dockerfile(base_image, requirements_path, category):
def _generate_dockerfile(
base_image: str, has_requirements: bool, category: str
) -> str:
"""Generate Dockerfile content based on configuration.

Args:
base_image: Base Docker image
requirements_path: Path to requirements.txt (or None)
has_requirements: Whether filtered requirements content is available
category: Accelerator category ('cpu', 'gpu', 'tpu')

Returns:
Expand All @@ -294,7 +357,7 @@ def _generate_dockerfile(base_image, requirements_path, category):
jax_install = "RUN python3 -m pip install 'jax[cuda12]'"

requirements_section = ""
if requirements_path and os.path.exists(requirements_path):
if has_requirements:
requirements_section = (
"COPY requirements.txt /tmp/requirements.txt\n"
"RUN python3 -m pip install -r /tmp/requirements.txt\n"
Expand All @@ -311,7 +374,9 @@ def _generate_dockerfile(base_image, requirements_path, category):
)


def _upload_build_source(tarball_path, bucket_name, project):
def _upload_build_source(
tarball_path: str, bucket_name: str, project: str
) -> str:
"""Upload build source tarball to Cloud Storage.

Args:
Expand Down
Loading
Loading