|
| 1 | +"""Tests for keras_remote.utils.packager — zip and payload serialization.""" |
| 2 | + |
| 3 | +import os |
| 4 | +import zipfile |
| 5 | + |
| 6 | +import cloudpickle |
| 7 | +import numpy as np |
| 8 | +import pytest |
| 9 | + |
| 10 | +from keras_remote.utils.packager import save_payload, zip_working_dir |
| 11 | + |
| 12 | + |
| 13 | +class TestZipWorkingDir: |
| 14 | + def _zip_and_list(self, src, tmp_path): |
| 15 | + """Zip src directory and return the set of archive member names.""" |
| 16 | + out = tmp_path / "context.zip" |
| 17 | + zip_working_dir(str(src), str(out)) |
| 18 | + with zipfile.ZipFile(str(out)) as zf: |
| 19 | + return set(zf.namelist()) |
| 20 | + |
| 21 | + def test_contains_all_files(self, tmp_path): |
| 22 | + src = tmp_path / "src" |
| 23 | + src.mkdir() |
| 24 | + (src / "a.py").write_text("a") |
| 25 | + (src / "b.txt").write_text("b") |
| 26 | + |
| 27 | + assert self._zip_and_list(src, tmp_path) == {"a.py", "b.txt"} |
| 28 | + |
| 29 | + def test_excludes_git_directory(self, tmp_path): |
| 30 | + src = tmp_path / "src" |
| 31 | + src.mkdir() |
| 32 | + git_dir = src / ".git" |
| 33 | + git_dir.mkdir() |
| 34 | + (git_dir / "config").write_text("git config") |
| 35 | + (src / "main.py").write_text("code") |
| 36 | + |
| 37 | + names = self._zip_and_list(src, tmp_path) |
| 38 | + assert all(".git" not in n for n in names) |
| 39 | + assert "main.py" in names |
| 40 | + |
| 41 | + def test_excludes_pycache_directory(self, tmp_path): |
| 42 | + src = tmp_path / "src" |
| 43 | + src.mkdir() |
| 44 | + cache_dir = src / "__pycache__" |
| 45 | + cache_dir.mkdir() |
| 46 | + (cache_dir / "mod.cpython-312.pyc").write_bytes(b"\x00") |
| 47 | + (src / "mod.py").write_text("code") |
| 48 | + |
| 49 | + names = self._zip_and_list(src, tmp_path) |
| 50 | + assert all("__pycache__" not in n for n in names) |
| 51 | + assert "mod.py" in names |
| 52 | + |
| 53 | + def test_preserves_nested_structure(self, tmp_path): |
| 54 | + src = tmp_path / "src" |
| 55 | + sub = src / "pkg" / "sub" |
| 56 | + sub.mkdir(parents=True) |
| 57 | + (sub / "deep.py").write_text("deep") |
| 58 | + (src / "top.py").write_text("top") |
| 59 | + |
| 60 | + names = self._zip_and_list(src, tmp_path) |
| 61 | + assert "top.py" in names |
| 62 | + assert os.path.join("pkg", "sub", "deep.py") in names |
| 63 | + |
| 64 | + def test_empty_directory(self, tmp_path): |
| 65 | + src = tmp_path / "empty" |
| 66 | + src.mkdir() |
| 67 | + |
| 68 | + assert self._zip_and_list(src, tmp_path) == set() |
| 69 | + |
| 70 | + |
| 71 | +class TestSavePayload: |
| 72 | + def _save_and_load(self, tmp_path, func, args=(), kwargs=None, env_vars=None): |
| 73 | + """Save a payload and load it back, returning the deserialized dict.""" |
| 74 | + if kwargs is None: |
| 75 | + kwargs = {} |
| 76 | + if env_vars is None: |
| 77 | + env_vars = {} |
| 78 | + out = tmp_path / "payload.pkl" |
| 79 | + save_payload(func, args, kwargs, env_vars, str(out)) |
| 80 | + with open(str(out), "rb") as f: |
| 81 | + return cloudpickle.load(f) |
| 82 | + |
| 83 | + def test_roundtrip_simple_function(self, tmp_path): |
| 84 | + def add(a, b): |
| 85 | + return a + b |
| 86 | + |
| 87 | + payload = self._save_and_load( |
| 88 | + tmp_path, add, args=(2, 3), env_vars={"KEY": "val"} |
| 89 | + ) |
| 90 | + |
| 91 | + assert payload["func"](2, 3) == 5 |
| 92 | + assert payload["args"] == (2, 3) |
| 93 | + assert payload["kwargs"] == {} |
| 94 | + assert payload["env_vars"] == {"KEY": "val"} |
| 95 | + |
| 96 | + def test_roundtrip_with_kwargs(self, tmp_path): |
| 97 | + def greet(name, greeting="Hello"): |
| 98 | + return f"{greeting}, {name}" |
| 99 | + |
| 100 | + payload = self._save_and_load( |
| 101 | + tmp_path, greet, args=("World",), kwargs={"greeting": "Hi"} |
| 102 | + ) |
| 103 | + |
| 104 | + result = payload["func"](*payload["args"], **payload["kwargs"]) |
| 105 | + assert result == "Hi, World" |
| 106 | + |
| 107 | + def test_roundtrip_lambda(self, tmp_path): |
| 108 | + payload = self._save_and_load(tmp_path, lambda x: x * 2, args=(5,)) |
| 109 | + |
| 110 | + assert payload["func"](*payload["args"]) == 10 |
| 111 | + |
| 112 | + def test_roundtrip_closure(self, tmp_path): |
| 113 | + multiplier = 7 |
| 114 | + |
| 115 | + def make_closure(x): |
| 116 | + return x * multiplier |
| 117 | + |
| 118 | + payload = self._save_and_load(tmp_path, make_closure, args=(6,)) |
| 119 | + |
| 120 | + assert payload["func"](*payload["args"]) == 42 |
| 121 | + |
| 122 | + def test_roundtrip_numpy_args(self, tmp_path): |
| 123 | + def dot(a, b): |
| 124 | + return np.dot(a, b) |
| 125 | + |
| 126 | + arr_a = np.array([1.0, 2.0, 3.0]) |
| 127 | + arr_b = np.array([4.0, 5.0, 6.0]) |
| 128 | + |
| 129 | + payload = self._save_and_load(tmp_path, dot, args=(arr_a, arr_b)) |
| 130 | + |
| 131 | + result = payload["func"](*payload["args"]) |
| 132 | + assert result == pytest.approx(32.0) |
| 133 | + |
| 134 | + def test_roundtrip_complex_args(self, tmp_path): |
| 135 | + def identity(x): |
| 136 | + return x |
| 137 | + |
| 138 | + complex_arg = { |
| 139 | + "key": [1, 2, 3], |
| 140 | + "nested": {"a": True, "b": None}, |
| 141 | + "tuple": (1, "two", 3.0), |
| 142 | + } |
| 143 | + |
| 144 | + payload = self._save_and_load(tmp_path, identity, args=(complex_arg,)) |
| 145 | + |
| 146 | + assert payload["func"](*payload["args"]) == complex_arg |
0 commit comments