99from google .api_core import exceptions as google_exceptions
1010
1111from keras_remote .infra .container_builder import (
12+ _filter_jax_requirements ,
1213 _generate_dockerfile ,
1314 _hash_requirements ,
1415 _image_exists ,
@@ -23,6 +24,64 @@ def _make_temp_path(test_case):
2324 return pathlib .Path (td .name )
2425
2526
27+ class TestFilterJaxRequirements (parameterized .TestCase ):
28+ @parameterized .named_parameters (
29+ dict (testcase_name = "bare_jax" , line = "jax\n " ),
30+ dict (testcase_name = "jax_with_tpu_extras" , line = "jax[tpu]>=0.4.6\n " ),
31+ dict (testcase_name = "jax_cuda" , line = "jax[cuda12]==0.4.30\n " ),
32+ dict (testcase_name = "jax_cpu" , line = "jax[cpu]\n " ),
33+ dict (testcase_name = "jaxlib" , line = "jaxlib>=0.4.6\n " ),
34+ dict (testcase_name = "libtpu" , line = "libtpu\n " ),
35+ dict (testcase_name = "libtpu_nightly_hyphen" , line = "libtpu-nightly\n " ),
36+ dict (testcase_name = "libtpu_nightly_underscore" , line = "libtpu_nightly\n " ),
37+ dict (testcase_name = "jax_uppercase" , line = "JAX\n " ),
38+ dict (testcase_name = "jax_mixed_case" , line = "Jax[tpu]\n " ),
39+ )
40+ def test_filters_jax_packages (self , line ):
41+ self .assertEqual (_filter_jax_requirements (line ), "" )
42+
43+ @parameterized .named_parameters (
44+ dict (testcase_name = "numpy" , line = "numpy==1.26\n " ),
45+ dict (testcase_name = "keras" , line = "keras\n " ),
46+ dict (testcase_name = "scipy" , line = "scipy>=1.12\n " ),
47+ dict (testcase_name = "comment" , line = "# jax should be here\n " ),
48+ dict (testcase_name = "blank" , line = "\n " ),
49+ dict (testcase_name = "pip_flag" , line = "-e git+https://foo\n " ),
50+ dict (testcase_name = "index_url" , line = "--index-url https://pypi.org\n " ),
51+ )
52+ def test_preserves_non_jax_packages (self , line ):
53+ self .assertEqual (_filter_jax_requirements (line ), line )
54+
55+ @parameterized .named_parameters (
56+ dict (testcase_name = "jax_keep" , line = "jax==0.4.30 # kr:keep\n " ),
57+ dict (testcase_name = "jaxlib_keep" , line = "jaxlib # kr:keep\n " ),
58+ dict (testcase_name = "libtpu_keep" , line = "libtpu-nightly # kr:keep\n " ),
59+ )
60+ def test_kr_keep_overrides_filter (self , line ):
61+ self .assertEqual (_filter_jax_requirements (line ), line )
62+
63+ def test_mixed_requirements (self ):
64+ content = (
65+ "numpy==1.26\n jax[tpu]>=0.4.6\n scipy\n "
66+ "jaxlib\n keras\n jax==0.4.30 # kr:keep\n "
67+ )
68+ result = _filter_jax_requirements (content )
69+ self .assertEqual (
70+ result , "numpy==1.26\n scipy\n keras\n jax==0.4.30 # kr:keep\n "
71+ )
72+
73+ def test_empty_string (self ):
74+ self .assertEqual (_filter_jax_requirements ("" ), "" )
75+
76+ def test_only_jax_packages (self ):
77+ self .assertEqual (_filter_jax_requirements ("jax\n jaxlib\n libtpu\n " ), "" )
78+
79+ def test_preserves_comments_and_blanks (self ):
80+ content = "# ML deps\n numpy\n \n jax\n # end\n "
81+ result = _filter_jax_requirements (content )
82+ self .assertEqual (result , "# ML deps\n numpy\n \n # end\n " )
83+
84+
2685class TestHashRequirements (parameterized .TestCase ):
2786 def test_deterministic (self ):
2887 tmp_path = _make_temp_path (self )
@@ -81,6 +140,17 @@ def test_returns_hex_string(self):
81140 h = _hash_requirements (str (req ), "gpu" , "python:3.12-slim" )
82141 self .assertRegex (h , r"^[0-9a-f]{64}$" )
83142
143+ def test_jax_in_requirements_does_not_affect_hash (self ):
144+ tmp_path = _make_temp_path (self )
145+ req_without_jax = tmp_path / "r1.txt"
146+ req_without_jax .write_text ("numpy==1.26\n " )
147+ req_with_jax = tmp_path / "r2.txt"
148+ req_with_jax .write_text ("numpy==1.26\n jax[tpu]>=0.4.6\n " )
149+
150+ h1 = _hash_requirements (str (req_without_jax ), "tpu" , "python:3.12-slim" )
151+ h2 = _hash_requirements (str (req_with_jax ), "tpu" , "python:3.12-slim" )
152+ self .assertEqual (h1 , h2 )
153+
84154
85155class TestGenerateDockerfile (parameterized .TestCase ):
86156 @parameterized .named_parameters (
0 commit comments