Skip to content

Commit 5cfba9a

Browse files
Filters JAX packages from requirements.txt
1 parent 70bd83e commit 5cfba9a

File tree

2 files changed

+129
-3
lines changed

2 files changed

+129
-3
lines changed

keras_remote/infra/container_builder.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import hashlib
44
import os
5+
import re
56
import shutil
67
import string
78
import tarfile
@@ -28,6 +29,58 @@
2829
)
2930
_RUNNER_DIR = os.path.join(os.path.dirname(__file__), os.pardir, "runner")
3031

32+
# JAX-related packages managed by the Dockerfile template.
33+
# User requirements containing these are filtered out to prevent overriding
34+
# the accelerator-specific JAX installation (e.g., jax[tpu], jax[cuda12]).
35+
_JAX_PACKAGE_NAMES = frozenset({"jax", "jaxlib", "libtpu", "libtpu-nightly"})
36+
_PACKAGE_NAME_RE = re.compile(r"^([a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?)")
37+
_KEEP_MARKER = "# kr:keep"
38+
39+
40+
def _filter_jax_requirements(requirements_content):
41+
"""Remove JAX-related packages from requirements content.
42+
43+
Strips lines that would override the accelerator-specific JAX installation
44+
managed by the Dockerfile template. Logs a warning for each filtered line.
45+
46+
To preserve a JAX line, append ``# kr:keep`` to it in requirements.txt.
47+
48+
Args:
49+
requirements_content: Raw text of a requirements.txt file.
50+
51+
Returns:
52+
Filtered requirements text with JAX-related lines removed.
53+
"""
54+
filtered_lines = []
55+
for line in requirements_content.splitlines(keepends=True):
56+
stripped = line.strip()
57+
# Preserve blanks, comments, and pip flags (-e, --index-url, etc.)
58+
if not stripped or stripped.startswith("#") or stripped.startswith("-"):
59+
filtered_lines.append(line)
60+
continue
61+
62+
# Allow users to bypass the filter with an inline marker.
63+
if _KEEP_MARKER in line:
64+
filtered_lines.append(line)
65+
continue
66+
67+
m = _PACKAGE_NAME_RE.match(stripped)
68+
if m:
69+
# PEP 503 normalization: lowercase, collapse [-_.] to '-'
70+
normalized = re.sub(r"[-_.]+", "-", m.group(1)).lower()
71+
if normalized in _JAX_PACKAGE_NAMES:
72+
logging.warning(
73+
"Filtered '%s' from requirements — JAX is installed "
74+
"automatically with the correct accelerator backend. "
75+
"To override, add '# kr:keep' to the line.",
76+
stripped,
77+
)
78+
continue
79+
80+
filtered_lines.append(line)
81+
82+
return "".join(filtered_lines)
83+
3184

3285
def get_or_build_container(
3386
base_image,
@@ -108,7 +161,7 @@ def _hash_requirements(requirements_path, category, base_image):
108161

109162
if requirements_path and os.path.exists(requirements_path):
110163
with open(requirements_path, "r") as f:
111-
content += f.read()
164+
content += _filter_jax_requirements(f.read())
112165

113166
# Include remote_runner.py in the hash so container rebuilds when it changes
114167
remote_runner_path = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME)
@@ -195,9 +248,12 @@ def _build_and_push(
195248
with open(dockerfile_path, "w") as f:
196249
f.write(dockerfile_content)
197250

198-
# Copy requirements.txt if it exists
251+
# Copy requirements.txt (with JAX-related packages filtered out)
199252
if requirements_path and os.path.exists(requirements_path):
200-
shutil.copy(requirements_path, os.path.join(tmpdir, "requirements.txt"))
253+
with open(requirements_path, "r") as f:
254+
filtered = _filter_jax_requirements(f.read())
255+
with open(os.path.join(tmpdir, "requirements.txt"), "w") as f:
256+
f.write(filtered)
201257

202258
# Copy remote_runner.py
203259
remote_runner_src = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME)

keras_remote/infra/container_builder_test.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from google.api_core import exceptions as google_exceptions
1010

