-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpyproject_deps_test.py
More file actions
104 lines (77 loc) · 2.96 KB
/
pyproject_deps_test.py
File metadata and controls
104 lines (77 loc) · 2.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""E2E tests for pyproject.toml dependency support.
These tests verify that dependencies declared in ``pyproject.toml`` are
correctly extracted, used for container building, and available in the
remote environment.
A temporary ``pyproject.toml`` is written to a temp directory. Only the
discovery function (``_find_requirements``) is patched to return that path;
the rest of the pipeline — parsing, JAX filtering, container building, and
remote execution — runs for real.
Set E2E_TESTS=1 to enable.
"""
import pathlib
import tempfile
from unittest import mock
from absl.testing import absltest
import keras_remote
from tests.e2e.e2e_utils import skip_unless_e2e
def _make_test_dir(test_case):
"""Create a temp directory cleaned up after the test."""
td = tempfile.TemporaryDirectory()
test_case.addCleanup(td.cleanup)
return pathlib.Path(td.name)
@skip_unless_e2e()
class TestPyprojectTomlDependencies(absltest.TestCase):
"""Verify that [project.dependencies] from pyproject.toml are installed."""
def _create_pyproject(self, content):
"""Write a pyproject.toml in a temp directory and return its path."""
tmp = _make_test_dir(self)
pyproject = tmp / "pyproject.toml"
pyproject.write_text(content)
return str(pyproject)
def test_dependency_installed_on_remote(self):
"""A dependency from pyproject.toml is importable in the remote function."""
path = self._create_pyproject(
'[project]\nname = "test"\nversion = "0.1"\n'
'dependencies = ["humanize>=4.0"]\n'
)
@keras_remote.run(accelerator="cpu")
def use_humanize():
import humanize
return humanize.intcomma(1_000_000)
with mock.patch(
"keras_remote.backend.execution._find_requirements",
return_value=path,
):
result = use_humanize()
self.assertEqual(result, "1,000,000")
def test_pyproject_without_deps_succeeds(self):
"""A pyproject.toml with no [project.dependencies] doesn't break the pipeline."""
path = self._create_pyproject("[tool.ruff]\nline-length = 88\n")
@keras_remote.run(accelerator="cpu")
def simple_add(a, b):
return a + b
with mock.patch(
"keras_remote.backend.execution._find_requirements",
return_value=path,
):
result = simple_add(10, 20)
self.assertEqual(result, 30)
def test_jax_filtered_from_pyproject_deps(self):
"""JAX packages in pyproject.toml are filtered like in requirements.txt."""
path = self._create_pyproject(
'[project]\nname = "test"\nversion = "0.1"\n'
'dependencies = ["jax", "humanize>=4.0"]\n'
)
@keras_remote.run(accelerator="cpu")
def check_humanize():
import humanize
return humanize.intcomma(2_500)
with mock.patch(
"keras_remote.backend.execution._find_requirements",
return_value=path,
):
result = check_humanize()
# humanize was installed (not filtered), jax was filtered silently
self.assertEqual(result, "2,500")
if __name__ == "__main__":
absltest.main()