Skip to content

Commit 5d0f1b3

Browse files
Filter JAX packages from user requirements to prevent accelerator backend override (#80)
1 parent d3fd037 commit 5d0f1b3

File tree

2 files changed

+181
-88
lines changed

2 files changed

+181
-88
lines changed

keras_remote/infra/container_builder.py

Lines changed: 98 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import hashlib
44
import os
5+
import re
56
import shutil
67
import string
78
import tarfile
@@ -28,15 +29,67 @@
2829
)
2930
_RUNNER_DIR = os.path.join(os.path.dirname(__file__), os.pardir, "runner")
3031

32+
# JAX-related packages managed by the Dockerfile template.
33+
# User requirements containing these are filtered out to prevent overriding
34+
# the accelerator-specific JAX installation (e.g., jax[tpu], jax[cuda12]).
35+
_JAX_PACKAGE_NAMES = frozenset({"jax", "jaxlib", "libtpu", "libtpu-nightly"})
36+
_PACKAGE_NAME_RE = re.compile(r"^([a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?)")
37+
_KEEP_MARKER = "# kr:keep"
38+
39+
40+
def _filter_jax_requirements(requirements_content: str) -> str:
41+
"""Remove JAX-related packages from requirements content.
42+
43+
Strips lines that would override the accelerator-specific JAX installation
44+
managed by the Dockerfile template. Logs a warning for each filtered line.
45+
46+
To preserve a JAX line, append ``# kr:keep`` to it in requirements.txt.
47+
48+
Args:
49+
requirements_content: Raw text of a requirements.txt file.
50+
51+
Returns:
52+
Filtered requirements text with JAX-related lines removed.
53+
"""
54+
filtered_lines = []
55+
for line in requirements_content.splitlines(keepends=True):
56+
stripped = line.strip()
57+
# Preserve blanks, comments, and pip flags (-e, --index-url, etc.)
58+
if not stripped or stripped.startswith("#") or stripped.startswith("-"):
59+
filtered_lines.append(line)
60+
continue
61+
62+
# Allow users to bypass the filter with an inline marker.
63+
if _KEEP_MARKER in line:
64+
filtered_lines.append(line)
65+
continue
66+
67+
m = _PACKAGE_NAME_RE.match(stripped)
68+
if m:
69+
# PEP 503 normalization: lowercase, collapse [-_.] to '-'
70+
normalized = re.sub(r"[-_.]+", "-", m.group(1)).lower()
71+
if normalized in _JAX_PACKAGE_NAMES:
72+
logging.warning(
73+
"Filtered '%s' from requirements — JAX is installed "
74+
"automatically with the correct accelerator backend. "
75+
"To override, add '# kr:keep' to the line.",
76+
m.group(1),
77+
)
78+
continue
79+
80+
filtered_lines.append(line)
81+
82+
return "".join(filtered_lines)
83+
3184

3285
def get_or_build_container(
33-
base_image,
34-
requirements_path,
35-
accelerator_type,
36-
project,
37-
zone=None,
38-
cluster_name=None,
39-
):
86+
base_image: str,
87+
requirements_path: str | None,
88+
accelerator_type: str,
89+
project: str,
90+
zone: str | None = None,
91+
cluster_name: str | None = None,
92+
) -> str:
4093
"""Get existing container or build if requirements changed.
4194
4295
Uses content-based hashing to detect requirement changes.
@@ -56,9 +109,15 @@ def get_or_build_container(
56109
cluster_name = cluster_name or get_default_cluster_name()
57110
category = accelerators.get_category(accelerator_type)
58111

112+
# Read and filter requirements once, reuse for hashing and building.
113+
filtered_requirements = None
114+
if requirements_path and os.path.exists(requirements_path):
115+
with open(requirements_path, "r") as f:
116+
filtered_requirements = _filter_jax_requirements(f.read())
117+
59118
# Generate deterministic hash from requirements + base image + category
60119
requirements_hash = _hash_requirements(
61-
requirements_path, category, base_image
120+
filtered_requirements, category, base_image
62121
)
63122

64123
# Use category for image name (e.g., 'tpu-hash', 'gpu-hash')
@@ -84,7 +143,7 @@ def get_or_build_container(
84143
logging.info("Building new container (requirements changed): %s", image_uri)
85144
return _build_and_push(
86145
base_image,
87-
requirements_path,
146+
filtered_requirements,
88147
category,
89148
project,
90149
image_uri,
@@ -93,11 +152,13 @@ def get_or_build_container(
93152
)
94153

95154

96-
def _hash_requirements(requirements_path, category, base_image):
155+
def _hash_requirements(
156+
filtered_requirements: str | None, category: str, base_image: str
157+
) -> str:
97158
"""Create deterministic hash from requirements + category + remote_runner + base image.
98159
99160
Args:
100-
requirements_path: Path to requirements.txt (or None)
161+
filtered_requirements: Pre-filtered requirements content (or None)
101162
category: Accelerator category ('cpu', 'gpu', 'tpu')
102163
base_image: Base Docker image (e.g., 'python:3.12-slim')
103164
@@ -106,9 +167,8 @@ def _hash_requirements(requirements_path, category, base_image):
106167
"""
107168
content = f"base_image={base_image}\ncategory={category}\n"
108169

109-
if requirements_path and os.path.exists(requirements_path):
110-
with open(requirements_path, "r") as f:
111-
content += f.read()
170+
if filtered_requirements:
171+
content += filtered_requirements
112172

113173
# Include remote_runner.py in the hash so container rebuilds when it changes
114174
remote_runner_path = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME)
@@ -125,7 +185,7 @@ def _hash_requirements(requirements_path, category, base_image):
125185
return hashlib.sha256(content.encode()).hexdigest()
126186

127187

128-
def _image_exists(image_uri, project):
188+
def _image_exists(image_uri: str, project: str) -> bool:
129189
"""Check if image exists in Artifact Registry.
130190
131191
Args:
@@ -162,19 +222,19 @@ def _image_exists(image_uri, project):
162222

