Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,19 @@ pillow
scikit-learn
```

Alternatively, dependencies declared in `pyproject.toml` are also supported:

```toml
[project]
dependencies = [
"tensorflow-datasets",
"pillow",
"scikit-learn",
]
```

Keras Remote automatically detects and installs dependencies on the remote worker.
If both files exist in the same directory, `requirements.txt` takes precedence.

### Prebuilt Container Images

Expand Down
18 changes: 13 additions & 5 deletions keras_remote/backend/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class JobContext:
# Artifact paths (set during prepare phase)
payload_path: Optional[str] = None
context_path: Optional[str] = None
requirements_path: Optional[str] = None
requirements_path: Optional[str] = None # requirements.txt or pyproject.toml
image_uri: Optional[str] = None

def __post_init__(self):
Expand Down Expand Up @@ -204,12 +204,20 @@ def cleanup_job(self, job: Any, ctx: JobContext) -> None:


def _find_requirements(start_dir: str) -> Optional[str]:
"""Search up directory tree for requirements.txt."""
"""Search up directory tree for requirements.txt or pyproject.toml.

At each directory level, ``requirements.txt`` is preferred over
``pyproject.toml``. The first match found while walking towards the
filesystem root is returned.
"""
search_dir = start_dir
while search_dir != "/":
req_path = os.path.join(search_dir, "requirements.txt")
if os.path.exists(req_path):
return req_path
pyproject_path = os.path.join(search_dir, "pyproject.toml")
if os.path.exists(pyproject_path):
return pyproject_path
parent_dir = os.path.dirname(search_dir)
if parent_dir == search_dir:
break
Expand Down Expand Up @@ -288,12 +296,12 @@ def _prepare_artifacts(
)
logging.info("Context packaged to %s", ctx.context_path)

# Find requirements.txt
# Find requirements.txt or pyproject.toml
ctx.requirements_path = _find_requirements(caller_path)
if ctx.requirements_path:
logging.info("Found requirements.txt: %s", ctx.requirements_path)
logging.info("Found dependency file: %s", ctx.requirements_path)
else:
logging.info("No requirements.txt found")
logging.info("No requirements.txt or pyproject.toml found")


def _build_container(ctx: JobContext) -> None:
Expand Down
52 changes: 51 additions & 1 deletion keras_remote/backend/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,62 @@ def test_finds_in_parent_dir(self):
)

def test_returns_none_when_not_found(self):
"""Returns None when no requirements.txt exists in any ancestor."""
"""Returns None when no requirements.txt or pyproject.toml exists."""
tmp_path = _make_temp_path(self)
empty = tmp_path / "empty"
empty.mkdir()
self.assertIsNone(_find_requirements(str(empty)))

def test_finds_pyproject_toml(self):
"""Returns pyproject.toml path when no requirements.txt exists."""
tmp_path = _make_temp_path(self)
(tmp_path / "pyproject.toml").write_text(
'[project]\ndependencies = ["numpy"]\n'
)
self.assertEqual(
_find_requirements(str(tmp_path)),
str(tmp_path / "pyproject.toml"),
)

def test_requirements_txt_preferred_over_pyproject_toml(self):
"""requirements.txt in the same directory wins over pyproject.toml."""
tmp_path = _make_temp_path(self)
(tmp_path / "requirements.txt").write_text("numpy\n")
(tmp_path / "pyproject.toml").write_text(
'[project]\ndependencies = ["scipy"]\n'
)
self.assertEqual(
_find_requirements(str(tmp_path)),
str(tmp_path / "requirements.txt"),
)

def test_parent_pyproject_toml_found_from_child(self):
"""Walks up to find pyproject.toml in parent when child has nothing."""
tmp_path = _make_temp_path(self)
(tmp_path / "pyproject.toml").write_text(
'[project]\ndependencies = ["numpy"]\n'
)
child = tmp_path / "subdir"
child.mkdir()
self.assertEqual(
_find_requirements(str(child)),
str(tmp_path / "pyproject.toml"),
)

def test_child_requirements_txt_beats_parent_pyproject_toml(self):
"""requirements.txt in child dir is found before pyproject.toml in parent."""
tmp_path = _make_temp_path(self)
(tmp_path / "pyproject.toml").write_text(
'[project]\ndependencies = ["scipy"]\n'
)
child = tmp_path / "subdir"
child.mkdir()
(child / "requirements.txt").write_text("numpy\n")
self.assertEqual(
_find_requirements(str(child)),
str(child / "requirements.txt"),
)


class TestExecuteRemote(absltest.TestCase):
def _make_func(self):
Expand Down
2 changes: 1 addition & 1 deletion keras_remote/cli/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"Aligning the cloud crystals",
"Tip: Container images are content-hashed — unchanged deps skip rebuilds",
"Feeding the hamsters",
"Tip: Add a requirements.txt to auto-install deps on the remote pod",
"Tip: Add a requirements.txt or pyproject.toml to auto-install deps on the remote pod",
"Consulting the oracle",
"Tip: Use --cluster to manage multiple clusters in the same project",
"Calibrating the widgets",
Expand Down
42 changes: 38 additions & 4 deletions keras_remote/infra/container_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tarfile
import tempfile
import time
import tomllib
import uuid

from absl import logging
Expand Down Expand Up @@ -82,6 +83,29 @@ def _filter_jax_requirements(requirements_content: str) -> str:
return "".join(filtered_lines)


def _parse_pyproject_dependencies(pyproject_path: str) -> str | None:
"""Extract ``[project.dependencies]`` from a pyproject.toml file.

Reads only the core dependency list defined under the ``[project]`` table.
Optional dependency groups (``[project.optional-dependencies]``) are ignored;
users who need those should use a ``requirements.txt`` instead.

Args:
pyproject_path: Absolute path to a ``pyproject.toml`` file.

Returns:
Newline-separated dependency strings in PEP 508 format suitable for
``pip install``, or ``None`` if the file declares no dependencies.
"""
with open(pyproject_path, "rb") as f:
data = tomllib.load(f)

deps = data.get("project", {}).get("dependencies", [])
if not deps:
return None
return "\n".join(deps) + "\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a well-implemented function. For better type consistency and to align with how empty requirements.txt files are handled (which are read as an empty string), consider changing the return type from str | None to just str. The function can return an empty string "" when no dependencies are found.

This would involve changing the function signature, implementation, and docstring. The corresponding unit tests would also need to be updated to check for "" instead of None on failure cases.

def _parse_pyproject_dependencies(pyproject_path: str) -> str:
  """Extract ``[project.dependencies]`` from a pyproject.toml file.

  Reads only the core dependency list defined under the ``[project]`` table.
  Optional dependency groups (``[project.optional-dependencies]``) are ignored;
  users who need those should use a ``requirements.txt`` instead.

  Args:
      pyproject_path: Absolute path to a ``pyproject.toml`` file.

  Returns:
      Newline-separated dependency strings in PEP 508 format suitable for
      ``pip install``, or an empty string if the file declares no dependencies.
  """
  with open(pyproject_path, "rb") as f:
    data = tomllib.load(f)

  deps = data.get("project", {}).get("dependencies", [])
  return "\n".join(deps) + "\n" if deps else ""



def get_or_build_container(
base_image: str,
requirements_path: str | None,
Expand All @@ -92,11 +116,16 @@ def get_or_build_container(
) -> str:
"""Get existing container or build if requirements changed.

Uses content-based hashing to detect requirement changes.
Uses content-based hashing to detect requirement changes. Dependencies can
be supplied via a ``requirements.txt`` or a ``pyproject.toml`` (from which
``[project.dependencies]`` are extracted).

Args:
base_image: Base Docker image (e.g., 'python:3.12-slim')
requirements_path: Path to requirements.txt (or None)
requirements_path: Path to requirements.txt or pyproject.toml (or
None). When a pyproject.toml is provided,
``[project.dependencies]`` are extracted and used as the
install list.
accelerator_type: TPU/GPU type (e.g., 'v3-8')
project: GCP project ID
zone: GCP zone for region derivation (defaults to KERAS_REMOTE_ZONE)
Expand All @@ -112,8 +141,13 @@ def get_or_build_container(
# Read and filter requirements once, reuse for hashing and building.
filtered_requirements = None
if requirements_path and os.path.exists(requirements_path):
with open(requirements_path, "r") as f:
filtered_requirements = _filter_jax_requirements(f.read())
if requirements_path.endswith(".toml"):
raw_requirements = _parse_pyproject_dependencies(requirements_path)
else:
with open(requirements_path, "r") as f:
raw_requirements = f.read()
if raw_requirements:
filtered_requirements = _filter_jax_requirements(raw_requirements)

# Generate deterministic hash from requirements + base image + category
requirements_hash = _hash_requirements(
Expand Down
41 changes: 41 additions & 0 deletions keras_remote/infra/container_builder_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for keras_remote.infra.container_builder — hashing, Dockerfile gen, caching."""

import os
import tempfile
from unittest import mock
from unittest.mock import MagicMock

Expand All @@ -11,6 +13,7 @@
_generate_dockerfile,
_hash_requirements,
_image_exists,
_parse_pyproject_dependencies,
get_or_build_container,
)

Expand Down Expand Up @@ -73,6 +76,44 @@ def test_preserves_comments_and_blanks(self):
self.assertEqual(result, "# ML deps\nnumpy\n\n# end\n")


class TestParsePyprojectDependencies(absltest.TestCase):
def _write_toml(self, content):
"""Write content to a temp pyproject.toml and return its path."""
td = tempfile.TemporaryDirectory()
self.addCleanup(td.cleanup)
path = os.path.join(td.name, "pyproject.toml")
with open(path, "w") as f:
f.write(content)
return path

def test_extracts_dependencies(self):
path = self._write_toml(
'[project]\ndependencies = ["numpy>=1.20", "pandas"]\n'
)
result = _parse_pyproject_dependencies(path)
self.assertEqual(result, "numpy>=1.20\npandas\n")

def test_returns_none_when_no_dependencies(self):
path = self._write_toml("[project]\nname = 'foo'\n")
self.assertIsNone(_parse_pyproject_dependencies(path))

def test_returns_none_when_no_project_table(self):
path = self._write_toml("[tool.ruff]\nline-length = 88\n")
self.assertIsNone(_parse_pyproject_dependencies(path))

def test_returns_none_for_empty_dependencies(self):
path = self._write_toml("[project]\ndependencies = []\n")
self.assertIsNone(_parse_pyproject_dependencies(path))

def test_ignores_optional_dependencies(self):
path = self._write_toml(
'[project]\ndependencies = ["numpy"]\n\n'
'[project.optional-dependencies]\ndev = ["pytest"]\n'
)
result = _parse_pyproject_dependencies(path)
self.assertEqual(result, "numpy\n")


class TestHashRequirements(parameterized.TestCase):
def test_deterministic(self):
h1 = _hash_requirements("numpy==1.26\n", "gpu", "python:3.12-slim")
Expand Down
104 changes: 104 additions & 0 deletions tests/e2e/pyproject_deps_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""E2E tests for pyproject.toml dependency support.

These tests verify that dependencies declared in ``pyproject.toml`` are
correctly extracted, used for container building, and available in the
remote environment.

A temporary ``pyproject.toml`` is written to a temp directory. Only the
discovery function (``_find_requirements``) is patched to return that path;
the rest of the pipeline — parsing, JAX filtering, container building, and
remote execution — runs for real.

Set E2E_TESTS=1 to enable.
"""

import pathlib
import tempfile
from unittest import mock

from absl.testing import absltest

import keras_remote
from tests.e2e.e2e_utils import skip_unless_e2e


def _make_test_dir(test_case):
"""Create a temp directory cleaned up after the test."""
td = tempfile.TemporaryDirectory()
test_case.addCleanup(td.cleanup)
return pathlib.Path(td.name)


@skip_unless_e2e()
class TestPyprojectTomlDependencies(absltest.TestCase):
"""Verify that [project.dependencies] from pyproject.toml are installed."""

def _create_pyproject(self, content):
"""Write a pyproject.toml in a temp directory and return its path."""
tmp = _make_test_dir(self)
pyproject = tmp / "pyproject.toml"
pyproject.write_text(content)
return str(pyproject)

def test_dependency_installed_on_remote(self):
"""A dependency from pyproject.toml is importable in the remote function."""
path = self._create_pyproject(
'[project]\nname = "test"\nversion = "0.1"\n'
'dependencies = ["humanize>=4.0"]\n'
)

@keras_remote.run(accelerator="cpu")
def use_humanize():
import humanize

return humanize.intcomma(1_000_000)

with mock.patch(
"keras_remote.backend.execution._find_requirements",
return_value=path,
):
result = use_humanize()

self.assertEqual(result, "1,000,000")

def test_pyproject_without_deps_succeeds(self):
"""A pyproject.toml with no [project.dependencies] doesn't break the pipeline."""
path = self._create_pyproject("[tool.ruff]\nline-length = 88\n")

@keras_remote.run(accelerator="cpu")
def simple_add(a, b):
return a + b

with mock.patch(
"keras_remote.backend.execution._find_requirements",
return_value=path,
):
result = simple_add(10, 20)

self.assertEqual(result, 30)

def test_jax_filtered_from_pyproject_deps(self):
"""JAX packages in pyproject.toml are filtered like in requirements.txt."""
path = self._create_pyproject(
'[project]\nname = "test"\nversion = "0.1"\n'
'dependencies = ["jax", "humanize>=4.0"]\n'
)

@keras_remote.run(accelerator="cpu")
def check_humanize():
import humanize

return humanize.intcomma(2_500)

with mock.patch(
"keras_remote.backend.execution._find_requirements",
return_value=path,
):
result = check_humanize()

# humanize was installed (not filtered), jax was filtered silently
self.assertEqual(result, "2,500")


if __name__ == "__main__":
absltest.main()
Loading