Skip to content

Commit 6ace678

Browse files
Merge pull request #27 from SaridakisStamatisChristos/codex/add-batch-controls-and-json-output
Add CLI bounds, JSON outputs, and sampling controls
2 parents 9b0f613 + 8bd1425 commit 6ace678

File tree

5 files changed

+183
-53
lines changed

5 files changed

+183
-53
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,19 @@ sudoku-dlx gen-batch --out puzzles.txt --count 1000 --givens 30 --symmetry rot18
5656
sudoku-dlx rate-file --in puzzles.txt --csv ratings.csv
5757
python bench/bench_file.py --in puzzles.txt
5858

59+
## Batch controls & parallel
60+
Generate with bounds and multiple processes:
61+
```bash
62+
sudoku-dlx gen-batch --out puzzles.txt --count 1000 --givens 30 \
63+
--min-givens 28 --max-givens 40 --parallel 8
64+
```
65+
66+
JSON output for ratings & sampling in stats:
67+
```bash
68+
sudoku-dlx rate-file --in puzzles.txt --json > scores.ndjson
69+
sudoku-dlx stats-file --in puzzles.txt --limit 5000 --sample 1000 --json stats.json
70+
```
71+
5972
# Dedupe a file of puzzles (fast)
6073
sudoku-dlx dedupe --in puzzles.txt --out unique.txt
6174

src/sudoku_dlx/cli.py

Lines changed: 109 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
from __future__ import annotations
22

3-
import argparse
4-
import csv
5-
import json
6-
import pathlib
7-
import random
8-
import sys
9-
import time
3+
import argparse, sys, pathlib, csv, random, json, time, multiprocessing as mp
104
from typing import Optional
115

126
from .api import analyze, build_reveal_trace, from_string, is_valid, solve, to_string
@@ -16,6 +10,26 @@
1610
from statistics import mean
1711

1812

13+
def _count_givens(grid) -> int:
14+
return sum(1 for r in range(9) for c in range(9) if grid[r][c] != 0)
15+
16+
17+
def _gen_batch_worker(args: tuple[int, int, bool, str, int, int]) -> str:
18+
seed_i, target_givens, minimal, symmetry, min_givens, max_givens = args
19+
local_rng = random.Random(seed_i)
20+
while True:
21+
g = generate(
22+
seed=local_rng.randrange(2**31 - 1),
23+
target_givens=target_givens,
24+
minimal=minimal,
25+
symmetry=symmetry,
26+
)
27+
gv = _count_givens(g)
28+
if min_givens <= gv <= max_givens:
29+
return canonical_form(g)
30+
31+
32+
1933
def _read_grid_arg(ns: argparse.Namespace) -> str:
2034
if ns.grid:
2135
return ns.grid
@@ -110,31 +124,49 @@ def cmd_canon(ns: argparse.Namespace) -> int:
110124

111125

112126
def cmd_gen_batch(ns: argparse.Namespace) -> int:
113-
"""Generate many canonicalized, unique puzzles quickly."""
114-
115-
out_path = pathlib.Path(ns.out)
127+
"""
128+
Generate many canonicalized, unique puzzles quickly.
129+
"""
130+
outp = pathlib.Path(ns.out)
116131
seen: set[str] = set()
117-
rng = random.Random(ns.seed)
118-
unique: list[str] = []
119-
while len(unique) < ns.count:
120-
grid = generate(
121-
seed=rng.randrange(2**31 - 1),
122-
target_givens=ns.givens,
123-
minimal=ns.minimal,
124-
symmetry=ns.symmetry,
125-
)
126-
canon = canonical_form(grid)
127-
if canon in seen:
128-
continue
129-
seen.add(canon)
130-
unique.append(canon)
131-
out_path.parent.mkdir(parents=True, exist_ok=True)
132-
with out_path.open("w", encoding="utf-8") as handle:
133-
for value in unique:
134-
handle.write(value + "\n")
135-
print(f"# generated: {len(unique)}", file=sys.stderr)
136-
return 0
132+
uniq: list[str] = []
137133

