diff --git a/keras_remote/utils/test_packager.py b/keras_remote/utils/test_packager.py new file mode 100644 index 0000000..519da62 --- /dev/null +++ b/keras_remote/utils/test_packager.py @@ -0,0 +1,174 @@ +"""Tests for keras_remote.utils.packager — zip and payload serialization.""" + +import os +import pathlib +import tempfile +import zipfile + +import cloudpickle +import numpy as np +from absl.testing import absltest + +from keras_remote.utils.packager import save_payload, zip_working_dir + + +def _make_temp_path(test_case): + """Create a temp directory that is cleaned up after the test.""" + td = tempfile.TemporaryDirectory() + test_case.addCleanup(td.cleanup) + return pathlib.Path(td.name) + + +class TestZipWorkingDir(absltest.TestCase): + def _zip_and_list(self, src, tmp_path): + """Zip src directory and return the set of archive member names.""" + out = tmp_path / "context.zip" + zip_working_dir(str(src), str(out)) + with zipfile.ZipFile(str(out)) as zf: + return set(zf.namelist()) + + def test_contains_all_files(self): + tmp_path = _make_temp_path(self) + src = tmp_path / "src" + src.mkdir() + (src / "a.py").write_text("a") + (src / "b.txt").write_text("b") + + self.assertEqual(self._zip_and_list(src, tmp_path), {"a.py", "b.txt"}) + + def test_excludes_git_directory(self): + tmp_path = _make_temp_path(self) + src = tmp_path / "src" + src.mkdir() + git_dir = src / ".git" + git_dir.mkdir() + (git_dir / "config").write_text("git config") + (src / "main.py").write_text("code") + + names = self._zip_and_list(src, tmp_path) + self.assertTrue(all(".git" not in n for n in names)) + self.assertIn("main.py", names) + + def test_excludes_pycache_directory(self): + tmp_path = _make_temp_path(self) + src = tmp_path / "src" + src.mkdir() + cache_dir = src / "__pycache__" + cache_dir.mkdir() + (cache_dir / "mod.cpython-312.pyc").write_bytes(b"\x00") + (src / "mod.py").write_text("code") + + names = self._zip_and_list(src, tmp_path) + self.assertTrue(all("__pycache__" not in n for n in names)) + self.assertIn("mod.py", names) + + def test_preserves_nested_structure(self): + tmp_path = _make_temp_path(self) + src = tmp_path / "src" + sub = src / "pkg" / "sub" + sub.mkdir(parents=True) + (sub / "deep.py").write_text("deep") + (src / "top.py").write_text("top") + + names = self._zip_and_list(src, tmp_path) + self.assertIn("top.py", names) + self.assertIn(os.path.join("pkg", "sub", "deep.py"), names) + + def test_empty_directory(self): + tmp_path = _make_temp_path(self) + src = tmp_path / "empty" + src.mkdir() + + self.assertEqual(self._zip_and_list(src, tmp_path), set()) + + +class TestSavePayload(absltest.TestCase): + def _save_and_load(self, tmp_path, func, args=(), kwargs=None, env_vars=None): + """Save a payload and load it back, returning the deserialized dict.""" + if kwargs is None: + kwargs = {} + if env_vars is None: + env_vars = {} + out = tmp_path / "payload.pkl" + save_payload(func, args, kwargs, env_vars, str(out)) + with open(str(out), "rb") as f: + return cloudpickle.load(f) + + def test_roundtrip_simple_function(self): + tmp_path = _make_temp_path(self) + + def add(a, b): + return a + b + + payload = self._save_and_load( + tmp_path, add, args=(2, 3), env_vars={"KEY": "val"} + ) + + self.assertEqual(payload["func"](2, 3), 5) + self.assertEqual(payload["args"], (2, 3)) + self.assertEqual(payload["kwargs"], {}) + self.assertEqual(payload["env_vars"], {"KEY": "val"}) + + def test_roundtrip_with_kwargs(self): + tmp_path = _make_temp_path(self) + + def greet(name, greeting="Hello"): + return f"{greeting}, {name}" + + payload = self._save_and_load( + tmp_path, greet, args=("World",), kwargs={"greeting": "Hi"} + ) + + result = payload["func"](*payload["args"], **payload["kwargs"]) + self.assertEqual(result, "Hi, World") + + def test_roundtrip_lambda(self): + tmp_path = _make_temp_path(self) + payload = self._save_and_load(tmp_path, lambda x: x * 2, args=(5,)) + + self.assertEqual(payload["func"](*payload["args"]), 10) + + def test_roundtrip_closure(self): + tmp_path = _make_temp_path(self) + multiplier = 7 + + def make_closure(x): + return x * multiplier + + payload = self._save_and_load(tmp_path, make_closure, args=(6,)) + + self.assertEqual(payload["func"](*payload["args"]), 42) + + def test_roundtrip_numpy_args(self): + tmp_path = _make_temp_path(self) + + def dot(a, b): + return np.dot(a, b) + + arr_a = np.array([1.0, 2.0, 3.0]) + arr_b = np.array([4.0, 5.0, 6.0]) + + payload = self._save_and_load(tmp_path, dot, args=(arr_a, arr_b)) + + result = payload["func"](*payload["args"]) + self.assertAlmostEqual(result, 32.0) + + def test_roundtrip_complex_args(self): + tmp_path = _make_temp_path(self) + + def identity(x): + return x + + complex_arg = { + "key": [1, 2, 3], + "nested": {"a": True, "b": None}, + "tuple": (1, "two", 3.0), + } + + payload = self._save_and_load(tmp_path, identity, args=(complex_arg,)) + + self.assertEqual(payload["func"](*payload["args"]), complex_arg) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_packager_roundtrip.py b/tests/integration/test_packager_roundtrip.py new file mode 100644 index 0000000..b9dba91 --- /dev/null +++ b/tests/integration/test_packager_roundtrip.py @@ -0,0 +1,66 @@ +"""Integration tests for the full serialization roundtrip.""" + +import pathlib +import sys +import tempfile +import zipfile + +from absl.testing import absltest + +from keras_remote.utils.packager import zip_working_dir + + +def _make_temp_path(test_case): + """Create a temp directory that is cleaned up after the test.""" + td = tempfile.TemporaryDirectory() + test_case.addCleanup(td.cleanup) + return pathlib.Path(td.name) + + +class TestZipAndExtract(absltest.TestCase): + def _zip_and_extract(self, src, tmp_path): + """Zip a directory and extract it, returning the extraction path.""" + zip_path = str(tmp_path / "context.zip") + zip_working_dir(str(src), zip_path) + extract_dir = tmp_path / "extracted" + extract_dir.mkdir() + with zipfile.ZipFile(zip_path) as zf: + zf.extractall(str(extract_dir)) + return extract_dir + + def test_zip_extract_preserves_files(self): + """Zip -> extract roundtrip preserves file content.""" + tmp_path = _make_temp_path(self) + src = tmp_path / "project" + src.mkdir() + (src / "main.py").write_text("x = 1") + (src / "config.json").write_text('{"key": "val"}') + + extract_dir = self._zip_and_extract(src, tmp_path) + + self.assertEqual((extract_dir / "main.py").read_text(), "x = 1") + self.assertEqual( + (extract_dir / "config.json").read_text(), '{"key": "val"}' + ) + + def test_zip_extract_enables_imports(self): + """Extracted workspace can be added to sys.path for imports.""" + tmp_path = _make_temp_path(self) + src = tmp_path / "project" + src.mkdir() + (src / "helper.py").write_text("def greet(name):\n return f'Hi {name}'") + + extract_dir = self._zip_and_extract(src, tmp_path) + + sys.path.insert(0, str(extract_dir)) + try: + import helper + + self.assertEqual(helper.greet("World"), "Hi World") + finally: + sys.path.remove(str(extract_dir)) + sys.modules.pop("helper", None) + + +if __name__ == "__main__": + absltest.main()