Skip to content

Commit 3020cd5

Browse files
reduce redundancy
1 parent 8357645 commit 3020cd5

File tree

2 files changed

+45
-84
lines changed

2 files changed

+45
-84
lines changed

keras_remote/infra/container_builder.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,15 @@ def get_or_build_container(
109109
cluster_name = cluster_name or get_default_cluster_name()
110110
category = accelerators.get_category(accelerator_type)
111111

112+
# Read and filter requirements once, reuse for hashing and building.
113+
filtered_requirements = None
114+
if requirements_path and os.path.exists(requirements_path):
115+
with open(requirements_path, "r") as f:
116+
filtered_requirements = _filter_jax_requirements(f.read())
117+
112118
# Generate deterministic hash from requirements + base image + category
113119
requirements_hash = _hash_requirements(
114-
requirements_path, category, base_image
120+
filtered_requirements, category, base_image
115121
)
116122

117123
# Use category for image name (e.g., 'tpu-hash', 'gpu-hash')
@@ -137,7 +143,7 @@ def get_or_build_container(
137143
logging.info("Building new container (requirements changed): %s", image_uri)
138144
return _build_and_push(
139145
base_image,
140-
requirements_path,
146+
filtered_requirements,
141147
category,
142148
project,
143149
image_uri,
@@ -147,12 +153,12 @@ def get_or_build_container(
147153

148154

149155
def _hash_requirements(
150-
requirements_path: str | None, category: str, base_image: str
156+
filtered_requirements: str | None, category: str, base_image: str
151157
) -> str:
152158
"""Create deterministic hash from requirements + category + remote_runner + base image.
153159
154160
Args:
155-
requirements_path: Path to requirements.txt (or None)
161+
filtered_requirements: Pre-filtered requirements content (or None)
156162
category: Accelerator category ('cpu', 'gpu', 'tpu')
157163
base_image: Base Docker image (e.g., 'python:3.12-slim')
158164
@@ -161,9 +167,8 @@ def _hash_requirements(
161167
"""
162168
content = f"base_image={base_image}\ncategory={category}\n"
163169

164-
if requirements_path and os.path.exists(requirements_path):
165-
with open(requirements_path, "r") as f:
166-
content += _filter_jax_requirements(f.read())
170+
if filtered_requirements:
171+
content += filtered_requirements
167172

168173
# Include remote_runner.py in the hash so container rebuilds when it changes
169174
remote_runner_path = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME)
@@ -218,7 +223,7 @@ def _image_exists(image_uri: str, project: str) -> bool:
218223

219224
def _build_and_push(
220225
base_image: str,
221-
requirements_path: str | None,
226+
filtered_requirements: str | None,
222227
category: str,
223228
project: str,
224229
image_uri: str,
@@ -229,7 +234,7 @@ def _build_and_push(
229234
230235
Args:
231236
base_image: Base Docker image
232-
requirements_path: Path to requirements.txt (or None)
237+
filtered_requirements: Pre-filtered requirements content (or None)
233238
category: Accelerator category ('cpu', 'gpu', 'tpu')
234239
project: GCP project ID
235240
image_uri: Target image URI
@@ -242,20 +247,18 @@ def _build_and_push(
242247
# Generate Dockerfile
243248
dockerfile_content = _generate_dockerfile(
244249
base_image=base_image,
245-
requirements_path=requirements_path,
250+
has_requirements=filtered_requirements is not None,
246251
category=category,
247252
)
248253

249254
dockerfile_path = os.path.join(tmpdir, "Dockerfile")
250255
with open(dockerfile_path, "w") as f:
251256
f.write(dockerfile_content)
252257

253-
# Copy requirements.txt (with JAX-related packages filtered out)
254-
if requirements_path and os.path.exists(requirements_path):
255-
with open(requirements_path, "r") as f:
256-
filtered = _filter_jax_requirements(f.read())
258+
# Write pre-filtered requirements.txt
259+
if filtered_requirements is not None:
257260
with open(os.path.join(tmpdir, "requirements.txt"), "w") as f:
258-
f.write(filtered)
261+
f.write(filtered_requirements)
259262

260263
# Copy remote_runner.py
261264
remote_runner_src = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME)
@@ -267,7 +270,7 @@ def _build_and_push(
267270
with tarfile.open(tarball_path, "w:gz") as tar:
268271
tar.add(dockerfile_path, arcname="Dockerfile")
269272
tar.add(remote_runner_dst, arcname=REMOTE_RUNNER_FILE_NAME)
270-
if requirements_path and os.path.exists(requirements_path):
273+
if filtered_requirements is not None:
271274
tar.add(
272275
os.path.join(tmpdir, "requirements.txt"), arcname="requirements.txt"
273276
)
@@ -330,13 +333,13 @@ def _build_and_push(
330333

331334

332335
def _generate_dockerfile(
333-
base_image: str, requirements_path: str | None, category: str
336+
base_image: str, has_requirements: bool, category: str
334337
) -> str:
335338
"""Generate Dockerfile content based on configuration.
336339
337340
Args:
338341
base_image: Base Docker image
339-
requirements_path: Path to requirements.txt (or None)
342+
has_requirements: Whether filtered requirements content is available
340343
category: Accelerator category ('cpu', 'gpu', 'tpu')
341344
342345
Returns:
@@ -354,7 +357,7 @@ def _generate_dockerfile(
354357
jax_install = "RUN python3 -m pip install 'jax[cuda12]'"
355358

356359
requirements_section = ""
357-
if requirements_path and os.path.exists(requirements_path):
360+
if has_requirements:
358361
requirements_section = (
359362
"COPY requirements.txt /tmp/requirements.txt\n"
360363
"RUN python3 -m pip install -r /tmp/requirements.txt\n"

keras_remote/infra/container_builder_test.py

Lines changed: 23 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Tests for keras_remote.infra.container_builder — hashing, Dockerfile gen, caching."""
22

3-
import pathlib
4-
import tempfile
53
from unittest import mock
64
from unittest.mock import MagicMock
75

@@ -17,13 +15,6 @@
1715
)
1816

1917

20-
def _make_temp_path(test_case):
21-
"""Create a temp directory that is cleaned up after the test."""
22-
td = tempfile.TemporaryDirectory()
23-
test_case.addCleanup(td.cleanup)
24-
return pathlib.Path(td.name)
25-
26-
2718
class TestFilterJaxRequirements(parameterized.TestCase):
2819
@parameterized.named_parameters(
2920
dict(testcase_name="bare_jax", line="jax\n"),
@@ -84,71 +75,42 @@ def test_preserves_comments_and_blanks(self):
8475

8576
class TestHashRequirements(parameterized.TestCase):
8677
def test_deterministic(self):
87-
tmp_path = _make_temp_path(self)
88-
req = tmp_path / "requirements.txt"
89-
req.write_text("numpy==1.26\n")
90-
91-
h1 = _hash_requirements(str(req), "gpu", "python:3.12-slim")
92-
h2 = _hash_requirements(str(req), "gpu", "python:3.12-slim")
78+
h1 = _hash_requirements("numpy==1.26\n", "gpu", "python:3.12-slim")
79+
h2 = _hash_requirements("numpy==1.26\n", "gpu", "python:3.12-slim")
9380
self.assertEqual(h1, h2)
9481

9582
def test_different_requirements_different_hash(self):
96-
tmp_path = _make_temp_path(self)
97-
req1 = tmp_path / "r1.txt"
98-
req1.write_text("numpy==1.26\n")
99-
req2 = tmp_path / "r2.txt"
100-
req2.write_text("scipy==1.12\n")
101-
102-
h1 = _hash_requirements(str(req1), "gpu", "python:3.12-slim")
103-
h2 = _hash_requirements(str(req2), "gpu", "python:3.12-slim")
83+
h1 = _hash_requirements("numpy==1.26\n", "gpu", "python:3.12-slim")
84+
h2 = _hash_requirements("scipy==1.12\n", "gpu", "python:3.12-slim")
10485
self.assertNotEqual(h1, h2)
10586

10687
def test_different_category_different_hash(self):
107-
tmp_path = _make_temp_path(self)
108-
req = tmp_path / "requirements.txt"
109-
req.write_text("numpy\n")
110-
111-
h1 = _hash_requirements(str(req), "gpu", "python:3.12-slim")
112-
h2 = _hash_requirements(str(req), "tpu", "python:3.12-slim")
88+
h1 = _hash_requirements("numpy\n", "gpu", "python:3.12-slim")
89+
h2 = _hash_requirements("numpy\n", "tpu", "python:3.12-slim")
11390
self.assertNotEqual(h1, h2)
11491

11592
def test_different_base_image_different_hash(self):
116-
tmp_path = _make_temp_path(self)
117-
req = tmp_path / "requirements.txt"
118-
req.write_text("numpy\n")
119-
120-
h1 = _hash_requirements(str(req), "gpu", "python:3.12-slim")
121-
h2 = _hash_requirements(str(req), "gpu", "python:3.11-slim")
93+
h1 = _hash_requirements("numpy\n", "gpu", "python:3.12-slim")
94+
h2 = _hash_requirements("numpy\n", "gpu", "python:3.11-slim")
12295
self.assertNotEqual(h1, h2)
12396

124-
@parameterized.named_parameters(
125-
dict(testcase_name="none", requirements_path=None),
126-
dict(
127-
testcase_name="nonexistent",
128-
requirements_path="/nonexistent/path.txt",
129-
),
130-
)
131-
def test_missing_requirements_valid(self, requirements_path):
132-
h = _hash_requirements(requirements_path, "cpu", "python:3.12-slim")
97+
def test_missing_requirements_valid(self):
98+
h = _hash_requirements(None, "cpu", "python:3.12-slim")
13399
self.assertIsInstance(h, str)
134100
self.assertLen(h, 64)
135101

136102
def test_returns_hex_string(self):
137-
tmp_path = _make_temp_path(self)
138-
req = tmp_path / "r.txt"
139-
req.write_text("keras\n")
140-
h = _hash_requirements(str(req), "gpu", "python:3.12-slim")
103+
h = _hash_requirements("keras\n", "gpu", "python:3.12-slim")
141104
self.assertRegex(h, r"^[0-9a-f]{64}$")
142105

143106
def test_jax_in_requirements_does_not_affect_hash(self):
144-
tmp_path = _make_temp_path(self)
145-
req_without_jax = tmp_path / "r1.txt"
146-
req_without_jax.write_text("numpy==1.26\n")
147-
req_with_jax = tmp_path / "r2.txt"
148-
req_with_jax.write_text("numpy==1.26\njax[tpu]>=0.4.6\n")
149-
150-
h1 = _hash_requirements(str(req_without_jax), "tpu", "python:3.12-slim")
151-
h2 = _hash_requirements(str(req_with_jax), "tpu", "python:3.12-slim")
107+
filtered_without_jax = _filter_jax_requirements("numpy==1.26\n")
108+
filtered_with_jax = _filter_jax_requirements(
109+
"numpy==1.26\njax[tpu]>=0.4.6\n"
110+
)
111+
112+
h1 = _hash_requirements(filtered_without_jax, "tpu", "python:3.12-slim")
113+
h2 = _hash_requirements(filtered_with_jax, "tpu", "python:3.12-slim")
152114
self.assertEqual(h1, h2)
153115

154116

@@ -176,7 +138,7 @@ class TestGenerateDockerfile(parameterized.TestCase):
176138
def test_jax_install(self, category, expected, not_expected):
177139
content = _generate_dockerfile(
178140
base_image="python:3.12-slim",
179-
requirements_path=None,
141+
has_requirements=False,
180142
category=category,
181143
)
182144
for s in expected:
@@ -185,13 +147,9 @@ def test_jax_install(self, category, expected, not_expected):
185147
self.assertNotIn(s, content)
186148

187149
def test_with_requirements(self):
188-
tmp_path = _make_temp_path(self)
189-
req = tmp_path / "requirements.txt"
190-
req.write_text("numpy\n")
191-
192150
content = _generate_dockerfile(
193151
base_image="python:3.12-slim",
194-
requirements_path=str(req),
152+
has_requirements=True,
195153
category="cpu",
196154
)
197155
self.assertIn("COPY requirements.txt", content)
@@ -200,7 +158,7 @@ def test_with_requirements(self):
200158
def test_without_requirements(self):
201159
content = _generate_dockerfile(
202160
base_image="python:3.12-slim",
203-
requirements_path=None,
161+
has_requirements=False,
204162
category="cpu",
205163
)
206164
self.assertNotIn("COPY requirements.txt", content)
@@ -218,15 +176,15 @@ def test_without_requirements(self):
218176
def test_contains_expected_content(self, expected_substring):
219177
content = _generate_dockerfile(
220178
base_image="python:3.12-slim",
221-
requirements_path=None,
179+
has_requirements=False,
222180
category="cpu",
223181
)
224182
self.assertIn(expected_substring, content)
225183

226184
def test_uses_base_image(self):
227185
content = _generate_dockerfile(
228186
base_image="python:3.11-bullseye",
229-
requirements_path=None,
187+
has_requirements=False,
230188
category="cpu",
231189
)
232190
self.assertIn("FROM python:3.11-bullseye", content)

0 commit comments

Comments
 (0)