134+
base_seed = ns.seed if ns.seed is not None else random.randrange(2**31 - 1)
135+
if ns.parallel <= 1:
136+
i = 0
137+
while len(uniq) < ns.count:
138+
args = (base_seed + i, ns.givens, ns.minimal, ns.symmetry, ns.min_givens, ns.max_givens)
139+
i += 1
140+
c = _gen_batch_worker(args)
141+
if c in seen:
142+
continue
143+
seen.add(c)
144+
uniq.append(c)
145+
else:
146+
with mp.Pool(processes=ns.parallel) as pool:
147+
i = 0
148+
# produce candidates until we reach count
149+
while len(uniq) < ns.count:
150+
batch_n = max(4 * ns.parallel, ns.count - len(uniq))
151+
seeds = [base_seed + j for j in range(i, i + batch_n)]
152+
i += batch_n
153+
args_iter = [
154+
(seed, ns.givens, ns.minimal, ns.symmetry, ns.min_givens, ns.max_givens)
155+
for seed in seeds
156+
]
157+
for c in pool.imap_unordered(_gen_batch_worker, args_iter):
158+
if c in seen:
159+
continue
160+
seen.add(c)
161+
uniq.append(c)
162+
if len(uniq) >= ns.count:
163+
break
164+
outp.parent.mkdir(parents=True, exist_ok=True)
165+
with outp.open("w", encoding="utf-8") as f:
166+
for u in uniq:
167+
f.write(u + "\n")
168+
print(f"# generated: {len(uniq)}", file=sys.stderr)
169+
return 0
138170

139171
def cmd_rate_file(ns: argparse.Namespace) -> int:
140172
inp = pathlib.Path(ns.in_path)
@@ -146,7 +178,10 @@ def cmd_rate_file(ns: argparse.Namespace) -> int:
146178
continue
147179
score = rate(from_string(s))
148180
rows.append((s, score))
149-
print(f"{score:.1f}")
181+
if ns.json:
182+
print(json.dumps({"grid": s, "score": round(score, 1)}, separators=(",", ":")))
183+
else:
184+
print(f"{score:.1f}")
150185
if ns.csv_path:
151186
with open(ns.csv_path, "w", newline="", encoding="utf-8") as csv_handle:
152187
writer = csv.writer(csv_handle)
@@ -169,41 +204,57 @@ def _percentile(xs: list[float], p: float) -> float:
169204

170205
def cmd_stats_file(ns: argparse.Namespace) -> int:
171206
inp = pathlib.Path(ns.in_path)
172-
total = 0
207+
processed = 0
173208
n_valid = n_solvable = n_unique = 0
174209
givens: list[int] = []
175210
diffs: list[float] = []
176211
ms_list: list[float] = []
177212
t0 = time.perf_counter()
213+
# reservoir sample if requested
214+
sample_k = ns.sample if ns.sample and ns.sample > 0 else 0
215+
rng = random.Random(1337)
216+
lines: list[str] = []
178217
with inp.open("r", encoding="utf-8") as handle:
179218
for line in handle:
180219
s = "".join(ch for ch in line.strip() if not ch.isspace())
181220
if not s:
182221
continue
183-
try:
184-
grid = from_string(s)
185-
except Exception:
186-
continue
187-
data = analyze(grid)
188-
total += 1
189-
if data["valid"]:
190-
n_valid += 1
191-
if data["solvable"]:
192-
n_solvable += 1
193-
if data["unique"]:
194-
n_unique += 1
195-
givens.append(int(data["givens"]))
196-
diffs.append(float(data["difficulty"]))
197-
ms_list.append(float(data["stats"]["ms"]))
198-
if total == 0:
222+
if ns.limit and processed >= ns.limit:
223+
break
224+
processed += 1
225+
if sample_k == 0:
226+
lines.append(s)
227+
else:
228+
if len(lines) < sample_k:
229+
lines.append(s)
230+
else:
231+
j = rng.randrange(1, processed + 1)
232+
if j <= sample_k:
233+
lines[j - 1] = s
234+
for s in lines:
235+
try:
236+
grid = from_string(s)
237+
except Exception:
238+
continue
239+
data = analyze(grid)
240+
if data["valid"]:
241+
n_valid += 1
242+
if data["solvable"]:
243+
n_solvable += 1
244+
if data["unique"]:
245+
n_unique += 1
246+
givens.append(int(data["givens"]))
247+
diffs.append(float(data["difficulty"]))
248+
ms_list.append(float(data["stats"]["ms"]))
249+
if len(lines) == 0:
199250
print("no puzzles read", file=sys.stderr)
200251
return 2
201252
elapsed = (time.perf_counter() - t0) * 1000.0
202253
report = {
203-
"count": total,
204-
"valid_pct": round(100.0 * n_valid / total, 2),
205-
"solvable_pct": round(100.0 * n_solvable / total, 2),
206-
"unique_pct": round(100.0 * n_unique / total, 2),
254+
"count": len(lines),
255+
"valid_pct": round(100.0 * n_valid / len(lines), 2),
256+
"solvable_pct": round(100.0 * n_solvable / len(lines), 2),
257+
"unique_pct": round(100.0 * n_unique / len(lines), 2),
207258
"givens_mean": round(mean(givens), 2),
208259
"givens_min": min(givens),
209260
"givens_max": max(givens),
@@ -245,7 +296,6 @@ def cmd_stats_file(ns: argparse.Namespace) -> int:
245296
)
246297
return 0
247298

