Skip to content

Commit b18cc9d

Browse files
Adds serialization and packaging tests
1 parent 67477a8 commit b18cc9d

File tree

2 files changed

+194
-0
lines changed

2 files changed

+194
-0
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Tests for keras_remote.utils.packager — zip and payload serialization."""
2+
3+
import os
4+
import zipfile
5+
6+
import cloudpickle
7+
import numpy as np
8+
import pytest
9+
10+
from keras_remote.utils.packager import save_payload, zip_working_dir
11+
12+
13+
class TestZipWorkingDir:
14+
def _zip_and_list(self, src, tmp_path):
15+
"""Zip src directory and return the set of archive member names."""
16+
out = tmp_path / "context.zip"
17+
zip_working_dir(str(src), str(out))
18+
with zipfile.ZipFile(str(out)) as zf:
19+
return set(zf.namelist())
20+
21+
def test_contains_all_files(self, tmp_path):
22+
src = tmp_path / "src"
23+
src.mkdir()
24+
(src / "a.py").write_text("a")
25+
(src / "b.txt").write_text("b")
26+
27+
assert self._zip_and_list(src, tmp_path) == {"a.py", "b.txt"}
28+
29+
def test_excludes_git_directory(self, tmp_path):
30+
src = tmp_path / "src"
31+
src.mkdir()
32+
git_dir = src / ".git"
33+
git_dir.mkdir()
34+
(git_dir / "config").write_text("git config")
35+
(src / "main.py").write_text("code")
36+
37+
names = self._zip_and_list(src, tmp_path)
38+
assert all(".git" not in n for n in names)
39+
assert "main.py" in names
40+
41+
def test_excludes_pycache_directory(self, tmp_path):
42+
src = tmp_path / "src"
43+
src.mkdir()
44+
cache_dir = src / "__pycache__"
45+
cache_dir.mkdir()
46+
(cache_dir / "mod.cpython-312.pyc").write_bytes(b"\x00")
47+
(src / "mod.py").write_text("code")
48+
49+
names = self._zip_and_list(src, tmp_path)
50+
assert all("__pycache__" not in n for n in names)
51+
assert "mod.py" in names
52+
53+
def test_preserves_nested_structure(self, tmp_path):
54+
src = tmp_path / "src"
55+
sub = src / "pkg" / "sub"
56+
sub.mkdir(parents=True)
57+
(sub / "deep.py").write_text("deep")
58+
(src / "top.py").write_text("top")
59+
60+
names = self._zip_and_list(src, tmp_path)
61+
assert "top.py" in names
62+
assert os.path.join("pkg", "sub", "deep.py") in names
63+
64+
def test_empty_directory(self, tmp_path):
65+
src = tmp_path / "empty"
66+
src.mkdir()
67+
68+
assert self._zip_and_list(src, tmp_path) == set()
69+
70+
71+
class TestSavePayload:
72+
def _save_and_load(self, tmp_path, func, args=(), kwargs=None, env_vars=None):
73+
"""Save a payload and load it back, returning the deserialized dict."""
74+
if kwargs is None:
75+
kwargs = {}
76+
if env_vars is None:
77+
env_vars = {}
78+
out = tmp_path / "payload.pkl"
79+
save_payload(func, args, kwargs, env_vars, str(out))
80+
with open(str(out), "rb") as f:
81+
return cloudpickle.load(f)
82+
83+
def test_roundtrip_simple_function(self, tmp_path):
84+
def add(a, b):
85+
return a + b
86+
87+
payload = self._save_and_load(
88+
tmp_path, add, args=(2, 3), env_vars={"KEY": "val"}
89+
)
90+
91+
assert payload["func"](2, 3) == 5
92+
assert payload["args"] == (2, 3)
93+
assert payload["kwargs"] == {}
94+
assert payload["env_vars"] == {"KEY": "val"}
95+
96+
def test_roundtrip_with_kwargs(self, tmp_path):
97+
def greet(name, greeting="Hello"):
98+
return f"{greeting}, {name}"
99+
100+
payload = self._save_and_load(
101+
tmp_path, greet, args=("World",), kwargs={"greeting": "Hi"}
102+
)
103+
104+
result = payload["func"](*payload["args"], **payload["kwargs"])
105+
assert result == "Hi, World"
106+
107+
def test_roundtrip_lambda(self, tmp_path):
108+
payload = self._save_and_load(tmp_path, lambda x: x * 2, args=(5,))
109+
110+
assert payload["func"](*payload["args"]) == 10
111+
112+
def test_roundtrip_closure(self, tmp_path):
113+
multiplier = 7
114+
115+
def make_closure(x):
116+
return x * multiplier
117+
118+
payload = self._save_and_load(tmp_path, make_closure, args=(6,))
119+
120+
assert payload["func"](*payload["args"]) == 42
121+
122+
def test_roundtrip_numpy_args(self, tmp_path):
123+
def dot(a, b):
124+
return np.dot(a, b)
125+
126+
arr_a = np.array([1.0, 2.0, 3.0])
127+
arr_b = np.array([4.0, 5.0, 6.0])
128+
129+
payload = self._save_and_load(tmp_path, dot, args=(arr_a, arr_b))
130+
131+
result = payload["func"](*payload["args"])
132+
assert result == pytest.approx(32.0)
133+
134+
def test_roundtrip_complex_args(self, tmp_path):
135+
def identity(x):
136+
return x
137+
138+
complex_arg = {
139+
"key": [1, 2, 3],
140+
"nested": {"a": True, "b": None},
141+
"tuple": (1, "two", 3.0),
142+
}
143+
144+
payload = self._save_and_load(tmp_path, identity, args=(complex_arg,))
145+
146+
assert payload["func"](*payload["args"]) == complex_arg
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Integration tests for the full serialization roundtrip."""
2+
3+
import sys
4+
import zipfile
5+
6+
from keras_remote.utils.packager import zip_working_dir
7+
8+
9+
class TestZipAndExtract:
10+
def _zip_and_extract(self, src, tmp_path):
11+
"""Zip a directory and extract it, returning the extraction path."""
12+
zip_path = str(tmp_path / "context.zip")
13+
zip_working_dir(str(src), zip_path)
14+
extract_dir = tmp_path / "extracted"
15+
extract_dir.mkdir()
16+
with zipfile.ZipFile(zip_path) as zf:
17+
zf.extractall(str(extract_dir))
18+
return extract_dir
19+
20+
def test_zip_extract_preserves_files(self, tmp_path):
21+
"""Zip → extract roundtrip preserves file content."""
22+
src = tmp_path / "project"
23+
src.mkdir()
24+
(src / "main.py").write_text("x = 1")
25+
(src / "config.json").write_text('{"key": "val"}')
26+
27+
extract_dir = self._zip_and_extract(src, tmp_path)
28+
29+
assert (extract_dir / "main.py").read_text() == "x = 1"
30+
assert (extract_dir / "config.json").read_text() == '{"key": "val"}'
31+
32+
def test_zip_extract_enables_imports(self, tmp_path):
33+
"""Extracted workspace can be added to sys.path for imports."""
34+
src = tmp_path / "project"
35+
src.mkdir()
36+
(src / "helper.py").write_text("def greet(name):\n return f'Hi {name}'")
37+
38+
extract_dir = self._zip_and_extract(src, tmp_path)
39+
40+
sys.path.insert(0, str(extract_dir))
41+
try:
42+
import helper
43+
44+
assert helper.greet("World") == "Hi World"
45+
finally:
46+
sys.path.remove(str(extract_dir))
47+
# Clean up imported module
48+
sys.modules.pop("helper", None)

0 commit comments

Comments
 (0)