Skip to content

Commit a88e71e

Browse files
Migrate serialization tests from pytest to absl testing
1 parent 36a011c commit a88e71e

File tree

3 files changed

+85
-39
lines changed

3 files changed

+85
-39
lines changed
Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,43 @@
11
"""Tests for keras_remote.utils.packager — zip and payload serialization."""
22

33
import os
4+
import pathlib
5+
import tempfile
46
import zipfile
57

68
import cloudpickle
79
import numpy as np
8-
import pytest
10+
from absl.testing import absltest
911

1012
from keras_remote.utils.packager import save_payload, zip_working_dir
1113

1214

13-
class TestZipWorkingDir:
15+
def _make_temp_path(test_case):
16+
"""Create a temp directory that is cleaned up after the test."""
17+
td = tempfile.TemporaryDirectory()
18+
test_case.addCleanup(td.cleanup)
19+
return pathlib.Path(td.name)
20+
21+
22+
class TestZipWorkingDir(absltest.TestCase):
1423
def _zip_and_list(self, src, tmp_path):
1524
"""Zip src directory and return the set of archive member names."""
1625
out = tmp_path / "context.zip"
1726
zip_working_dir(str(src), str(out))
1827
with zipfile.ZipFile(str(out)) as zf:
1928
return set(zf.namelist())
2029

21-
def test_contains_all_files(self, tmp_path):
30+
def test_contains_all_files(self):
31+
tmp_path = _make_temp_path(self)
2232
src = tmp_path / "src"
2333
src.mkdir()
2434
(src / "a.py").write_text("a")
2535
(src / "b.txt").write_text("b")
2636

27-
assert self._zip_and_list(src, tmp_path) == {"a.py", "b.txt"}
37+
self.assertEqual(self._zip_and_list(src, tmp_path), {"a.py", "b.txt"})
2838

29-
def test_excludes_git_directory(self, tmp_path):
39+
def test_excludes_git_directory(self):
40+
tmp_path = _make_temp_path(self)
3041
src = tmp_path / "src"
3142
src.mkdir()
3243
git_dir = src / ".git"
@@ -35,10 +46,11 @@ def test_excludes_git_directory(self, tmp_path):
3546
(src / "main.py").write_text("code")
3647

3748
names = self._zip_and_list(src, tmp_path)
38-
assert all(".git" not in n for n in names)
39-
assert "main.py" in names
49+
self.assertTrue(all(".git" not in n for n in names))
50+
self.assertIn("main.py", names)
4051

41-
def test_excludes_pycache_directory(self, tmp_path):
52+
def test_excludes_pycache_directory(self):
53+
tmp_path = _make_temp_path(self)
4254
src = tmp_path / "src"
4355
src.mkdir()
4456
cache_dir = src / "__pycache__"
@@ -47,28 +59,30 @@ def test_excludes_pycache_directory(self, tmp_path):
4759
(src / "mod.py").write_text("code")
4860

4961
names = self._zip_and_list(src, tmp_path)
50-
assert all("__pycache__" not in n for n in names)
51-
assert "mod.py" in names
62+
self.assertTrue(all("__pycache__" not in n for n in names))
63+
self.assertIn("mod.py", names)
5264

53-
def test_preserves_nested_structure(self, tmp_path):
65+
def test_preserves_nested_structure(self):
66+
tmp_path = _make_temp_path(self)
5467
src = tmp_path / "src"
5568
sub = src / "pkg" / "sub"
5669
sub.mkdir(parents=True)
5770
(sub / "deep.py").write_text("deep")
5871
(src / "top.py").write_text("top")
5972

6073
names = self._zip_and_list(src, tmp_path)
61-
assert "top.py" in names
62-
assert os.path.join("pkg", "sub", "deep.py") in names
74+
self.assertIn("top.py", names)
75+
self.assertIn(os.path.join("pkg", "sub", "deep.py"), names)
6376

64-
def test_empty_directory(self, tmp_path):
77+
def test_empty_directory(self):
78+
tmp_path = _make_temp_path(self)
6579
src = tmp_path / "empty"
6680
src.mkdir()
6781

68-
assert self._zip_and_list(src, tmp_path) == set()
82+
self.assertEqual(self._zip_and_list(src, tmp_path), set())
6983

7084

71-
class TestSavePayload:
85+
class TestSavePayload(absltest.TestCase):
7286
def _save_and_load(self, tmp_path, func, args=(), kwargs=None, env_vars=None):
7387
"""Save a payload and load it back, returning the deserialized dict."""
7488
if kwargs is None:
@@ -80,20 +94,24 @@ def _save_and_load(self, tmp_path, func, args=(), kwargs=None, env_vars=None):
8094
with open(str(out), "rb") as f:
8195
return cloudpickle.load(f)
8296

83-
def test_roundtrip_simple_function(self, tmp_path):
97+
def test_roundtrip_simple_function(self):
98+
tmp_path = _make_temp_path(self)
99+
84100
def add(a, b):
85101
return a + b
86102

87103
payload = self._save_and_load(
88104
tmp_path, add, args=(2, 3), env_vars={"KEY": "val"}
89105
)
90106

91-
assert payload["func"](2, 3) == 5
92-
assert payload["args"] == (2, 3)
93-
assert payload["kwargs"] == {}
94-
assert payload["env_vars"] == {"KEY": "val"}
107+
self.assertEqual(payload["func"](2, 3), 5)
108+
self.assertEqual(payload["args"], (2, 3))
109+
self.assertEqual(payload["kwargs"], {})
110+
self.assertEqual(payload["env_vars"], {"KEY": "val"})
111+
112+
def test_roundtrip_with_kwargs(self):
113+
tmp_path = _make_temp_path(self)
95114

