Skip to content

Commit 7cd2d37

Browse files
Merge pull request #14 from SaridakisStamatisChristos/codex/apply-one-shot-sota-patch
Add Sudoku canonicalization support
2 parents d221bbc + 81f9915 commit 7cd2d37

File tree

5 files changed

+335
-0
lines changed

5 files changed

+335
-0
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ sudoku-dlx solve --grid "<81chars>" --pretty --stats
4242
# Rate difficulty (0..10)
4343
sudoku-dlx rate --grid "<81chars>"
4444

45+
# Canonicalize (dedupe isomorphic puzzles)
46+
sudoku-dlx canon --grid "<81chars>" # D4 × bands/stacks × inner row/col × digit relabel
47+
# Produces a stable 81-char string for deduping datasets.
48+
4549
# Generate a unique puzzle (deterministic with seed)
4650
sudoku-dlx gen --seed 123 --givens 30 # ~target clue count (approx)
4751
sudoku-dlx gen --seed 123 --givens 30 --pretty

src/sudoku_dlx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
solve,
1111
to_string,
1212
)
13+
from .canonical import canonical_form
1314
from .generate import generate
1415
from .rating import rate
1516
from .solver import (
@@ -35,6 +36,7 @@
3536
"solve",
3637
"count_solutions",
3738
"rate",
39+
"canonical_form",
3840
"generate",
3941
# Legacy exports
4042
"SOLVER",

src/sudoku_dlx/canonical.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
from __future__ import annotations
2+
"""
3+
State-of-the-art canonicalization for Sudoku puzzles.
4+
5+
Maps isomorphic puzzles to a single 81-char canonical form using:
6+
• Dihedral symmetries D4 (8 transforms)
7+
• Band (row bands) and stack (column stacks) permutations (3! each)
8+
• Row swaps within each band and column swaps within each stack (3! for each band/stack)
9+
• Greedy digit relabeling (first-appearance maps to 1..9)
10+
11+
Total variants explored per grid: 8 × (3!)^4 = 10,368 — acceptable for CLI/tests.
12+
"""
13+
from itertools import permutations
14+
from typing import List, Sequence, Tuple
15+
16+
from .api import Grid
17+
18+
# --------- Dihedral transforms over 9x9 grids (D4) ----------
19+
20+
def _rot90(g: Grid) -> Grid:
21+
return [[g[9 - 1 - c][r] for c in range(9)] for r in range(9)]
22+
23+
24+
def _rot180(g: Grid) -> Grid:
25+
return [[g[9 - 1 - r][9 - 1 - c] for c in range(9)] for r in range(9)]
26+
27+
28+
def _rot270(g: Grid) -> Grid:
29+
return [[g[c][9 - 1 - r] for c in range(9)] for r in range(9)]
30+
31+
32+
def _flip_h(g: Grid) -> Grid:
33+
# horizontal flip (mirror over vertical axis)
34+
return [[g[r][9 - 1 - c] for c in range(9)] for r in range(9)]
35+
36+
37+
def _flip_v(g: Grid) -> Grid:
38+
# vertical flip (mirror over horizontal axis)
39+
return [g[9 - 1 - r][:] for r in range(9)]
40+
41+
42+
def _flip_main_diag(g: Grid) -> Grid:
43+
# transpose over main diagonal
44+
return [[g[c][r] for c in range(9)] for r in range(9)]
45+
46+
47+
def _flip_anti_diag(g: Grid) -> Grid:
48+
# reflect over anti-diagonal (r,c) -> (8-c,8-r)
49+
return [[g[9 - 1 - c][9 - 1 - r] for c in range(9)] for r in range(9)]
50+
51+
52+
_TRANSFORMS = (
53+
lambda x: x,
54+
_rot90,
55+
_rot180,
56+
_rot270,
57+
_flip_h,
58+
_flip_v,
59+
_flip_main_diag,
60+
_flip_anti_diag,
61+
)
62+
63+
# --------- Permutations for bands/stacks and inner rows/cols ----------
64+
65+
_PERM3 = list(permutations((0, 1, 2))) # 6 perms
66+
67+
68+
def _cell_char(value: int) -> str:
69+
if value == 0:
70+
return "."
71+
if isinstance(value, str):
72+
return value if value not in {"0", "-"} else "."
73+
return str(value)
74+
75+
76+
def _canonical_band_stack(
77+
grid_chars: Sequence[Sequence[str]],
78+
band_perm: Tuple[int, int, int],
79+
stack_perm: Tuple[int, int, int],
80+
best: str | None,
81+
) -> str | None:
82+
best_local = best
83+
chosen_row_perms: dict[int, Tuple[int, int, int]] = {}
84+
chosen_col_perms: dict[int, Tuple[int, int, int]] = {}
85+
mapping: dict[str, str] = {}
86+
out_chars: List[str] = []
87+
next_digit = ord("1")
88+
cmp_state = 0
89+
90+
def rollback(inserted: List[str], saved_len: int, saved_next: int, saved_cmp: int) -> None:
91+
nonlocal next_digit, cmp_state
92+
del out_chars[saved_len:]
93+
next_digit = saved_next
94+
cmp_state = saved_cmp
95+
for key in reversed(inserted):
96+
mapping.pop(key, None)
97+
98+
def dfs(block_idx: int) -> None:
99+
nonlocal best_local, next_digit, cmp_state
100+
if block_idx == 9:
101+
candidate = "".join(out_chars)
102+
if best_local is None or candidate < best_local:
103+
best_local = candidate
104+
return
105+
106+
band_idx = block_idx // 3
107+
stack_idx = block_idx % 3
108+
band = band_perm[band_idx]
109+
stack = stack_perm[stack_idx]
110+
111+
row_options = (
112+
(chosen_row_perms[band],)
113+
if band in chosen_row_perms
114+
else _PERM3
115+
)
116+
col_options = (
117+
(chosen_col_perms[stack],)
118+
if stack in chosen_col_perms
119+
else _PERM3
120+
)
121+
122+
for row_perm in row_options:
123+
assigned_row = False
124+
if band not in chosen_row_perms:
125+
chosen_row_perms[band] = row_perm
126+
assigned_row = True
127+
for col_perm in col_options:
128+
assigned_col = False
129+
if stack not in chosen_col_perms:
130+
chosen_col_perms[stack] = col_perm
131+
assigned_col = True
132+
133+
saved_len = len(out_chars)
134+
saved_next = next_digit
135+
saved_cmp = cmp_state
136+
inserted: List[str] = []
137+
pruned = False
138+
139+
for r_local in row_perm:
140+
row = grid_chars[band * 3 + r_local]
141+
for c_local in col_perm:
142+
ch = row[stack * 3 + c_local]
143+
if ch == ".":
144+
mapped = "."
145+
else:
146+
mapped = mapping.get(ch)
147+
if mapped is None:
148+
mapped = chr(next_digit)
149+
mapping[ch] = mapped
150+
inserted.append(ch)
151+
if next_digit < ord("9"):
152+
next_digit += 1
153+
out_chars.append(mapped)
154+
if best_local is not None and cmp_state == 0:
155+
best_char = best_local[len(out_chars) - 1]
156+
if mapped > best_char:
157+
pruned = True
158+
break
159+
if mapped < best_char:
160+
cmp_state = -1
161+
if pruned:
162+
break
163+
164+
if not pruned:
165+
dfs(block_idx + 1)
166+
167+
rollback(inserted, saved_len, saved_next, saved_cmp)
168+
169+
if assigned_col:
170+
chosen_col_perms.pop(stack, None)
171+
172+
if pruned and best_local is not None and cmp_state == 0:
173+
# If pruning occurred due to mapped > best prefix, remaining column perms
174+
# in this branch are unlikely to improve; continue to next col perm.
175+
pass
176+
177+
if assigned_row:
178+
chosen_row_perms.pop(band, None)
179+
180+
dfs(0)
181+
return best_local
182+
183+
184+
# --------- Public API (full canon) ----------
185+
186+
187+
def canonical_form(grid: Grid) -> str:
188+
"""
189+
Return the lexicographically smallest normalized string among all:
190+
- D4 dihedral transforms
191+
- Band and stack permutations
192+
- Row swaps within each band, column swaps within each stack
193+
Each candidate is normalized by greedy digit relabeling before compare.
194+
"""
195+
best: str | None = None
196+
for tf in _TRANSFORMS:
197+
g1 = tf(grid)
198+
grid_chars = [[_cell_char(cell) for cell in row] for row in g1]
199+
for band_perm in _PERM3:
200+
for stack_perm in _PERM3:
201+
cand = _canonical_band_stack(grid_chars, band_perm, stack_perm, best)
202+
if cand is not None and (best is None or cand < best):
203+
best = cand
204+
assert best is not None
205+
return best
206+
207+
208+
__all__ = ["canonical_form"]

src/sudoku_dlx/cli.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Optional
44

55
from .api import from_string, is_valid, solve, to_string
6+
from .canonical import canonical_form
67
from .generate import generate
78
from .rating import rate
89

@@ -62,6 +63,12 @@ def cmd_gen(ns: argparse.Namespace) -> int:
6263
return 0
6364

6465

66+
def cmd_canon(ns: argparse.Namespace) -> int:
67+
grid = from_string(_read_grid_arg(ns))
68+
print(canonical_form(grid))
69+
return 0
70+
71+
6572
def main(argv: Optional[list[str]] = None) -> int:
6673
parser = argparse.ArgumentParser(
6774
prog="sudoku-dlx",
@@ -98,6 +105,16 @@ def main(argv: Optional[list[str]] = None) -> int:
98105
gen_parser.add_argument("--pretty", action="store_true")
99106
gen_parser.set_defaults(func=cmd_gen)
100107

108+
canon_parser = sub.add_parser(
109+
"canon",
110+
help=(
111+
"print canonical 81-char form (D4 × bands/stacks × inner row/col swaps × digit relabel)"
112+
),
113+
)
114+
canon_parser.add_argument("--grid", help="81-char string; 0/./- for blanks")
115+
canon_parser.add_argument("--file", help="path to a file with 9 lines of 9 chars")
116+
canon_parser.set_defaults(func=cmd_canon)
117+
101118
args = parser.parse_args(argv)
102119
if not hasattr(args, "func"):
103120
parser.print_help()

tests/test_canonical.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from textwrap import dedent
2+
3+
from sudoku_dlx import from_string, to_string, canonical_form
4+
5+
BASE = dedent(
6+
"""
7+
53..7....
8+
6..195...
9+
.98....6.
10+
8...6...3
11+
4..8.3..1
12+
7...2...6
13+
.6....28.
14+
...419..5
15+
....8..79
16+
"""
17+
).strip().replace("\n", "")
18+
19+
20+
def rot90_string(s: str) -> str:
21+
g = [list(s[r * 9 : (r + 1) * 9]) for r in range(9)]
22+
out = [["."] * 9 for _ in range(9)]
23+
for r in range(9):
24+
for c in range(9):
25+
out[r][c] = g[9 - 1 - c][r]
26+
return "".join("".join(row) for row in out)
27+
28+
29+
def relabel_123_to_456(s: str) -> str:
30+
table = str.maketrans({"1": "4", "2": "5", "3": "6", "4": "1", "5": "2", "6": "3"})
31+
return s.translate(table)
32+
33+
34+
def swap_bands_string(s: str, order=(1, 0, 2)) -> str:
35+
rows = [s[i * 9 : (i + 1) * 9] for i in range(9)]
36+
new_rows = []
37+
for b in order:
38+
new_rows.extend(rows[b * 3 : (b + 1) * 3])
39+
return "".join(new_rows)
40+
41+
42+
def swap_stacks_string(s: str, order=(2, 1, 0)) -> str:
43+
rows = [list(s[i * 9 : (i + 1) * 9]) for i in range(9)]
44+
for r in range(9):
45+
chunks = [rows[r][i * 3 : (i + 1) * 3] for i in range(3)]
46+
rows[r] = [v for idx in order for v in chunks[idx]]
47+
return "".join("".join(r) for r in rows)
48+
49+
50+
def swap_rows_in_band_string(s: str, band=1, perm=(2, 0, 1)) -> str:
51+
rows = [s[i * 9 : (i + 1) * 9] for i in range(9)]
52+
start = band * 3
53+
block = rows[start : start + 3]
54+
new_block = [block[i] for i in perm]
55+
rows[start : start + 3] = new_block
56+
return "".join(rows)
57+
58+
59+
def swap_cols_in_stack_string(s: str, stack=0, perm=(1, 2, 0)) -> str:
60+
rows = [list(s[i * 9 : (i + 1) * 9]) for i in range(9)]
61+
start = stack * 3
62+
for r in range(9):
63+
block = rows[r][start : start + 3]
64+
rows[r][start : start + 3] = [block[i] for i in perm]
65+
return "".join("".join(r) for r in rows)
66+
67+
68+
def test_canonical_equal_under_rotation():
69+
g0 = from_string(BASE)
70+
g1 = from_string(rot90_string(BASE))
71+
c0 = canonical_form(g0)
72+
c1 = canonical_form(g1)
73+
assert c0 == c1
74+
75+
76+
def test_canonical_equal_under_digit_relabel():
77+
s2 = relabel_123_to_456(BASE)
78+
c0 = canonical_form(from_string(BASE))
79+
c2 = canonical_form(from_string(s2))
80+
assert c0 == c2
81+
82+
83+
def test_canonical_is_81_chars_and_uses_dots():
84+
c = canonical_form(from_string(BASE))
85+
assert len(c) == 81
86+
assert set(c) <= set("123456789.")
87+
88+
89+
def test_canonical_equal_under_band_and_stack_swaps():
90+
s_band = swap_bands_string(BASE, order=(1, 0, 2))
91+
s_stack = swap_stacks_string(BASE, order=(2, 1, 0))
92+
c0 = canonical_form(from_string(BASE))
93+
c_band = canonical_form(from_string(s_band))
94+
c_stack = canonical_form(from_string(s_stack))
95+
assert c0 == c_band == c_stack
96+
97+
98+
def test_canonical_equal_under_inner_row_col_swaps():
99+
s_rows = swap_rows_in_band_string(BASE, band=2, perm=(1, 2, 0))
100+
s_cols = swap_cols_in_stack_string(BASE, stack=1, perm=(2, 0, 1))
101+
c0 = canonical_form(from_string(BASE))
102+
c_rows = canonical_form(from_string(s_rows))
103+
c_cols = canonical_form(from_string(s_cols))
104+
assert c0 == c_rows == c_cols

0 commit comments

Comments
 (0)