248-
249299
def cmd_dedupe(ns: argparse.Namespace) -> int:
250300
inp = pathlib.Path(ns.in_path)
251301
outp = pathlib.Path(ns.out_path)
@@ -328,13 +378,17 @@ def main(argv: Optional[list[str]] = None) -> int:
328378
)
329379
genb_parser.add_argument("--minimal", action="store_true")
330380
genb_parser.add_argument("--seed", type=int, default=None)
381+
genb_parser.add_argument("--min-givens", type=int, default=0, help="keep only puzzles with >= this many givens")
382+
genb_parser.add_argument("--max-givens", type=int, default=81, help="keep only puzzles with <= this many givens")
383+
genb_parser.add_argument("--parallel", type=int, default=1, help="processes for generation (default 1)")
331384
genb_parser.set_defaults(func=cmd_gen_batch)
332385

333386
ratef_parser = sub.add_parser("rate-file", help="rate each puzzle in a file")
334387
ratef_parser.add_argument("--in", dest="in_path", required=True)
335388
ratef_parser.add_argument(
336389
"--csv", dest="csv_path", help="optional CSV output path"
337390
)
391+
ratef_parser.add_argument("--json", action="store_true", help="print one JSON object per line to stdout")
338392
ratef_parser.set_defaults(func=cmd_rate_file)
339393

340394
stats_parser = sub.add_parser("stats-file", help="summarize a file of puzzles")
@@ -348,6 +402,8 @@ def main(argv: Optional[list[str]] = None) -> int:
348402
stats_parser.add_argument(
349403
"--bins", type=int, default=11, help="histogram bins (default 11 for 0..10)"
350404
)
405+
stats_parser.add_argument("--limit", type=int, default=0, help="process at most N lines (0 = no limit)")
406+
stats_parser.add_argument("--sample", type=int, default=0, help="reservoir sample K lines (0 = no sampling)")
351407
stats_parser.set_defaults(func=cmd_stats_file)
352408

353409
gen_parser = sub.add_parser("gen", help="generate a puzzle")
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from sudoku_dlx import cli, from_string
2+
3+
4+
def _givens(s: str) -> int:
5+
return sum(1 for ch in s if ch != '.')
6+
7+
8+
def test_gen_batch_bounds_parallel(tmp_path):
9+
out = tmp_path / "p.txt"
10+
rc = cli.main([
11+
"gen-batch",
12+
"--out", str(out),
13+
"--count", "6",
14+
"--givens", "34",
15+
"--min-givens", "30",
16+
"--max-givens", "40",
17+
"--parallel", "2",
18+
"--symmetry", "none",
19+
])
20+
assert rc == 0
21+
lines = [ln.strip() for ln in out.read_text(encoding="utf-8").splitlines() if ln.strip()]
22+
assert len(lines) == 6
23+
for s in lines:
24+
g = _givens(s)
25+
assert 30 <= g <= 40

tests/test_rate_file_json.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import json
2+
from sudoku_dlx import cli
3+
4+
5+
def test_rate_file_json(tmp_path, capsys):
6+
p = tmp_path / "p.txt"
7+
grids = [
8+
"53..7....6..195....98....6.8...6...34..8.3..17...2...6.6....28....419..5....8..79",
9+
"53..7....6..195....98....6.8...6...34..8.3..17...2...6.6....28....419..5....8..79",
10+
]
11+
p.write_text("\n".join(grids) + "\n", encoding="utf-8")
12+
rc = cli.main(["rate-file", "--in", str(p), "--json"])
13+
assert rc == 0
14+
out = capsys.readouterr().out.strip().splitlines()
15+
assert len(out) == 2
16+
j = json.loads(out[0])
17+
assert "grid" in j and "score" in j
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import json
2+
from sudoku_dlx import cli
3+
4+
5+
def test_stats_file_limit_and_sample(tmp_path, capsys):
6+
# prepare 10 puzzles
7+
out = tmp_path / "p.txt"
8+
rc = cli.main(["gen-batch", "--out", str(out), "--count", "10", "--givens", "34", "--symmetry", "none"])
9+
assert rc == 0
10+
# limit=5
11+
rc = cli.main(["stats-file", "--in", str(out), "--limit", "5"])
12+
assert rc == 0
13+
data = json.loads(capsys.readouterr().out.strip())
14+
assert data["count"] == 5
15+
# sample=4 (no limit)
16+
rc = cli.main(["stats-file", "--in", str(out), "--sample", "4"])
17+
assert rc == 0
18+
data = json.loads(capsys.readouterr().out.strip())
19+
assert data["count"] == 4

0 commit comments

Comments
 (0)