Skip to content

Commit d63b363

Browse files
committed
refactor: extract atomic_json_write helper, add 24 checkpoint tests
Extract the duplicated temp-file + fsync + os.replace pattern from batch_runner.py (1 instance) and process_registry.py (2 instances) into a shared utils.atomic_json_write() function. Add 12 tests for atomic_json_write covering: valid JSON, parent dir creation, overwrite, crash safety (original preserved on error), no temp file leaks, string paths, unicode, custom indent, concurrent writes. Add 12 tests for batch_runner checkpoint behavior covering: _save_checkpoint (valid JSON, last_updated, overwrite, lock/no-lock, parent dirs, no temp leaks), _load_checkpoint (missing file, existing data, corrupt JSON), and resume logic (preserves prior progress, different run_name starts fresh).
1 parent c05c606 commit d63b363

File tree

5 files changed

+340
-64
lines changed

5 files changed

+340
-64
lines changed

batch_runner.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
from datetime import datetime
3030
from multiprocessing import Pool, Lock
3131
import traceback
32-
import tempfile
33-
3432
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
3533
from rich.console import Console
3634
import fire
@@ -703,32 +701,12 @@ def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock]
703701
"""
704702
checkpoint_data["last_updated"] = datetime.now().isoformat()
705703

706-
def _atomic_write():
707-
"""Write checkpoint atomically (temp file + replace) to avoid corruption on crash."""
708-
self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
709-
fd, tmp_path = tempfile.mkstemp(
710-
dir=str(self.checkpoint_file.parent),
711-
prefix='.checkpoint_',
712-
suffix='.tmp',
713-
)
714-
try:
715-
with os.fdopen(fd, 'w', encoding='utf-8') as f:
716-
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
717-
f.flush()
718-
os.fsync(f.fileno())
719-
os.replace(tmp_path, self.checkpoint_file)
720-
except BaseException:
721-
try:
722-
os.unlink(tmp_path)
723-
except OSError:
724-
pass
725-
raise
726-
704+
from utils import atomic_json_write
727705
if lock:
728706
with lock:
729-
_atomic_write()
707+
atomic_json_write(self.checkpoint_file, checkpoint_data)
730708
else:
731-
_atomic_write()
709+
atomic_json_write(self.checkpoint_file, checkpoint_data)
732710

733711
def _scan_completed_prompts_by_content(self) -> set:
734712
"""

