-
Notifications
You must be signed in to change notification settings - Fork 1
Filter JAX packages from user requirements to prevent accelerator backend override #80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import hashlib | ||
| import os | ||
| import re | ||
| import shutil | ||
| import string | ||
| import tarfile | ||
|
|
@@ -28,15 +29,67 @@ | |
| ) | ||
| _RUNNER_DIR = os.path.join(os.path.dirname(__file__), os.pardir, "runner") | ||
|
|
||
| # JAX-related packages managed by the Dockerfile template. | ||
| # User requirements containing these are filtered out to prevent overriding | ||
| # the accelerator-specific JAX installation (e.g., jax[tpu], jax[cuda12]). | ||
| _JAX_PACKAGE_NAMES = frozenset({"jax", "jaxlib", "libtpu", "libtpu-nightly"}) | ||
| _PACKAGE_NAME_RE = re.compile(r"^([a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?)") | ||
| _KEEP_MARKER = "# kr:keep" | ||
|
|
||
|
|
||
| def _filter_jax_requirements(requirements_content: str) -> str: | ||
| """Remove JAX-related packages from requirements content. | ||
|
|
||
| Strips lines that would override the accelerator-specific JAX installation | ||
| managed by the Dockerfile template. Logs a warning for each filtered line. | ||
|
|
||
| To preserve a JAX line, append ``# kr:keep`` to it in requirements.txt. | ||
|
|
||
| Args: | ||
| requirements_content: Raw text of a requirements.txt file. | ||
|
|
||
| Returns: | ||
| Filtered requirements text with JAX-related lines removed. | ||
| """ | ||
| filtered_lines = [] | ||
| for line in requirements_content.splitlines(keepends=True): | ||
| stripped = line.strip() | ||
| # Preserve blanks, comments, and pip flags (-e, --index-url, etc.) | ||
| if not stripped or stripped.startswith("#") or stripped.startswith("-"): | ||
| filtered_lines.append(line) | ||
| continue | ||
|
|
||
| # Allow users to bypass the filter with an inline marker. | ||
| if _KEEP_MARKER in line: | ||
| filtered_lines.append(line) | ||
| continue | ||
|
|
||
| m = _PACKAGE_NAME_RE.match(stripped) | ||
| if m: | ||
| # PEP 503 normalization: lowercase, collapse [-_.] to '-' | ||
| normalized = re.sub(r"[-_.]+", "-", m.group(1)).lower() | ||
| if normalized in _JAX_PACKAGE_NAMES: | ||
| logging.warning( | ||
| "Filtered '%s' from requirements — JAX is installed " | ||
| "automatically with the correct accelerator backend. " | ||
| "To override, add '# kr:keep' to the line.", | ||
| stripped, | ||
| ) | ||
JyotinderSingh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| continue | ||
|
|
||
| filtered_lines.append(line) | ||
|
|
||
| return "".join(filtered_lines) | ||
|
|
||
|
|
||
| def get_or_build_container( | ||
| base_image, | ||
| requirements_path, | ||
| accelerator_type, | ||
| project, | ||
| zone=None, | ||
| cluster_name=None, | ||
| ): | ||
| base_image: str, | ||
| requirements_path: str | None, | ||
| accelerator_type: str, | ||
| project: str, | ||
| zone: str | None = None, | ||
| cluster_name: str | None = None, | ||
| ) -> str: | ||
| """Get existing container or build if requirements changed. | ||
|
|
||
| Uses content-based hashing to detect requirement changes. | ||
|
|
@@ -93,7 +146,9 @@ def get_or_build_container( | |
| ) | ||
|
|
||
|
|
||
| def _hash_requirements(requirements_path, category, base_image): | ||
| def _hash_requirements( | ||
| requirements_path: str | None, category: str, base_image: str | ||
| ) -> str: | ||
| """Create deterministic hash from requirements + category + remote_runner + base image. | ||
|
|
||
| Args: | ||
|
|
@@ -108,7 +163,7 @@ def _hash_requirements(requirements_path, category, base_image): | |
|
|
||
| if requirements_path and os.path.exists(requirements_path): | ||
| with open(requirements_path, "r") as f: | ||
| content += f.read() | ||
| content += _filter_jax_requirements(f.read()) | ||
|
|
||
| # Include remote_runner.py in the hash so container rebuilds when it changes | ||
| remote_runner_path = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME) | ||
|
|
@@ -125,7 +180,7 @@ def _hash_requirements(requirements_path, category, base_image): | |
| return hashlib.sha256(content.encode()).hexdigest() | ||
|
|
||
|
|
||
| def _image_exists(image_uri, project): | ||
| def _image_exists(image_uri: str, project: str) -> bool: | ||
| """Check if image exists in Artifact Registry. | ||
|
|
||
| Args: | ||
|
|
@@ -162,14 +217,14 @@ def _image_exists(image_uri, project): | |
|
|
||
|
|
||
| def _build_and_push( | ||
| base_image, | ||
| requirements_path, | ||
| category, | ||
| project, | ||
| image_uri, | ||
| ar_location="us", | ||
| cluster_name=None, | ||
| ): | ||
| base_image: str, | ||
| requirements_path: str | None, | ||
| category: str, | ||
| project: str, | ||
| image_uri: str, | ||
| ar_location: str = "us", | ||
| cluster_name: str | None = None, | ||
| ) -> str: | ||
| """Build and push Docker image using Cloud Build. | ||
|
|
||
| Args: | ||
|
|
@@ -195,9 +250,12 @@ def _build_and_push( | |
| with open(dockerfile_path, "w") as f: | ||
| f.write(dockerfile_content) | ||
|
|
||
| # Copy requirements.txt if it exists | ||
| # Copy requirements.txt (with JAX-related packages filtered out) | ||
| if requirements_path and os.path.exists(requirements_path): | ||
| shutil.copy(requirements_path, os.path.join(tmpdir, "requirements.txt")) | ||
| with open(requirements_path, "r") as f: | ||
| filtered = _filter_jax_requirements(f.read()) | ||
|
||
| with open(os.path.join(tmpdir, "requirements.txt"), "w") as f: | ||
| f.write(filtered) | ||
|
|
||
| # Copy remote_runner.py | ||
| remote_runner_src = os.path.join(_RUNNER_DIR, REMOTE_RUNNER_FILE_NAME) | ||
|
|
@@ -271,7 +329,9 @@ def _build_and_push( | |
| raise RuntimeError(f"Build failed with status: {result.status}") | ||
|
|
||
|
|
||
| def _generate_dockerfile(base_image, requirements_path, category): | ||
| def _generate_dockerfile( | ||
| base_image: str, requirements_path: str | None, category: str | ||
| ) -> str: | ||
| """Generate Dockerfile content based on configuration. | ||
|
|
||
| Args: | ||
|
|
@@ -311,7 +371,9 @@ def _generate_dockerfile(base_image, requirements_path, category): | |
| ) | ||
|
|
||
|
|
||
| def _upload_build_source(tarball_path, bucket_name, project): | ||
| def _upload_build_source( | ||
| tarball_path: str, bucket_name: str, project: str | ||
| ) -> str: | ||
| """Upload build source tarball to Cloud Storage. | ||
|
|
||
| Args: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will we also support requirements specified in pyproject.toml?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we should add support for that in subsequent PRs. Created #81 for tracking.