1111
from keras_remote.infra.container_builder import (
12+
_filter_jax_requirements,
1213
_generate_dockerfile,
1314
_hash_requirements,
1415
_image_exists,
@@ -23,6 +24,64 @@ def _make_temp_path(test_case):
2324
return pathlib.Path(td.name)
2425

2526

27+
class TestFilterJaxRequirements(parameterized.TestCase):
28+
@parameterized.named_parameters(
29+
dict(testcase_name="bare_jax", line="jax\n"),
30+
dict(testcase_name="jax_with_tpu_extras", line="jax[tpu]>=0.4.6\n"),
31+
dict(testcase_name="jax_cuda", line="jax[cuda12]==0.4.30\n"),
32+
dict(testcase_name="jax_cpu", line="jax[cpu]\n"),
33+
dict(testcase_name="jaxlib", line="jaxlib>=0.4.6\n"),
34+
dict(testcase_name="libtpu", line="libtpu\n"),
35+
dict(testcase_name="libtpu_nightly_hyphen", line="libtpu-nightly\n"),
36+
dict(testcase_name="libtpu_nightly_underscore", line="libtpu_nightly\n"),
37+
dict(testcase_name="jax_uppercase", line="JAX\n"),
38+
dict(testcase_name="jax_mixed_case", line="Jax[tpu]\n"),
39+
)
40+
def test_filters_jax_packages(self, line):
41+
self.assertEqual(_filter_jax_requirements(line), "")
42+
43+
@parameterized.named_parameters(
44+
dict(testcase_name="numpy", line="numpy==1.26\n"),
45+
dict(testcase_name="keras", line="keras\n"),
46+
dict(testcase_name="scipy", line="scipy>=1.12\n"),
47+
dict(testcase_name="comment", line="# jax should be here\n"),
48+
dict(testcase_name="blank", line="\n"),
49+
dict(testcase_name="pip_flag", line="-e git+https://foo\n"),
50+
dict(testcase_name="index_url", line="--index-url https://pypi.org\n"),
51+
)
52+
def test_preserves_non_jax_packages(self, line):
53+
self.assertEqual(_filter_jax_requirements(line), line)
54+
55+
@parameterized.named_parameters(
56+
dict(testcase_name="jax_keep", line="jax==0.4.30 # kr:keep\n"),
57+
dict(testcase_name="jaxlib_keep", line="jaxlib # kr:keep\n"),
58+
dict(testcase_name="libtpu_keep", line="libtpu-nightly # kr:keep\n"),
59+
)
60+
def test_kr_keep_overrides_filter(self, line):
61+
self.assertEqual(_filter_jax_requirements(line), line)
62+
63+
def test_mixed_requirements(self):
64+
content = (
65+
"numpy==1.26\njax[tpu]>=0.4.6\nscipy\n"
66+
"jaxlib\nkeras\njax==0.4.30 # kr:keep\n"
67+
)
68+
result = _filter_jax_requirements(content)
69+
self.assertEqual(
70+
result, "numpy==1.26\nscipy\nkeras\njax==0.4.30 # kr:keep\n"
71+
)
72+
73+
def test_empty_string(self):
74+
self.assertEqual(_filter_jax_requirements(""), "")
75+
76+
def test_only_jax_packages(self):
77+
self.assertEqual(_filter_jax_requirements("jax\njaxlib\nlibtpu\n"), "")
78+
79+
def test_preserves_comments_and_blanks(self):
80+
content = "# ML deps\nnumpy\n\njax\n# end\n"
81+
result = _filter_jax_requirements(content)
82+
self.assertEqual(result, "# ML deps\nnumpy\n\n# end\n")
83+
84+
2685
class TestHashRequirements(parameterized.TestCase):
2786
def test_deterministic(self):
2887
tmp_path = _make_temp_path(self)
@@ -81,6 +140,17 @@ def test_returns_hex_string(self):
81140
h = _hash_requirements(str(req), "gpu", "python:3.12-slim")
82141
self.assertRegex(h, r"^[0-9a-f]{64}$")
83142

143+
def test_jax_in_requirements_does_not_affect_hash(self):
144+
tmp_path = _make_temp_path(self)
145+
req_without_jax = tmp_path / "r1.txt"
146+
req_without_jax.write_text("numpy==1.26\n")
147+
req_with_jax = tmp_path / "r2.txt"
148+
req_with_jax.write_text("numpy==1.26\njax[tpu]>=0.4.6\n")
149+
150+
h1 = _hash_requirements(str(req_without_jax), "tpu", "python:3.12-slim")
151+
h2 = _hash_requirements(str(req_with_jax), "tpu", "python:3.12-slim")
152+
self.assertEqual(h1, h2)
153+
84154

85155
class TestGenerateDockerfile(parameterized.TestCase):
86156
@parameterized.named_parameters(

0 commit comments

Comments
 (0)