1717from zephyr .execution import ZephyrContext
1818from zephyr .writers import write_parquet_file
1919
20- from .conftest import CallCounter
21-
2220
2321@pytest .fixture
2422def sample_data ():
@@ -192,24 +190,14 @@ def test_chaining_operations(zephyr_ctx):
192190
193191def test_lazy_evaluation ():
194192 """Test that operations are lazy until backend executes."""
195- call_count = 0
196-
197- def counting_fn (x ):
198- nonlocal call_count
199- call_count += 1
200- return x * 2
193+ ds = Dataset .from_list ([1 , 2 , 3 ]).map (lambda x : x * 2 )
201194
202- # Create dataset with map - should not execute yet
203- ds = Dataset .from_list ([1 , 2 , 3 ]).map (counting_fn )
204- assert call_count == 0
205-
206- # Now execute - should call function
195+ # Now execute - should call function and produce results
207196 client = LocalClient ()
208197 ctx = ZephyrContext (client = client , max_workers = 1 , resources = ResourceConfig (cpu = 1 , ram = "512m" ), name = "test-dataset" )
209198 try :
210199 result = list (ctx .execute (ds ))
211- assert result == [2 , 4 , 6 ]
212- assert call_count == 3
200+ assert sorted (result ) == [2 , 4 , 6 ]
213201 finally :
214202 ctx .shutdown ()
215203
@@ -992,21 +980,21 @@ def test_skip_existing_clean_run(tmp_path, sample_input_files):
992980 output_dir = tmp_path / "output"
993981 output_dir .mkdir ()
994982
995- counter = CallCounter ()
996983 ds = (
997984 Dataset .from_files (f"{ sample_input_files } /*.jsonl" )
998- .flat_map (lambda x : counter . counting_flat_map ( x ) )
999- .map (lambda x : counter . counting_map ( x ) )
985+ .flat_map (load_file )
986+ .map (lambda x : { ** x , "processed" : True } )
1000987 .write_jsonl (str (output_dir / "output-{shard:05d}.jsonl" ), skip_existing = True )
1001988 )
1002989
1003990 try :
1004991 result = list (ctx .execute (ds ))
1005992 assert len (result ) == 3
1006993 assert all (Path (p ).exists () for p in result )
1007- assert counter .flat_map_count == 3 # All files loaded
1008- assert counter .map_count == 3 # All items mapped
1009- assert sorted (counter .processed_ids ) == [0 , 1 , 2 ] # All shards ran
994+ # All shards ran -- each output has "processed" flag
995+ for p in result :
996+ records = [json .loads (line ) for line in Path (p ).read_text ().strip ().splitlines ()]
997+ assert all (r .get ("processed" ) for r in records )
1010998 finally :
1011999 ctx .shutdown ()
10121000
@@ -1018,25 +1006,28 @@ def test_skip_existing_one_file_exists(tmp_path, sample_input_files):
10181006 output_dir = tmp_path / "output"
10191007 output_dir .mkdir ()
10201008
1021- # Manually create one output file (shard 1)
1009+ # Manually create one output file (shard 1) -- no "processed" flag
10221010 with open (output_dir / "output-00001.jsonl" , "w" ) as f :
1023- f .write ('{"id": 1, "processed ": true}\n ' )
1011+ f .write ('{"id": 1, "skipped ": true}\n ' )
10241012
1025- counter = CallCounter ()
10261013 ds = (
10271014 Dataset .from_files (f"{ sample_input_files } /*.jsonl" )
1028- .flat_map (lambda x : counter . counting_flat_map ( x ) )
1029- .map (lambda x : counter . counting_map ( x ) )
1015+ .flat_map (load_file )
1016+ .map (lambda x : { ** x , "processed" : True } )
10301017 .write_jsonl (str (output_dir / "output-{shard:05d}.jsonl" ), skip_existing = True )
10311018 )
10321019
10331020 try :
10341021 result = list (ctx .execute (ds ))
10351022 assert len (result ) == 3
10361023 assert all (Path (p ).exists () for p in result )
1037- assert counter .flat_map_count == 2 # Only 2 files loaded (shard 1 skipped)
1038- assert counter .map_count == 2 # Only 2 items mapped
1039- assert sorted (counter .processed_ids ) == [0 , 2 ] # Only shards 0 and 2 ran
1024+ # Shard 1 was skipped -- its file still has the pre-existing content
1025+ shard1 = [json .loads (line ) for line in (output_dir / "output-00001.jsonl" ).read_text ().strip ().splitlines ()]
1026+ assert shard1 == [{"id" : 1 , "skipped" : True }]
1027+ # Shards 0 and 2 ran -- they have "processed" flag
1028+ for shard_file in ["output-00000.jsonl" , "output-00002.jsonl" ]:
1029+ records = [json .loads (line ) for line in (output_dir / shard_file ).read_text ().strip ().splitlines ()]
1030+ assert all (r .get ("processed" ) for r in records )
10401031 finally :
10411032 ctx .shutdown ()
10421033
@@ -1048,36 +1039,38 @@ def test_skip_existing_all_files_exist(tmp_path, sample_input_files):
10481039 output_dir = tmp_path / "output"
10491040 output_dir .mkdir ()
10501041
1051- counter = CallCounter ()
10521042 ds = (
10531043 Dataset .from_files (f"{ sample_input_files } /*.jsonl" )
1054- .flat_map (lambda x : counter . counting_flat_map ( x ) )
1055- .map (lambda x : counter . counting_map ( x ) )
1044+ .flat_map (load_file )
1045+ .map (lambda x : { ** x , "processed" : True } )
10561046 .write_jsonl (str (output_dir / "output-{shard:05d}.jsonl" ), skip_existing = True )
10571047 )
10581048
10591049 try :
10601050 # First run: create all output files
10611051 result = list (ctx .execute (ds ))
10621052 assert len (result ) == 3
1063- assert counter .flat_map_count == 3
1064- assert counter .map_count == 3
1065- assert sorted (counter .processed_ids ) == [0 , 1 , 2 ] # All shards ran
1053+ assert all (Path (p ).exists () for p in result )
1054+ for p in result :
1055+ records = [json .loads (line ) for line in Path (p ).read_text ().strip ().splitlines ()]
1056+ assert all (r .get ("processed" ) for r in records )
10661057
1067- # Second run: all files exist, nothing should process
1068- counter .reset ()
1069- ds = (
1058+ # Record modification times
1059+ mtimes = {p : Path (p ).stat ().st_mtime for p in result }
1060+
1061+ # Second run: all files exist, nothing should be rewritten
1062+ ds2 = (
10701063 Dataset .from_files (f"{ sample_input_files } /*.jsonl" )
1071- .flat_map (counter . counting_flat_map )
1072- .map (counter . counting_map )
1064+ .flat_map (load_file )
1065+ .map (lambda x : { ** x , "processed" : True } )
10731066 .write_jsonl (str (output_dir / "output-{shard:05d}.jsonl" ), skip_existing = True )
10741067 )
10751068
1076- result = list (ctx .execute (ds ))
1077- assert len (result ) == 3
1078- assert counter . flat_map_count == 0 # Nothing loaded
1079- assert counter . map_count == 0 # Nothing mapped
1080- assert counter . processed_ids == [] # No shards ran
1069+ result2 = list (ctx .execute (ds2 ))
1070+ assert len (result2 ) == 3
1071+ # Files should be untouched -- same mtime
1072+ for p in result2 :
1073+ assert Path ( p ). stat (). st_mtime == mtimes [ p ]
10811074 finally :
10821075 ctx .shutdown ()
10831076
0 commit comments