-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_packager.py
More file actions
174 lines (129 loc) · 4.86 KB
/
test_packager.py
File metadata and controls
174 lines (129 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""Tests for keras_remote.utils.packager — zip and payload serialization."""
import os
import pathlib
import tempfile
import zipfile
import cloudpickle
import numpy as np
from absl.testing import absltest
from keras_remote.utils.packager import save_payload, zip_working_dir
def _make_temp_path(test_case):
"""Create a temp directory that is cleaned up after the test."""
td = tempfile.TemporaryDirectory()
test_case.addCleanup(td.cleanup)
return pathlib.Path(td.name)
class TestZipWorkingDir(absltest.TestCase):
def _zip_and_list(self, src, tmp_path):
"""Zip src directory and return the set of archive member names."""
out = tmp_path / "context.zip"
zip_working_dir(str(src), str(out))
with zipfile.ZipFile(str(out)) as zf:
return set(zf.namelist())
def test_contains_all_files(self):
tmp_path = _make_temp_path(self)
src = tmp_path / "src"
src.mkdir()
(src / "a.py").write_text("a")
(src / "b.txt").write_text("b")
self.assertEqual(self._zip_and_list(src, tmp_path), {"a.py", "b.txt"})
def test_excludes_git_directory(self):
tmp_path = _make_temp_path(self)
src = tmp_path / "src"
src.mkdir()
git_dir = src / ".git"
git_dir.mkdir()
(git_dir / "config").write_text("git config")
(src / "main.py").write_text("code")
names = self._zip_and_list(src, tmp_path)
self.assertTrue(all(".git" not in n for n in names))
self.assertIn("main.py", names)
def test_excludes_pycache_directory(self):
tmp_path = _make_temp_path(self)
src = tmp_path / "src"
src.mkdir()
cache_dir = src / "__pycache__"
cache_dir.mkdir()
(cache_dir / "mod.cpython-312.pyc").write_bytes(b"\x00")
(src / "mod.py").write_text("code")
names = self._zip_and_list(src, tmp_path)
self.assertTrue(all("__pycache__" not in n for n in names))
self.assertIn("mod.py", names)
def test_preserves_nested_structure(self):
tmp_path = _make_temp_path(self)
src = tmp_path / "src"
sub = src / "pkg" / "sub"
sub.mkdir(parents=True)
(sub / "deep.py").write_text("deep")
(src / "top.py").write_text("top")
names = self._zip_and_list(src, tmp_path)
self.assertIn("top.py", names)
self.assertIn(os.path.join("pkg", "sub", "deep.py"), names)
def test_empty_directory(self):
tmp_path = _make_temp_path(self)
src = tmp_path / "empty"
src.mkdir()
self.assertEqual(self._zip_and_list(src, tmp_path), set())
class TestSavePayload(absltest.TestCase):
def _save_and_load(self, tmp_path, func, args=(), kwargs=None, env_vars=None):
"""Save a payload and load it back, returning the deserialized dict."""
if kwargs is None:
kwargs = {}
if env_vars is None:
env_vars = {}
out = tmp_path / "payload.pkl"
save_payload(func, args, kwargs, env_vars, str(out))
with open(str(out), "rb") as f:
return cloudpickle.load(f)
def test_roundtrip_simple_function(self):
tmp_path = _make_temp_path(self)
def add(a, b):
return a + b
payload = self._save_and_load(
tmp_path, add, args=(2, 3), env_vars={"KEY": "val"}
)
self.assertEqual(payload["func"](2, 3), 5)
self.assertEqual(payload["args"], (2, 3))
self.assertEqual(payload["kwargs"], {})
self.assertEqual(payload["env_vars"], {"KEY": "val"})
def test_roundtrip_with_kwargs(self):
tmp_path = _make_temp_path(self)
def greet(name, greeting="Hello"):
return f"{greeting}, {name}"
payload = self._save_and_load(
tmp_path, greet, args=("World",), kwargs={"greeting": "Hi"}
)
result = payload["func"](*payload["args"], **payload["kwargs"])
self.assertEqual(result, "Hi, World")
def test_roundtrip_lambda(self):
tmp_path = _make_temp_path(self)
payload = self._save_and_load(tmp_path, lambda x: x * 2, args=(5,))
self.assertEqual(payload["func"](*payload["args"]), 10)
def test_roundtrip_closure(self):
tmp_path = _make_temp_path(self)
multiplier = 7
def make_closure(x):
return x * multiplier
payload = self._save_and_load(tmp_path, make_closure, args=(6,))
self.assertEqual(payload["func"](*payload["args"]), 42)
def test_roundtrip_numpy_args(self):
tmp_path = _make_temp_path(self)
def dot(a, b):
return np.dot(a, b)
arr_a = np.array([1.0, 2.0, 3.0])
arr_b = np.array([4.0, 5.0, 6.0])
payload = self._save_and_load(tmp_path, dot, args=(arr_a, arr_b))
result = payload["func"](*payload["args"])
self.assertAlmostEqual(result, 32.0)
def test_roundtrip_complex_args(self):
tmp_path = _make_temp_path(self)
def identity(x):
return x
complex_arg = {
"key": [1, 2, 3],
"nested": {"a": True, "b": None},
"tuple": (1, "two", 3.0),
}
payload = self._save_and_load(tmp_path, identity, args=(complex_arg,))
self.assertEqual(payload["func"](*payload["args"]), complex_arg)
if __name__ == "__main__":
absltest.main()