22
33import hashlib
44import os
5+ import re
56import shutil
67import string
78import tarfile
2829)
2930_RUNNER_DIR = os .path .join (os .path .dirname (__file__ ), os .pardir , "runner" )
3031
32+ # JAX-related packages managed by the Dockerfile template.
33+ # User requirements containing these are filtered out to prevent overriding
34+ # the accelerator-specific JAX installation (e.g., jax[tpu], jax[cuda12]).
35+ _JAX_PACKAGE_NAMES = frozenset ({"jax" , "jaxlib" , "libtpu" , "libtpu-nightly" })
36+ _PACKAGE_NAME_RE = re .compile (r"^([a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?)" )
37+ _KEEP_MARKER = "# kr:keep"
38+
39+
40+ def _filter_jax_requirements (requirements_content : str ) -> str :
41+ """Remove JAX-related packages from requirements content.
42+
43+ Strips lines that would override the accelerator-specific JAX installation
44+ managed by the Dockerfile template. Logs a warning for each filtered line.
45+
46+ To preserve a JAX line, append ``# kr:keep`` to it in requirements.txt.
47+
48+ Args:
49+ requirements_content: Raw text of a requirements.txt file.
50+
51+ Returns:
52+ Filtered requirements text with JAX-related lines removed.
53+ """
54+ filtered_lines = []
55+ for line in requirements_content .splitlines (keepends = True ):
56+ stripped = line .strip ()
57+ # Preserve blanks, comments, and pip flags (-e, --index-url, etc.)
58+ if not stripped or stripped .startswith ("#" ) or stripped .startswith ("-" ):
59+ filtered_lines .append (line )
60+ continue
61+
62+ # Allow users to bypass the filter with an inline marker.
63+ if _KEEP_MARKER in line :
64+ filtered_lines .append (line )
65+ continue
66+
67+ m = _PACKAGE_NAME_RE .match (stripped )
68+ if m :
69+ # PEP 503 normalization: lowercase, collapse [-_.] to '-'
70+ normalized = re .sub (r"[-_.]+" , "-" , m .group (1 )).lower ()
71+ if normalized in _JAX_PACKAGE_NAMES :
72+ logging .warning (
73+ "Filtered '%s' from requirements — JAX is installed "
74+ "automatically with the correct accelerator backend. "
75+ "To override, add '# kr:keep' to the line." ,
76+ m .group (1 ),
77+ )
78+ continue
79+
80+ filtered_lines .append (line )
81+
82+ return "" .join (filtered_lines )
83+
3184
3285def get_or_build_container (
33- base_image ,
34- requirements_path ,
35- accelerator_type ,
36- project ,
37- zone = None ,
38- cluster_name = None ,
39- ):
86+ base_image : str ,
87+ requirements_path : str | None ,
88+ accelerator_type : str ,
89+ project : str ,
90+ zone : str | None = None ,
91+ cluster_name : str | None = None ,
92+ ) -> str :
4093 """Get existing container or build if requirements changed.
4194
4295 Uses content-based hashing to detect requirement changes.
@@ -56,9 +109,15 @@ def get_or_build_container(
56109 cluster_name = cluster_name or get_default_cluster_name ()
57110 category = accelerators .get_category (accelerator_type )
58111
112+ # Read and filter requirements once, reuse for hashing and building.
113+ filtered_requirements = None
114+ if requirements_path and os .path .exists (requirements_path ):
115+ with open (requirements_path , "r" ) as f :
116+ filtered_requirements = _filter_jax_requirements (f .read ())
117+
59118 # Generate deterministic hash from requirements + base image + category
60119 requirements_hash = _hash_requirements (
61- requirements_path , category , base_image
120+ filtered_requirements , category , base_image
62121 )
63122
64123 # Use category for image name (e.g., 'tpu-hash', 'gpu-hash')
@@ -84,7 +143,7 @@ def get_or_build_container(
84143 logging .info ("Building new container (requirements changed): %s" , image_uri )
85144 return _build_and_push (
86145 base_image ,
87- requirements_path ,
146+ filtered_requirements ,
88147 category ,
89148 project ,
90149 image_uri ,
@@ -93,11 +152,13 @@ def get_or_build_container(
93152 )
94153
95154
96- def _hash_requirements (requirements_path , category , base_image ):
155+ def _hash_requirements (
156+ filtered_requirements : str | None , category : str , base_image : str
157+ ) -> str :
97158 """Create deterministic hash from requirements + category + remote_runner + base image.
98159
99160 Args:
100- requirements_path: Path to requirements.txt (or None)
161+ filtered_requirements: Pre-filtered requirements content (or None)
101162 category: Accelerator category ('cpu', 'gpu', 'tpu')
102163 base_image: Base Docker image (e.g., 'python:3.12-slim')
103164
@@ -106,9 +167,8 @@ def _hash_requirements(requirements_path, category, base_image):
106167 """
107168 content = f"base_image={ base_image } \n category={ category } \n "
108169
109- if requirements_path and os .path .exists (requirements_path ):
110- with open (requirements_path , "r" ) as f :
111- content += f .read ()
170+ if filtered_requirements :
171+ content += filtered_requirements
112172
113173 # Include remote_runner.py in the hash so container rebuilds when it changes
114174 remote_runner_path = os .path .join (_RUNNER_DIR , REMOTE_RUNNER_FILE_NAME )
@@ -125,7 +185,7 @@ def _hash_requirements(requirements_path, category, base_image):
125185 return hashlib .sha256 (content .encode ()).hexdigest ()
126186
127187
128- def _image_exists (image_uri , project ) :
188+ def _image_exists (image_uri : str , project : str ) -> bool :
129189 """Check if image exists in Artifact Registry.
130190
131191 Args:
@@ -162,19 +222,19 @@ def _image_exists(image_uri, project):
162222
163223
164224def _build_and_push (
165- base_image ,
166- requirements_path ,
167- category ,
168- project ,
169- image_uri ,
170- ar_location = "us" ,
171- cluster_name = None ,
172- ):
225+ base_image : str ,
226+ filtered_requirements : str | None ,
227+ category : str ,
228+ project : str ,
229+ image_uri : str ,
230+ ar_location : str = "us" ,
231+ cluster_name : str | None = None ,
232+ ) -> str :
173233 """Build and push Docker image using Cloud Build.
174234
175235 Args:
176236 base_image: Base Docker image
177- requirements_path: Path to requirements.txt (or None)
237+ filtered_requirements: Pre-filtered requirements content (or None)
178238 category: Accelerator category ('cpu', 'gpu', 'tpu')
179239 project: GCP project ID
180240 image_uri: Target image URI
@@ -187,17 +247,18 @@ def _build_and_push(
187247 # Generate Dockerfile
188248 dockerfile_content = _generate_dockerfile (
189249 base_image = base_image ,
190- requirements_path = requirements_path ,
250+ has_requirements = filtered_requirements is not None ,
191251 category = category ,
192252 )
193253
194254 dockerfile_path = os .path .join (tmpdir , "Dockerfile" )
195255 with open (dockerfile_path , "w" ) as f :
196256 f .write (dockerfile_content )
197257
198- # Copy requirements.txt if it exists
199- if requirements_path and os .path .exists (requirements_path ):
200- shutil .copy (requirements_path , os .path .join (tmpdir , "requirements.txt" ))
258+ # Write pre-filtered requirements.txt
259+ if filtered_requirements is not None :
260+ with open (os .path .join (tmpdir , "requirements.txt" ), "w" ) as f :
261+ f .write (filtered_requirements )
201262
202263 # Copy remote_runner.py
203264 remote_runner_src = os .path .join (_RUNNER_DIR , REMOTE_RUNNER_FILE_NAME )
@@ -209,7 +270,7 @@ def _build_and_push(
209270 with tarfile .open (tarball_path , "w:gz" ) as tar :
210271 tar .add (dockerfile_path , arcname = "Dockerfile" )
211272 tar .add (remote_runner_dst , arcname = REMOTE_RUNNER_FILE_NAME )
212- if requirements_path and os . path . exists ( requirements_path ) :
273+ if filtered_requirements is not None :
213274 tar .add (
214275 os .path .join (tmpdir , "requirements.txt" ), arcname = "requirements.txt"
215276 )
@@ -271,12 +332,14 @@ def _build_and_push(
271332 raise RuntimeError (f"Build failed with status: { result .status } " )
272333
273334
274- def _generate_dockerfile (base_image , requirements_path , category ):
335+ def _generate_dockerfile (
336+ base_image : str , has_requirements : bool , category : str
337+ ) -> str :
275338 """Generate Dockerfile content based on configuration.
276339
277340 Args:
278341 base_image: Base Docker image
279- requirements_path: Path to requirements.txt (or None)
342+ has_requirements: Whether filtered requirements content is available
280343 category: Accelerator category ('cpu', 'gpu', 'tpu')
281344
282345 Returns:
@@ -294,7 +357,7 @@ def _generate_dockerfile(base_image, requirements_path, category):
294357 jax_install = "RUN python3 -m pip install 'jax[cuda12]'"
295358
296359 requirements_section = ""
297- if requirements_path and os . path . exists ( requirements_path ) :
360+ if has_requirements :
298361 requirements_section = (
299362 "COPY requirements.txt /tmp/requirements.txt\n "
300363 "RUN python3 -m pip install -r /tmp/requirements.txt\n "
@@ -311,7 +374,9 @@ def _generate_dockerfile(base_image, requirements_path, category):
311374 )
312375
313376
314- def _upload_build_source (tarball_path , bucket_name , project ):
377+ def _upload_build_source (
378+ tarball_path : str , bucket_name : str , project : str
379+ ) -> str :
315380 """Upload build source tarball to Cloud Storage.
316381
317382 Args:
0 commit comments