11"""Tests for keras_remote.utils.packager — zip and payload serialization."""
22
33import os
4+ import pathlib
5+ import tempfile
46import zipfile
57
68import cloudpickle
79import numpy as np
8- import pytest
10+ from absl . testing import absltest
911
1012from keras_remote .utils .packager import save_payload , zip_working_dir
1113
1214
13- class TestZipWorkingDir :
15+ def _make_temp_path (test_case ):
16+ """Create a temp directory that is cleaned up after the test."""
17+ td = tempfile .TemporaryDirectory ()
18+ test_case .addCleanup (td .cleanup )
19+ return pathlib .Path (td .name )
20+
21+
22+ class TestZipWorkingDir (absltest .TestCase ):
1423 def _zip_and_list (self , src , tmp_path ):
1524 """Zip src directory and return the set of archive member names."""
1625 out = tmp_path / "context.zip"
1726 zip_working_dir (str (src ), str (out ))
1827 with zipfile .ZipFile (str (out )) as zf :
1928 return set (zf .namelist ())
2029
21- def test_contains_all_files (self , tmp_path ):
30+ def test_contains_all_files (self ):
31+ tmp_path = _make_temp_path (self )
2232 src = tmp_path / "src"
2333 src .mkdir ()
2434 (src / "a.py" ).write_text ("a" )
2535 (src / "b.txt" ).write_text ("b" )
2636
27- assert self ._zip_and_list (src , tmp_path ) == {"a.py" , "b.txt" }
37+ self .assertEqual ( self . _zip_and_list (src , tmp_path ), {"a.py" , "b.txt" })
2838
29- def test_excludes_git_directory (self , tmp_path ):
39+ def test_excludes_git_directory (self ):
40+ tmp_path = _make_temp_path (self )
3041 src = tmp_path / "src"
3142 src .mkdir ()
3243 git_dir = src / ".git"
@@ -35,10 +46,11 @@ def test_excludes_git_directory(self, tmp_path):
3546 (src / "main.py" ).write_text ("code" )
3647
3748 names = self ._zip_and_list (src , tmp_path )
38- assert all (".git" not in n for n in names )
39- assert "main.py" in names
49+ self . assertTrue ( all (".git" not in n for n in names ) )
50+ self . assertIn ( "main.py" , names )
4051
41- def test_excludes_pycache_directory (self , tmp_path ):
52+ def test_excludes_pycache_directory (self ):
53+ tmp_path = _make_temp_path (self )
4254 src = tmp_path / "src"
4355 src .mkdir ()
4456 cache_dir = src / "__pycache__"
@@ -47,28 +59,30 @@ def test_excludes_pycache_directory(self, tmp_path):
4759 (src / "mod.py" ).write_text ("code" )
4860
4961 names = self ._zip_and_list (src , tmp_path )
50- assert all ("__pycache__" not in n for n in names )
51- assert "mod.py" in names
62+ self . assertTrue ( all ("__pycache__" not in n for n in names ) )
63+ self . assertIn ( "mod.py" , names )
5264
53- def test_preserves_nested_structure (self , tmp_path ):
65+ def test_preserves_nested_structure (self ):
66+ tmp_path = _make_temp_path (self )
5467 src = tmp_path / "src"
5568 sub = src / "pkg" / "sub"
5669 sub .mkdir (parents = True )
5770 (sub / "deep.py" ).write_text ("deep" )
5871 (src / "top.py" ).write_text ("top" )
5972
6073 names = self ._zip_and_list (src , tmp_path )
61- assert "top.py" in names
62- assert os .path .join ("pkg" , "sub" , "deep.py" ) in names
74+ self . assertIn ( "top.py" , names )
75+ self . assertIn ( os .path .join ("pkg" , "sub" , "deep.py" ), names )
6376
64- def test_empty_directory (self , tmp_path ):
77+ def test_empty_directory (self ):
78+ tmp_path = _make_temp_path (self )
6579 src = tmp_path / "empty"
6680 src .mkdir ()
6781
68- assert self ._zip_and_list (src , tmp_path ) == set ()
82+ self .assertEqual ( self . _zip_and_list (src , tmp_path ), set () )
6983
7084
71- class TestSavePayload :
85+ class TestSavePayload ( absltest . TestCase ) :
7286 def _save_and_load (self , tmp_path , func , args = (), kwargs = None , env_vars = None ):
7387 """Save a payload and load it back, returning the deserialized dict."""
7488 if kwargs is None :
@@ -80,20 +94,24 @@ def _save_and_load(self, tmp_path, func, args=(), kwargs=None, env_vars=None):
8094 with open (str (out ), "rb" ) as f :
8195 return cloudpickle .load (f )
8296
83- def test_roundtrip_simple_function (self , tmp_path ):
97+ def test_roundtrip_simple_function (self ):
98+ tmp_path = _make_temp_path (self )
99+
84100 def add (a , b ):
85101 return a + b
86102
87103 payload = self ._save_and_load (
88104 tmp_path , add , args = (2 , 3 ), env_vars = {"KEY" : "val" }
89105 )
90106
91- assert payload ["func" ](2 , 3 ) == 5
92- assert payload ["args" ] == (2 , 3 )
93- assert payload ["kwargs" ] == {}
94- assert payload ["env_vars" ] == {"KEY" : "val" }
107+ self .assertEqual (payload ["func" ](2 , 3 ), 5 )
108+ self .assertEqual (payload ["args" ], (2 , 3 ))
109+ self .assertEqual (payload ["kwargs" ], {})
110+ self .assertEqual (payload ["env_vars" ], {"KEY" : "val" })
111+
112+ def test_roundtrip_with_kwargs (self ):
113+ tmp_path = _make_temp_path (self )
95114
96- def test_roundtrip_with_kwargs (self , tmp_path ):
97115 def greet (name , greeting = "Hello" ):
98116 return f"{ greeting } , { name } "
99117
@@ -102,24 +120,28 @@ def greet(name, greeting="Hello"):
102120 )
103121
104122 result = payload ["func" ](* payload ["args" ], ** payload ["kwargs" ])
105- assert result == "Hi, World"
123+ self . assertEqual ( result , "Hi, World" )
106124
107- def test_roundtrip_lambda (self , tmp_path ):
125+ def test_roundtrip_lambda (self ):
126+ tmp_path = _make_temp_path (self )
108127 payload = self ._save_and_load (tmp_path , lambda x : x * 2 , args = (5 ,))
109128
110- assert payload ["func" ](* payload ["args" ]) == 10
129+ self . assertEqual ( payload ["func" ](* payload ["args" ]), 10 )
111130
112- def test_roundtrip_closure (self , tmp_path ):
131+ def test_roundtrip_closure (self ):
132+ tmp_path = _make_temp_path (self )
113133 multiplier = 7
114134
115135 def make_closure (x ):
116136 return x * multiplier
117137
118138 payload = self ._save_and_load (tmp_path , make_closure , args = (6 ,))
119139
120- assert payload ["func" ](* payload ["args" ]) == 42
140+ self .assertEqual (payload ["func" ](* payload ["args" ]), 42 )
141+
142+ def test_roundtrip_numpy_args (self ):
143+ tmp_path = _make_temp_path (self )
121144
122- def test_roundtrip_numpy_args (self , tmp_path ):
123145 def dot (a , b ):
124146 return np .dot (a , b )
125147
@@ -129,9 +151,11 @@ def dot(a, b):
129151 payload = self ._save_and_load (tmp_path , dot , args = (arr_a , arr_b ))
130152
131153 result = payload ["func" ](* payload ["args" ])
132- assert result == pytest .approx (32.0 )
154+ self .assertAlmostEqual (result , 32.0 )
155+
156+ def test_roundtrip_complex_args (self ):
157+ tmp_path = _make_temp_path (self )
133158
134- def test_roundtrip_complex_args (self , tmp_path ):
135159 def identity (x ):
136160 return x
137161
@@ -143,4 +167,8 @@ def identity(x):
143167
144168 payload = self ._save_and_load (tmp_path , identity , args = (complex_arg ,))
145169
146- assert payload ["func" ](* payload ["args" ]) == complex_arg
170+ self .assertEqual (payload ["func" ](* payload ["args" ]), complex_arg )
171+
172+
173+ if __name__ == "__main__" :
174+ absltest .main ()
0 commit comments