tests/test_atomic_json_write.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""Tests for utils.atomic_json_write — crash-safe JSON file writes."""
2+
3+
import json
4+
import os
5+
from pathlib import Path
6+
from unittest.mock import patch
7+
8+
import pytest
9+
10+
from utils import atomic_json_write
11+
12+
13+
class TestAtomicJsonWrite:
14+
"""Core atomic write behavior."""
15+
16+
def test_writes_valid_json(self, tmp_path):
17+
target = tmp_path / "data.json"
18+
data = {"key": "value", "nested": {"a": 1}}
19+
atomic_json_write(target, data)
20+
21+
result = json.loads(target.read_text(encoding="utf-8"))
22+
assert result == data
23+
24+
def test_creates_parent_directories(self, tmp_path):
25+
target = tmp_path / "deep" / "nested" / "dir" / "data.json"
26+
atomic_json_write(target, {"ok": True})
27+
28+
assert target.exists()
29+
assert json.loads(target.read_text())["ok"] is True
30+
31+
def test_overwrites_existing_file(self, tmp_path):
32+
target = tmp_path / "data.json"
33+
target.write_text('{"old": true}')
34+
35+
atomic_json_write(target, {"new": True})
36+
result = json.loads(target.read_text())
37+
assert result == {"new": True}
38+
39+
def test_preserves_original_on_serialization_error(self, tmp_path):
40+
target = tmp_path / "data.json"
41+
original = {"preserved": True}
42+
target.write_text(json.dumps(original))
43+
44+
# Try to write non-serializable data — should fail
45+
with pytest.raises(TypeError):
46+
atomic_json_write(target, {"bad": object()})
47+
48+
# Original file should be untouched
49+
result = json.loads(target.read_text())
50+
assert result == original
51+
52+
def test_no_leftover_temp_files_on_success(self, tmp_path):
53+
target = tmp_path / "data.json"
54+
atomic_json_write(target, [1, 2, 3])
55+
56+
# No .tmp files should be left behind
57+
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
58+
assert len(tmp_files) == 0
59+
assert target.exists()
60+
61+
def test_no_leftover_temp_files_on_failure(self, tmp_path):
62+
target = tmp_path / "data.json"
63+
64+
with pytest.raises(TypeError):
65+
atomic_json_write(target, {"bad": object()})
66+
67+
# No temp files should be left behind
68+
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
69+
assert len(tmp_files) == 0
70+
71+
def test_accepts_string_path(self, tmp_path):
72+
target = str(tmp_path / "string_path.json")
73+
atomic_json_write(target, {"string": True})
74+
75+
result = json.loads(Path(target).read_text())
76+
assert result == {"string": True}
77+
78+
def test_writes_list_data(self, tmp_path):
79+
target = tmp_path / "list.json"
80+
data = [1, "two", {"three": 3}]
81+
atomic_json_write(target, data)
82+
83+
result = json.loads(target.read_text())
84+
assert result == data
85+
86+
def test_empty_list(self, tmp_path):
87+
target = tmp_path / "empty.json"
88+
atomic_json_write(target, [])
89+
90+
result = json.loads(target.read_text())
91+
assert result == []
92+
93+
def test_custom_indent(self, tmp_path):
94+
target = tmp_path / "custom.json"
95+
atomic_json_write(target, {"a": 1}, indent=4)
96+
97+
text = target.read_text()
98+
assert ' "a"' in text # 4-space indent
99+
100+
def test_unicode_content(self, tmp_path):
101+
target = tmp_path / "unicode.json"
102+
data = {"emoji": "🎉", "japanese": "日本語"}
103+
atomic_json_write(target, data)
104+
105+
result = json.loads(target.read_text(encoding="utf-8"))
106+
assert result["emoji"] == "🎉"
107+
assert result["japanese"] == "日本語"
108+
109+
def test_concurrent_writes_dont_corrupt(self, tmp_path):
110+
"""Multiple rapid writes should each produce valid JSON."""
111+
import threading
112+
113+
target = tmp_path / "concurrent.json"
114+
errors = []
115+
116+
def writer(n):
117+
try:
118+
atomic_json_write(target, {"writer": n, "data": list(range(100))})
119+
except Exception as e:
120+
errors.append(e)
121+
122+
threads = [threading.Thread(target=writer, args=(i,)) for i in range(10)]
123+
for t in threads:
124+
t.start()
125+
for t in threads:
126+
t.join()
127+
128+
assert not errors
129+
# File should contain valid JSON from one of the writers
130+
result = json.loads(target.read_text())
131+
assert "writer" in result
132+
assert len(result["data"]) == 100
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""Tests for batch_runner checkpoint behavior — incremental writes, resume, atomicity."""
2+
3+
import json
4+
import os
5+
from pathlib import Path
6+
from multiprocessing import Lock
7+
from unittest.mock import patch, MagicMock
8+
9+
import pytest
10+
11+
# batch_runner uses relative imports, ensure project root is on path
12+
import sys
13+
sys.path.insert(0, str(Path(__file__).parent.parent))
14+
15+
from batch_runner import BatchRunner
16+
17+
18+
@pytest.fixture
19+
def runner(tmp_path):
20+
"""Create a BatchRunner with all paths pointing at tmp_path."""
21+
prompts_file = tmp_path / "prompts.jsonl"
22+
prompts_file.write_text("")
23+
output_file = tmp_path / "output.jsonl"
24+
checkpoint_file = tmp_path / "checkpoint.json"
25+
r = BatchRunner.__new__(BatchRunner)
26+
r.run_name = "test_run"
27+
r.checkpoint_file = checkpoint_file
28+
r.output_file = output_file
29+
r.prompts_file = prompts_file
30+
return r
31+
32+
33+
class TestSaveCheckpoint:
34+
"""Verify _save_checkpoint writes valid, atomic JSON."""
35+
36+
def test_writes_valid_json(self, runner):
37+
data = {"run_name": "test", "completed_prompts": [1, 2, 3], "batch_stats": {}}
38+
runner._save_checkpoint(data)
39+
40+
result = json.loads(runner.checkpoint_file.read_text())
41+
assert result["run_name"] == "test"
42+
assert result["completed_prompts"] == [1, 2, 3]
43+
44+
def test_adds_last_updated(self, runner):
45+
data = {"run_name": "test", "completed_prompts": []}
46+
runner._save_checkpoint(data)
47+
48+
result = json.loads(runner.checkpoint_file.read_text())
49+
assert "last_updated" in result
50+
assert result["last_updated"] is not None
51+
52+
def test_overwrites_previous_checkpoint(self, runner):
53+
runner._save_checkpoint({"run_name": "test", "completed_prompts": [1]})
54+
runner._save_checkpoint({"run_name": "test", "completed_prompts": [1, 2, 3]})
55+
56+
result = json.loads(runner.checkpoint_file.read_text())
57+
assert result["completed_prompts"] == [1, 2, 3]
58+
59+
def test_with_lock(self, runner):
60+
lock = Lock()
61+
data = {"run_name": "test", "completed_prompts": [42]}
62+
runner._save_checkpoint(data, lock=lock)
63+
64+
result = json.loads(runner.checkpoint_file.read_text())
65+
assert result["completed_prompts"] == [42]
66+
67+
def test_without_lock(self, runner):
68+
data = {"run_name": "test", "completed_prompts": [99]}
69+
runner._save_checkpoint(data, lock=None)
70+
71+
result = json.loads(runner.checkpoint_file.read_text())
72+
assert result["completed_prompts"] == [99]
73+
74+
def test_creates_parent_dirs(self, tmp_path):
75+
runner_deep = BatchRunner.__new__(BatchRunner)
76+
runner_deep.checkpoint_file = tmp_path / "deep" / "nested" / "checkpoint.json"
77+
78+
data = {"run_name": "test", "completed_prompts": []}
79+
runner_deep._save_checkpoint(data)
80+
81+
assert runner_deep.checkpoint_file.exists()
82+
83+
def test_no_temp_files_left(self, runner):
84+
runner._save_checkpoint({"run_name": "test", "completed_prompts": []})
85+
86+
tmp_files = [f for f in runner.checkpoint_file.parent.iterdir()
87+
if ".tmp" in f.name]
88+
assert len(tmp_files) == 0
89+
90+
91+
class TestLoadCheckpoint:
92+
"""Verify _load_checkpoint reads existing data or returns defaults."""
93+
94+
def test_returns_empty_when_no_file(self, runner):
95+
result = runner._load_checkpoint()
96+
assert result.get("completed_prompts", []) == []
97+
98+
def test_loads_existing_checkpoint(self, runner):
99+
data = {"run_name": "test_run", "completed_prompts": [5, 10, 15],
100+
"batch_stats": {"0": {"processed": 3}}}
101+
runner.checkpoint_file.write_text(json.dumps(data))
102+
103+
result = runner._load_checkpoint()
104+
assert result["completed_prompts"] == [5, 10, 15]
105+
assert result["batch_stats"]["0"]["processed"] == 3
106+
107+
def test_handles_corrupt_json(self, runner):
108+
runner.checkpoint_file.write_text("{broken json!!")
109+
110+
result = runner._load_checkpoint()
111+
# Should return empty/default, not crash
112+
assert isinstance(result, dict)
113+
114+
115+
class TestResumePreservesProgress:
116+
"""Verify that initializing a run with resume=True loads prior checkpoint."""
117+
118+
def test_completed_prompts_loaded_from_checkpoint(self, runner):
119+
# Simulate a prior run that completed prompts 0-4
120+
prior = {
121+
"run_name": "test_run",
122+
"completed_prompts": [0, 1, 2, 3, 4],
123+
"batch_stats": {"0": {"processed": 5}},
124+
"last_updated": "2026-01-01T00:00:00",
125+
}
126+
runner.checkpoint_file.write_text(json.dumps(prior))
127+
128+
# Load checkpoint like run() does
129+
checkpoint_data = runner._load_checkpoint()
130+
if checkpoint_data.get("run_name") != runner.run_name:
131+
checkpoint_data = {
132+
"run_name": runner.run_name,
133+
"completed_prompts": [],
134+
"batch_stats": {},
135+
"last_updated": None,
136+
}
137+
138+
completed_set = set(checkpoint_data.get("completed_prompts", []))
139+
assert completed_set == {0, 1, 2, 3, 4}
140+
141+
def test_different_run_name_starts_fresh(self, runner):
142+
prior = {
143+
"run_name": "different_run",
144+
"completed_prompts": [0, 1, 2],
145+
"batch_stats": {},
146+
}
147+
runner.checkpoint_file.write_text(json.dumps(prior))
148+
149+
checkpoint_data = runner._load_checkpoint()
150+
if checkpoint_data.get("run_name") != runner.run_name:
151+
checkpoint_data = {
152+
"run_name": runner.run_name,
153+
"completed_prompts": [],
154+
"batch_stats": {},
155+
"last_updated": None,
156+
}
157+
158+
assert checkpoint_data["completed_prompts"] == []
159+
assert checkpoint_data["run_name"] == "test_run"

0 commit comments

Comments
 (0)