96-
def test_roundtrip_with_kwargs(self, tmp_path):
97115
def greet(name, greeting="Hello"):
98116
return f"{greeting}, {name}"
99117

@@ -102,24 +120,28 @@ def greet(name, greeting="Hello"):
102120
)
103121

104122
result = payload["func"](*payload["args"], **payload["kwargs"])
105-
assert result == "Hi, World"
123+
self.assertEqual(result, "Hi, World")
106124

107-
def test_roundtrip_lambda(self, tmp_path):
125+
def test_roundtrip_lambda(self):
126+
tmp_path = _make_temp_path(self)
108127
payload = self._save_and_load(tmp_path, lambda x: x * 2, args=(5,))
109128

110-
assert payload["func"](*payload["args"]) == 10
129+
self.assertEqual(payload["func"](*payload["args"]), 10)
111130

112-
def test_roundtrip_closure(self, tmp_path):
131+
def test_roundtrip_closure(self):
132+
tmp_path = _make_temp_path(self)
113133
multiplier = 7
114134

115135
def make_closure(x):
116136
return x * multiplier
117137

118138
payload = self._save_and_load(tmp_path, make_closure, args=(6,))
119139

120-
assert payload["func"](*payload["args"]) == 42
140+
self.assertEqual(payload["func"](*payload["args"]), 42)
141+
142+
def test_roundtrip_numpy_args(self):
143+
tmp_path = _make_temp_path(self)
121144

122-
def test_roundtrip_numpy_args(self, tmp_path):
123145
def dot(a, b):
124146
return np.dot(a, b)
125147

@@ -129,9 +151,11 @@ def dot(a, b):
129151
payload = self._save_and_load(tmp_path, dot, args=(arr_a, arr_b))
130152

131153
result = payload["func"](*payload["args"])
132-
assert result == pytest.approx(32.0)
154+
self.assertAlmostEqual(result, 32.0)
155+
156+
def test_roundtrip_complex_args(self):
157+
tmp_path = _make_temp_path(self)
133158

134-
def test_roundtrip_complex_args(self, tmp_path):
135159
def identity(x):
136160
return x
137161

@@ -143,4 +167,8 @@ def identity(x):
143167

144168
payload = self._save_and_load(tmp_path, identity, args=(complex_arg,))
145169

146-
assert payload["func"](*payload["args"]) == complex_arg
170+
self.assertEqual(payload["func"](*payload["args"]), complex_arg)
171+
172+
173+
if __name__ == "__main__":
174+
absltest.main()

tests/integration/__init__.py

Whitespace-only changes.
Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
"""Integration tests for the full serialization roundtrip."""
22

3+
import pathlib
34
import sys
5+
import tempfile
46
import zipfile
57

8+
from absl.testing import absltest
9+
610
from keras_remote.utils.packager import zip_working_dir
711

812

9-
class TestZipAndExtract:
13+
def _make_temp_path(test_case):
14+
"""Create a temp directory that is cleaned up after the test."""
15+
td = tempfile.TemporaryDirectory()
16+
test_case.addCleanup(td.cleanup)
17+
return pathlib.Path(td.name)
18+
19+
20+
class TestZipAndExtract(absltest.TestCase):
1021
def _zip_and_extract(self, src, tmp_path):
1122
"""Zip a directory and extract it, returning the extraction path."""
1223
zip_path = str(tmp_path / "context.zip")
@@ -17,20 +28,24 @@ def _zip_and_extract(self, src, tmp_path):
1728
zf.extractall(str(extract_dir))
1829
return extract_dir
1930

20-
def test_zip_extract_preserves_files(self, tmp_path):
21-
"""Zip → extract roundtrip preserves file content."""
31+
def test_zip_extract_preserves_files(self):
32+
"""Zip -> extract roundtrip preserves file content."""
33+
tmp_path = _make_temp_path(self)
2234
src = tmp_path / "project"
2335
src.mkdir()
2436
(src / "main.py").write_text("x = 1")
2537
(src / "config.json").write_text('{"key": "val"}')
2638

2739
extract_dir = self._zip_and_extract(src, tmp_path)
2840

29-
assert (extract_dir / "main.py").read_text() == "x = 1"
30-
assert (extract_dir / "config.json").read_text() == '{"key": "val"}'
41+
self.assertEqual((extract_dir / "main.py").read_text(), "x = 1")
42+
self.assertEqual(
43+
(extract_dir / "config.json").read_text(), '{"key": "val"}'
44+
)
3145

32-
def test_zip_extract_enables_imports(self, tmp_path):
46+
def test_zip_extract_enables_imports(self):
3347
"""Extracted workspace can be added to sys.path for imports."""
48+
tmp_path = _make_temp_path(self)
3449
src = tmp_path / "project"
3550
src.mkdir()
3651
(src / "helper.py").write_text("def greet(name):\n return f'Hi {name}'")
@@ -41,8 +56,11 @@ def test_zip_extract_enables_imports(self, tmp_path):
4156
try:
4257
import helper
4358

44-
assert helper.greet("World") == "Hi World"
59+
self.assertEqual(helper.greet("World"), "Hi World")
4560
finally:
4661
sys.path.remove(str(extract_dir))
47-
# Clean up imported module
4862
sys.modules.pop("helper", None)
63+
64+
65+
if __name__ == "__main__":
66+
absltest.main()

0 commit comments

Comments
 (0)