163223

164224
def _build_and_push(
165-
base_image,
166-
requirements_path,
167-
category,
168-
project,
169-
image_uri,
170-
ar_location="us",
171-
cluster_name=None,
172-
):
225+
base_image: str,
226+
filtered_requirements: str | None,
227+
category: str,
228+
project: str,
229+
image_uri: str,
230+
ar_location: str = "us",
231+
cluster_name: str | None = None,
232+
) -> str:
173233
"""Build and push Docker image using Cloud Build.
174234
175235
Args:
176236
base_image: Base Docker image
177-
requirements_path: Path to requirements.txt (or None)
237+
filtered_requirements: Pre-filtered requirements content (or None)
178238
category: Accelerator category ('cpu', 'gpu', 'tpu')
179239
project: GCP project ID
180240
image_uri: Target image URI
@@ -187,17 +247,18 @@ def _build_and_push(
187247
# Generate Dockerfile
188248
dockerfile_content = _generate_dockerfile(
189249
base_image=base_image,
190-
requirements_path=requirements_path,
250+
has_requirements=filtered_requirements is not None,
191251
category=category,
192252
)
193253

194254
dockerfile_path = os.path.join(tmpdir, "Dockerfile")
195255
with open(dockerfile_path, "w") as f:
196256
f.write(dockerfile_content)
197257

198-
# Copy requirements.txt if it exists
199-
if requirements_path and os.path.exists(requirements_path):
200-
shutil.copy(requirements_path, os.path.join(tmpdir, "requirements.txt"))
258+
# Write pre-filtered requirements.txt
259+
if filtered_requirements is not None:
260+
with open(os.path.join(tmpdir, "requirements.txt"), "w") as f:
261+
f.write(filtered_requirements)
201262

202263
# Copy remote_runner.py
203264
remote_runner_src = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME)
@@ -209,7 +270,7 @@ def _build_and_push(
209270
with tarfile.open(tarball_path, "w:gz") as tar:
210271
tar.add(dockerfile_path, arcname="Dockerfile")
211272
tar.add(remote_runner_dst, arcname=REMOTE_RUNNER_FILE_NAME)
212-
if requirements_path and os.path.exists(requirements_path):
273+
if filtered_requirements is not None:
213274
tar.add(
214275
os.path.join(tmpdir, "requirements.txt"), arcname="requirements.txt"
215276
)
@@ -271,12 +332,14 @@ def _build_and_push(
271332
raise RuntimeError(f"Build failed with status: {result.status}")
272333

273334

274-
def _generate_dockerfile(base_image, requirements_path, category):
335+
def _generate_dockerfile(
336+
base_image: str, has_requirements: bool, category: str
337+
) -> str:
275338
"""Generate Dockerfile content based on configuration.
276339
277340
Args:
278341
base_image: Base Docker image
279-
requirements_path: Path to requirements.txt (or None)
342+
has_requirements: Whether filtered requirements content is available
280343
category: Accelerator category ('cpu', 'gpu', 'tpu')
281344
282345
Returns:
@@ -294,7 +357,7 @@ def _generate_dockerfile(base_image, requirements_path, category):
294357
jax_install = "RUN python3 -m pip install 'jax[cuda12]'"
295358

296359
requirements_section = ""
297-
if requirements_path and os.path.exists(requirements_path):
360+
if has_requirements:
298361
requirements_section = (
299362
"COPY requirements.txt /tmp/requirements.txt\n"
300363
"RUN python3 -m pip install -r /tmp/requirements.txt\n"
@@ -311,7 +374,9 @@ def _generate_dockerfile(base_image, requirements_path, category):
311374
)
312375

313376

314-
def _upload_build_source(tarball_path, bucket_name, project):
377+
def _upload_build_source(
378+
tarball_path: str, bucket_name: str, project: str
379+
) -> str:
315380
"""Upload build source tarball to Cloud Storage.
316381
317382
Args:

0 commit comments

Comments
 (0)