Skip to content

Commit 814b7da

Browse files
Adds support for pyproject.toml dependencies
1 parent b57bf26 commit 814b7da

File tree

7 files changed

+260
-11
lines changed

7 files changed

+260
-11
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,19 @@ pillow
187187
scikit-learn
188188
```
189189

190+
Alternatively, dependencies declared in `pyproject.toml` are also supported:
191+
192+
```toml
193+
[project]
194+
dependencies = [
195+
"tensorflow-datasets",
196+
"pillow",
197+
"scikit-learn",
198+
]
199+
```
200+
190201
Keras Remote automatically detects and installs dependencies on the remote worker.
202+
If both files exist in the same directory, `requirements.txt` takes precedence.
191203

192204
### Prebuilt Container Images
193205

keras_remote/backend/execution.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class JobContext:
5959
# Artifact paths (set during prepare phase)
6060
payload_path: Optional[str] = None
6161
context_path: Optional[str] = None
62-
requirements_path: Optional[str] = None
62+
requirements_path: Optional[str] = None # requirements.txt or pyproject.toml
6363
image_uri: Optional[str] = None
6464

6565
def __post_init__(self):
@@ -204,12 +204,20 @@ def cleanup_job(self, job: Any, ctx: JobContext) -> None:
204204

205205

206206
def _find_requirements(start_dir: str) -> Optional[str]:
207-
"""Search up directory tree for requirements.txt."""
207+
"""Search up directory tree for requirements.txt or pyproject.toml.
208+
209+
At each directory level, ``requirements.txt`` is preferred over
210+
``pyproject.toml``. The first match found while walking towards the
211+
filesystem root is returned.
212+
"""
208213
search_dir = start_dir
209214
while search_dir != "/":
210215
req_path = os.path.join(search_dir, "requirements.txt")
211216
if os.path.exists(req_path):
212217
return req_path
218+
pyproject_path = os.path.join(search_dir, "pyproject.toml")
219+
if os.path.exists(pyproject_path):
220+
return pyproject_path
213221
parent_dir = os.path.dirname(search_dir)
214222
if parent_dir == search_dir:
215223
break
@@ -288,12 +296,12 @@ def _prepare_artifacts(
288296
)
289297
logging.info("Context packaged to %s", ctx.context_path)
290298

291-
# Find requirements.txt
299+
# Find requirements.txt or pyproject.toml
292300
ctx.requirements_path = _find_requirements(caller_path)
293301
if ctx.requirements_path:
294-
logging.info("Found requirements.txt: %s", ctx.requirements_path)
302+
logging.info("Found dependency file: %s", ctx.requirements_path)
295303
else:
296-
logging.info("No requirements.txt found")
304+
logging.info("No requirements.txt or pyproject.toml found")
297305

298306

299307
def _build_container(ctx: JobContext) -> None:

keras_remote/backend/execution_test.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,62 @@ def test_finds_in_parent_dir(self):
148148
)
149149

150150
def test_returns_none_when_not_found(self):
151-
"""Returns None when no requirements.txt exists in any ancestor."""
151+
"""Returns None when no requirements.txt or pyproject.toml exists."""
152152
tmp_path = _make_temp_path(self)
153153
empty = tmp_path / "empty"
154154
empty.mkdir()
155155
self.assertIsNone(_find_requirements(str(empty)))
156156

157+
def test_finds_pyproject_toml(self):
158+
"""Returns pyproject.toml path when no requirements.txt exists."""
159+
tmp_path = _make_temp_path(self)
160+
(tmp_path / "pyproject.toml").write_text(
161+
'[project]\ndependencies = ["numpy"]\n'
162+
)
163+
self.assertEqual(
164+
_find_requirements(str(tmp_path)),
165+
str(tmp_path / "pyproject.toml"),
166+
)
167+
168+
def test_requirements_txt_preferred_over_pyproject_toml(self):
169+
"""requirements.txt in the same directory wins over pyproject.toml."""
170+
tmp_path = _make_temp_path(self)
171+
(tmp_path / "requirements.txt").write_text("numpy\n")
172+
(tmp_path / "pyproject.toml").write_text(
173+
'[project]\ndependencies = ["scipy"]\n'
174+
)
175+
self.assertEqual(
176+
_find_requirements(str(tmp_path)),
177+
str(tmp_path / "requirements.txt"),
178+
)
179+
180+
def test_parent_pyproject_toml_found_from_child(self):
181+
"""Walks up to find pyproject.toml in parent when child has nothing."""
182+
tmp_path = _make_temp_path(self)
183+
(tmp_path / "pyproject.toml").write_text(
184+
'[project]\ndependencies = ["numpy"]\n'
185+
)
186+
child = tmp_path / "subdir"
187+
child.mkdir()
188+
self.assertEqual(
189+
_find_requirements(str(child)),
190+
str(tmp_path / "pyproject.toml"),
191+
)
192+
193+
def test_child_requirements_txt_beats_parent_pyproject_toml(self):
194+
"""requirements.txt in child dir is found before pyproject.toml in parent."""
195+
tmp_path = _make_temp_path(self)
196+
(tmp_path / "pyproject.toml").write_text(
197+
'[project]\ndependencies = ["scipy"]\n'
198+
)
199+
child = tmp_path / "subdir"
200+
child.mkdir()
201+
(child / "requirements.txt").write_text("numpy\n")
202+
self.assertEqual(
203+
_find_requirements(str(child)),
204+
str(child / "requirements.txt"),
205+
)
206+
157207

158208
class TestExecuteRemote(absltest.TestCase):
159209
def _make_func(self):

keras_remote/cli/output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
"Aligning the cloud crystals",
3131
"Tip: Container images are content-hashed — unchanged deps skip rebuilds",
3232
"Feeding the hamsters",
33-
"Tip: Add a requirements.txt to auto-install deps on the remote pod",
33+
"Tip: Add a requirements.txt or pyproject.toml to auto-install deps on the remote pod",
3434
"Consulting the oracle",
3535
"Tip: Use --cluster to manage multiple clusters in the same project",
3636
"Calibrating the widgets",

keras_remote/infra/container_builder.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tarfile
99
import tempfile
1010
import time
11+
import tomllib
1112
import uuid
1213

1314
from absl import logging
@@ -82,6 +83,29 @@ def _filter_jax_requirements(requirements_content: str) -> str:
8283
return "".join(filtered_lines)
8384

8485

86+
def _parse_pyproject_dependencies(pyproject_path: str) -> str | None:
87+
"""Extract ``[project.dependencies]`` from a pyproject.toml file.
88+
89+
Reads only the core dependency list defined under the ``[project]`` table.
90+
Optional dependency groups (``[project.optional-dependencies]``) are ignored;
91+
users who need those should use a ``requirements.txt`` instead.
92+
93+
Args:
94+
pyproject_path: Absolute path to a ``pyproject.toml`` file.
95+
96+
Returns:
97+
Newline-separated dependency strings in PEP 508 format suitable for
98+
``pip install``, or ``None`` if the file declares no dependencies.
99+
"""
100+
with open(pyproject_path, "rb") as f:
101+
data = tomllib.load(f)
102+
103+
deps = data.get("project", {}).get("dependencies", [])
104+
if not deps:
105+
return None
106+
return "\n".join(deps) + "\n"
107+
108+
85109
def get_or_build_container(
86110
base_image: str,
87111
requirements_path: str | None,
@@ -92,11 +116,16 @@ def get_or_build_container(
92116
) -> str:
93117
"""Get existing container or build if requirements changed.
94118
95-
Uses content-based hashing to detect requirement changes.
119+
Uses content-based hashing to detect requirement changes. Dependencies can
120+
be supplied via a ``requirements.txt`` or a ``pyproject.toml`` (from which
121+
``[project.dependencies]`` are extracted).
96122
97123
Args:
98124
base_image: Base Docker image (e.g., 'python:3.12-slim')
99-
requirements_path: Path to requirements.txt (or None)
125+
requirements_path: Path to requirements.txt or pyproject.toml (or
126+
None). When a pyproject.toml is provided,
127+
``[project.dependencies]`` are extracted and used as the
128+
install list.
100129
accelerator_type: TPU/GPU type (e.g., 'v3-8')
101130
project: GCP project ID
102131
zone: GCP zone for region derivation (defaults to KERAS_REMOTE_ZONE)
@@ -112,8 +141,13 @@ def get_or_build_container(
112141
# Read and filter requirements once, reuse for hashing and building.
113142
filtered_requirements = None
114143
if requirements_path and os.path.exists(requirements_path):
115-
with open(requirements_path, "r") as f:
116-
filtered_requirements = _filter_jax_requirements(f.read())
144+
if requirements_path.endswith(".toml"):
145+
raw_requirements = _parse_pyproject_dependencies(requirements_path)
146+
else:
147+
with open(requirements_path, "r") as f:
148+
raw_requirements = f.read()
149+
if raw_requirements:
150+
filtered_requirements = _filter_jax_requirements(raw_requirements)
117151

118152
# Generate deterministic hash from requirements + base image + category
119153
requirements_hash = _hash_requirements(

keras_remote/infra/container_builder_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Tests for keras_remote.infra.container_builder — hashing, Dockerfile gen, caching."""
22

3+
import os
4+
import tempfile
35
from unittest import mock
46
from unittest.mock import MagicMock
57

@@ -11,6 +13,7 @@
1113
_generate_dockerfile,
1214
_hash_requirements,
1315
_image_exists,
16+
_parse_pyproject_dependencies,
1417
get_or_build_container,
1518
)
1619

@@ -73,6 +76,44 @@ def test_preserves_comments_and_blanks(self):
7376
self.assertEqual(result, "# ML deps\nnumpy\n\n# end\n")
7477

7578

79+
class TestParsePyprojectDependencies(absltest.TestCase):
80+
def _write_toml(self, content):
81+
"""Write content to a temp pyproject.toml and return its path."""
82+
td = tempfile.TemporaryDirectory()
83+
self.addCleanup(td.cleanup)
84+
path = os.path.join(td.name, "pyproject.toml")
85+
with open(path, "w") as f:
86+
f.write(content)
87+
return path
88+
89+
def test_extracts_dependencies(self):
90+
path = self._write_toml(
91+
'[project]\ndependencies = ["numpy>=1.20", "pandas"]\n'
92+
)
93+
result = _parse_pyproject_dependencies(path)
94+
self.assertEqual(result, "numpy>=1.20\npandas\n")
95+
96+
def test_returns_none_when_no_dependencies(self):
97+
path = self._write_toml("[project]\nname = 'foo'\n")
98+
self.assertIsNone(_parse_pyproject_dependencies(path))
99+
100+
def test_returns_none_when_no_project_table(self):
101+
path = self._write_toml("[tool.ruff]\nline-length = 88\n")
102+
self.assertIsNone(_parse_pyproject_dependencies(path))
103+
104+
def test_returns_none_for_empty_dependencies(self):
105+
path = self._write_toml("[project]\ndependencies = []\n")
106+
self.assertIsNone(_parse_pyproject_dependencies(path))
107+
108+
def test_ignores_optional_dependencies(self):
109+
path = self._write_toml(
110+
'[project]\ndependencies = ["numpy"]\n\n'
111+
'[project.optional-dependencies]\ndev = ["pytest"]\n'
112+
)
113+
result = _parse_pyproject_dependencies(path)
114+
self.assertEqual(result, "numpy\n")
115+
116+
76117
class TestHashRequirements(parameterized.TestCase):
77118
def test_deterministic(self):
78119
h1 = _hash_requirements("numpy==1.26\n", "gpu", "python:3.12-slim")

tests/e2e/pyproject_deps_test.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""E2E tests for pyproject.toml dependency support.
2+
3+
These tests verify that dependencies declared in ``pyproject.toml`` are
4+
correctly extracted, used for container building, and available in the
5+
remote environment.
6+
7+
A temporary ``pyproject.toml`` is written to a temp directory. Only the
8+
discovery function (``_find_requirements``) is patched to return that path;
9+
the rest of the pipeline — parsing, JAX filtering, container building, and
10+
remote execution — runs for real.
11+
12+
Set E2E_TESTS=1 to enable.
13+
"""
14+
15+
import pathlib
16+
import tempfile
17+
from unittest import mock
18+
19+
from absl.testing import absltest
20+
21+
import keras_remote
22+
from tests.e2e.e2e_utils import skip_unless_e2e
23+
24+
25+
def _make_test_dir(test_case):
26+
"""Create a temp directory cleaned up after the test."""
27+
td = tempfile.TemporaryDirectory()
28+
test_case.addCleanup(td.cleanup)
29+
return pathlib.Path(td.name)
30+
31+
32+
@skip_unless_e2e()
33+
class TestPyprojectTomlDependencies(absltest.TestCase):
34+
"""Verify that [project.dependencies] from pyproject.toml are installed."""
35+
36+
def _create_pyproject(self, content):
37+
"""Write a pyproject.toml in a temp directory and return its path."""
38+
tmp = _make_test_dir(self)
39+
pyproject = tmp / "pyproject.toml"
40+
pyproject.write_text(content)
41+
return str(pyproject)
42+
43+
def test_dependency_installed_on_remote(self):
44+
"""A dependency from pyproject.toml is importable in the remote function."""
45+
path = self._create_pyproject(
46+
'[project]\nname = "test"\nversion = "0.1"\n'
47+
'dependencies = ["humanize>=4.0"]\n'
48+
)
49+
50+
@keras_remote.run(accelerator="cpu")
51+
def use_humanize():
52+
import humanize
53+
54+
return humanize.intcomma(1_000_000)
55+
56+
with mock.patch(
57+
"keras_remote.backend.execution._find_requirements",
58+
return_value=path,
59+
):
60+
result = use_humanize()
61+
62+
self.assertEqual(result, "1,000,000")
63+
64+
def test_pyproject_without_deps_succeeds(self):
65+
"""A pyproject.toml with no [project.dependencies] doesn't break the pipeline."""
66+
path = self._create_pyproject("[tool.ruff]\nline-length = 88\n")
67+
68+
@keras_remote.run(accelerator="cpu")
69+
def simple_add(a, b):
70+
return a + b
71+
72+
with mock.patch(
73+
"keras_remote.backend.execution._find_requirements",
74+
return_value=path,
75+
):
76+
result = simple_add(10, 20)
77+
78+
self.assertEqual(result, 30)
79+
80+
def test_jax_filtered_from_pyproject_deps(self):
81+
"""JAX packages in pyproject.toml are filtered like in requirements.txt."""
82+
path = self._create_pyproject(
83+
'[project]\nname = "test"\nversion = "0.1"\n'
84+
'dependencies = ["jax", "humanize>=4.0"]\n'
85+
)
86+
87+
@keras_remote.run(accelerator="cpu")
88+
def check_humanize():
89+
import humanize
90+
91+
return humanize.intcomma(2_500)
92+
93+
with mock.patch(
94+
"keras_remote.backend.execution._find_requirements",
95+
return_value=path,
96+
):
97+
result = check_humanize()
98+
99+
# humanize was installed (not filtered), jax was filtered silently
100+
self.assertEqual(result, "2,500")
101+
102+
103+
if __name__ == "__main__":
104+
absltest.main()

0 commit comments

Comments
 (0)