11"""Tests for keras_remote.infra.container_builder — hashing, Dockerfile gen, caching."""
22
3- import pathlib
4- import tempfile
53from unittest import mock
64from unittest .mock import MagicMock
75
1715)
1816
1917
20- def _make_temp_path (test_case ):
21- """Create a temp directory that is cleaned up after the test."""
22- td = tempfile .TemporaryDirectory ()
23- test_case .addCleanup (td .cleanup )
24- return pathlib .Path (td .name )
25-
26-
2718class TestFilterJaxRequirements (parameterized .TestCase ):
2819 @parameterized .named_parameters (
2920 dict (testcase_name = "bare_jax" , line = "jax\n " ),
@@ -84,71 +75,42 @@ def test_preserves_comments_and_blanks(self):
8475
8576class TestHashRequirements (parameterized .TestCase ):
8677 def test_deterministic (self ):
87- tmp_path = _make_temp_path (self )
88- req = tmp_path / "requirements.txt"
89- req .write_text ("numpy==1.26\n " )
90-
91- h1 = _hash_requirements (str (req ), "gpu" , "python:3.12-slim" )
92- h2 = _hash_requirements (str (req ), "gpu" , "python:3.12-slim" )
78+ h1 = _hash_requirements ("numpy==1.26\n " , "gpu" , "python:3.12-slim" )
79+ h2 = _hash_requirements ("numpy==1.26\n " , "gpu" , "python:3.12-slim" )
9380 self .assertEqual (h1 , h2 )
9481
9582 def test_different_requirements_different_hash (self ):
96- tmp_path = _make_temp_path (self )
97- req1 = tmp_path / "r1.txt"
98- req1 .write_text ("numpy==1.26\n " )
99- req2 = tmp_path / "r2.txt"
100- req2 .write_text ("scipy==1.12\n " )
101-
102- h1 = _hash_requirements (str (req1 ), "gpu" , "python:3.12-slim" )
103- h2 = _hash_requirements (str (req2 ), "gpu" , "python:3.12-slim" )
83+ h1 = _hash_requirements ("numpy==1.26\n " , "gpu" , "python:3.12-slim" )
84+ h2 = _hash_requirements ("scipy==1.12\n " , "gpu" , "python:3.12-slim" )
10485 self .assertNotEqual (h1 , h2 )
10586
10687 def test_different_category_different_hash (self ):
107- tmp_path = _make_temp_path (self )
108- req = tmp_path / "requirements.txt"
109- req .write_text ("numpy\n " )
110-
111- h1 = _hash_requirements (str (req ), "gpu" , "python:3.12-slim" )
112- h2 = _hash_requirements (str (req ), "tpu" , "python:3.12-slim" )
88+ h1 = _hash_requirements ("numpy\n " , "gpu" , "python:3.12-slim" )
89+ h2 = _hash_requirements ("numpy\n " , "tpu" , "python:3.12-slim" )
11390 self .assertNotEqual (h1 , h2 )
11491
11592 def test_different_base_image_different_hash (self ):
116- tmp_path = _make_temp_path (self )
117- req = tmp_path / "requirements.txt"
118- req .write_text ("numpy\n " )
119-
120- h1 = _hash_requirements (str (req ), "gpu" , "python:3.12-slim" )
121- h2 = _hash_requirements (str (req ), "gpu" , "python:3.11-slim" )
93+ h1 = _hash_requirements ("numpy\n " , "gpu" , "python:3.12-slim" )
94+ h2 = _hash_requirements ("numpy\n " , "gpu" , "python:3.11-slim" )
12295 self .assertNotEqual (h1 , h2 )
12396
124- @parameterized .named_parameters (
125- dict (testcase_name = "none" , requirements_path = None ),
126- dict (
127- testcase_name = "nonexistent" ,
128- requirements_path = "/nonexistent/path.txt" ,
129- ),
130- )
131- def test_missing_requirements_valid (self , requirements_path ):
132- h = _hash_requirements (requirements_path , "cpu" , "python:3.12-slim" )
97+ def test_missing_requirements_valid (self ):
98+ h = _hash_requirements (None , "cpu" , "python:3.12-slim" )
13399 self .assertIsInstance (h , str )
134100 self .assertLen (h , 64 )
135101
136102 def test_returns_hex_string (self ):
137- tmp_path = _make_temp_path (self )
138- req = tmp_path / "r.txt"
139- req .write_text ("keras\n " )
140- h = _hash_requirements (str (req ), "gpu" , "python:3.12-slim" )
103+ h = _hash_requirements ("keras\n " , "gpu" , "python:3.12-slim" )
141104 self .assertRegex (h , r"^[0-9a-f]{64}$" )
142105
143106 def test_jax_in_requirements_does_not_affect_hash (self ):
144- tmp_path = _make_temp_path (self )
145- req_without_jax = tmp_path / "r1.txt"
146- req_without_jax .write_text ("numpy==1.26\n " )
147- req_with_jax = tmp_path / "r2.txt"
148- req_with_jax .write_text ("numpy==1.26\n jax[tpu]>=0.4.6\n " )
149-
150- h1 = _hash_requirements (str (req_without_jax ), "tpu" , "python:3.12-slim" )
151- h2 = _hash_requirements (str (req_with_jax ), "tpu" , "python:3.12-slim" )
107+ filtered_without_jax = _filter_jax_requirements ("numpy==1.26\n " )
108+ filtered_with_jax = _filter_jax_requirements (
109+ "numpy==1.26\n jax[tpu]>=0.4.6\n "
110+ )
111+
112+ h1 = _hash_requirements (filtered_without_jax , "tpu" , "python:3.12-slim" )
113+ h2 = _hash_requirements (filtered_with_jax , "tpu" , "python:3.12-slim" )
152114 self .assertEqual (h1 , h2 )
153115
154116
@@ -176,7 +138,7 @@ class TestGenerateDockerfile(parameterized.TestCase):
176138 def test_jax_install (self , category , expected , not_expected ):
177139 content = _generate_dockerfile (
178140 base_image = "python:3.12-slim" ,
179- requirements_path = None ,
141+ has_requirements = False ,
180142 category = category ,
181143 )
182144 for s in expected :
@@ -185,13 +147,9 @@ def test_jax_install(self, category, expected, not_expected):
185147 self .assertNotIn (s , content )
186148
187149 def test_with_requirements (self ):
188- tmp_path = _make_temp_path (self )
189- req = tmp_path / "requirements.txt"
190- req .write_text ("numpy\n " )
191-
192150 content = _generate_dockerfile (
193151 base_image = "python:3.12-slim" ,
194- requirements_path = str ( req ) ,
152+ has_requirements = True ,
195153 category = "cpu" ,
196154 )
197155 self .assertIn ("COPY requirements.txt" , content )
@@ -200,7 +158,7 @@ def test_with_requirements(self):
200158 def test_without_requirements (self ):
201159 content = _generate_dockerfile (
202160 base_image = "python:3.12-slim" ,
203- requirements_path = None ,
161+ has_requirements = False ,
204162 category = "cpu" ,
205163 )
206164 self .assertNotIn ("COPY requirements.txt" , content )
@@ -218,15 +176,15 @@ def test_without_requirements(self):
218176 def test_contains_expected_content (self , expected_substring ):
219177 content = _generate_dockerfile (
220178 base_image = "python:3.12-slim" ,
221- requirements_path = None ,
179+ has_requirements = False ,
222180 category = "cpu" ,
223181 )
224182 self .assertIn (expected_substring , content )
225183
226184 def test_uses_base_image (self ):
227185 content = _generate_dockerfile (
228186 base_image = "python:3.11-bullseye" ,
229- requirements_path = None ,
187+ has_requirements = False ,
230188 category = "cpu" ,
231189 )
232190 self .assertIn ("FROM python:3.11-bullseye" , content )
0 commit comments