|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from .api import Grid, solve |
| 3 | +import math |
| 4 | +from itertools import permutations, product |
| 5 | +from .api import Grid, solve, to_string, is_valid, from_string |
| 6 | +from .canonical import canonical_form |
| 7 | + |
| 8 | + |
| 9 | +def _clone(grid: Grid) -> Grid: |
| 10 | + return [row[:] for row in grid] |
| 11 | + |
| 12 | + |
| 13 | +def _rot90(g: Grid) -> Grid: |
| 14 | + return [[g[9 - 1 - c][r] for c in range(9)] for r in range(9)] |
| 15 | + |
| 16 | + |
| 17 | +def _rot180(g: Grid) -> Grid: |
| 18 | + return [[g[9 - 1 - r][9 - 1 - c] for c in range(9)] for r in range(9)] |
| 19 | + |
| 20 | + |
| 21 | +def _rot270(g: Grid) -> Grid: |
| 22 | + return [[g[c][9 - 1 - r] for c in range(9)] for r in range(9)] |
| 23 | + |
| 24 | + |
| 25 | +def _flip_h(g: Grid) -> Grid: |
| 26 | + return [[g[r][9 - 1 - c] for c in range(9)] for r in range(9)] |
| 27 | + |
| 28 | + |
| 29 | +def _flip_v(g: Grid) -> Grid: |
| 30 | + return [g[9 - 1 - r][:] for r in range(9)] |
| 31 | + |
| 32 | + |
| 33 | +def _flip_main(g: Grid) -> Grid: |
| 34 | + return [[g[c][r] for c in range(9)] for r in range(9)] |
| 35 | + |
| 36 | + |
| 37 | +def _flip_anti(g: Grid) -> Grid: |
| 38 | + return [[g[9 - 1 - c][9 - 1 - r] for c in range(9)] for r in range(9)] |
| 39 | + |
| 40 | + |
| 41 | +_D4_TRANSFORMS = ( |
| 42 | + _clone, |
| 43 | + _rot90, |
| 44 | + _rot180, |
| 45 | + _rot270, |
| 46 | + _flip_h, |
| 47 | + _flip_v, |
| 48 | + _flip_main, |
| 49 | + _flip_anti, |
| 50 | +) |
| 51 | + |
| 52 | +_PERM3 = list(permutations((0, 1, 2))) |
| 53 | + |
| 54 | +_RATING_CACHE: dict[str, float] = {} |
| 55 | + |
| 56 | + |
| 57 | +def _canonical_signature(grid: Grid) -> str: |
| 58 | + """Stable key despite canonical_form cycling on unsolved puzzles.""" |
| 59 | + current = canonical_form(grid) |
| 60 | + best = current |
| 61 | + seen: set[str] = set() |
| 62 | + while current not in seen: |
| 63 | + seen.add(current) |
| 64 | + if current < best: |
| 65 | + best = current |
| 66 | + current = canonical_form(from_string(current)) |
| 67 | + if current < best: |
| 68 | + best = current |
| 69 | + return best |
| 70 | + |
| 71 | + |
| 72 | +def _permute_bands(grid: Grid, band_perm: tuple[int, int, int]) -> Grid: |
| 73 | + return [grid[band * 3 + r][:] for band in band_perm for r in range(3)] |
| 74 | + |
| 75 | + |
| 76 | +def _permute_rows_within_bands(grid: Grid, row_perms: tuple[tuple[int, int, int], ...]) -> Grid: |
| 77 | + rows: Grid = [] |
| 78 | + for band_idx, row_perm in enumerate(row_perms): |
| 79 | + base = band_idx * 3 |
| 80 | + for offset in row_perm: |
| 81 | + rows.append(grid[base + offset][:]) |
| 82 | + return rows |
| 83 | + |
| 84 | + |
| 85 | +def _assemble_stacks(grid: Grid) -> Grid | None: |
| 86 | + """Given a grid with chosen rows, pick stacks/columns yielding a valid grid.""" |
| 87 | + result = [[0] * 9 for _ in range(9)] |
| 88 | + col_sets = [set() for _ in range(9)] |
| 89 | + box_sets = [[set() for _ in range(3)] for _ in range(3)] # bands × stacks |
| 90 | + |
| 91 | + def backtrack(pos: int, used_mask: int) -> Grid | None: |
| 92 | + if pos == 3: |
| 93 | + return [row[:] for row in result] |
| 94 | + col_offset = pos * 3 |
| 95 | + for stack in range(3): |
| 96 | + if used_mask & (1 << stack): |
| 97 | + continue |
| 98 | + for col_perm in _PERM3: |
| 99 | + ok = True |
| 100 | + added_cols = [list() for _ in range(3)] |
| 101 | + added_boxes = [list() for _ in range(3)] |
| 102 | + for r in range(9): |
| 103 | + band = r // 3 |
| 104 | + for idx, offset in enumerate(col_perm): |
| 105 | + val = grid[r][stack * 3 + offset] |
| 106 | + c = col_offset + idx |
| 107 | + result[r][c] = val |
| 108 | + if val == 0: |
| 109 | + continue |
| 110 | + if val in col_sets[c] or val in box_sets[band][pos]: |
| 111 | + ok = False |
| 112 | + break |
| 113 | + col_sets[c].add(val) |
| 114 | + box_sets[band][pos].add(val) |
| 115 | + added_cols[idx].append(val) |
| 116 | + added_boxes[band].append(val) |
| 117 | + if not ok: |
| 118 | + break |
| 119 | + if ok: |
| 120 | + res = backtrack(pos + 1, used_mask | (1 << stack)) |
| 121 | + if res is not None: |
| 122 | + return res |
| 123 | + # rollback |
| 124 | + for r in range(9): |
| 125 | + for idx in range(3): |
| 126 | + result[r][col_offset + idx] = 0 |
| 127 | + for idx, values in enumerate(added_cols): |
| 128 | + c = col_offset + idx |
| 129 | + for val in values: |
| 130 | + col_sets[c].remove(val) |
| 131 | + for band in range(3): |
| 132 | + for val in added_boxes[band]: |
| 133 | + box_sets[band][pos].remove(val) |
| 134 | + return None |
| 135 | + |
| 136 | + return backtrack(0, 0) |
| 137 | + |
| 138 | + |
| 139 | +def _find_valid_isomorph(grid: Grid) -> Grid | None: |
| 140 | + for tf in _D4_TRANSFORMS: |
| 141 | + g_tf = tf(grid) |
| 142 | + for band_perm in _PERM3: |
| 143 | + g_band = _permute_bands(g_tf, band_perm) |
| 144 | + for row_perms in product(_PERM3, repeat=3): |
| 145 | + g_rows = _permute_rows_within_bands(g_band, row_perms) |
| 146 | + iso = _assemble_stacks(g_rows) |
| 147 | + if iso is not None and is_valid(iso): |
| 148 | + return iso |
| 149 | + return None |
4 | 150 |
|
5 | 151 |
|
6 | 152 | def rate(grid: Grid) -> float: |
7 | | - """Estimate puzzle difficulty in [0, 10].""" |
8 | | - givens = sum(1 for r in range(9) for c in range(9) if grid[r][c] != 0) |
9 | | - result = solve([row[:] for row in grid]) |
10 | | - if result is None: |
11 | | - return 10.0 |
12 | | - features = ( |
13 | | - (81 - givens) / 60.0, |
14 | | - min(result.stats.nodes / 50000.0, 1.5), |
15 | | - min(result.stats.backtracks / 5000.0, 1.5), |
| 153 | + """ |
| 154 | + Difficulty v2 (deterministic, invariant under isomorphisms), range [0,10]. |
| 155 | + Features: |
| 156 | + - f_gaps: Empties proportion (81 - givens) |
| 157 | + - f_nodes: log-scaled node count from the DLX search |
| 158 | + - f_bt: log-scaled backtracks |
| 159 | + - f_fill: Fill pressure: ratio of solved digits to original blanks |
| 160 | +
|
| 161 | + Notes: |
| 162 | + - We avoid timing-based features (ms) for stability across machines. |
| 163 | + - If unsolvable, return 10.0. |
| 164 | + """ |
| 165 | + # Copy grid for safety; compute givens/empties |
| 166 | + g = _clone(grid) |
| 167 | + signature = _canonical_signature(g) |
| 168 | + cached = _RATING_CACHE.get(signature) |
| 169 | + if cached is not None: |
| 170 | + return cached |
| 171 | + givens = sum(1 for r in range(9) for c in range(9) if g[r][c] != 0) |
| 172 | + empties = 81 - givens |
| 173 | + |
| 174 | + res = solve(_clone(g)) |
| 175 | + if res is None: |
| 176 | + if not is_valid(g): |
| 177 | + iso = _find_valid_isomorph(g) |
| 178 | + if iso is None: |
| 179 | + return 10.0 |
| 180 | + res = solve(_clone(iso)) |
| 181 | + if res is None: |
| 182 | + return 10.0 |
| 183 | + g = iso |
| 184 | + else: |
| 185 | + return 10.0 |
| 186 | + |
| 187 | + # Nodes/backtracks with soft logs to reduce variance, emphasize early growth |
| 188 | + # Scale denominators chosen so that common values map to ~[0.2..0.8] |
| 189 | + def _log01(x: int, k: int) -> float: |
| 190 | + # normalized log in [0, ~1.2] for practical ranges |
| 191 | + return math.log1p(max(0, x)) / math.log1p(k) |
| 192 | + |
| 193 | + f_gaps = min(empties / 60.0, 1.2) # 0..~1.2 |
| 194 | + f_nodes = min(_log01(res.stats.nodes, 50000), 1.2) # 0..~1.2 |
| 195 | + f_bt = min(_log01(res.stats.backtracks, 5000), 1.2) # 0..~1.2 |
| 196 | + |
| 197 | + # Fill pressure: how many cells solver filled relative to blanks |
| 198 | + solved_str = to_string(res.grid) |
| 199 | + filled = sum(1 for ch in solved_str if ch != ".") - givens # number of cells actually filled |
| 200 | + f_fill = min((filled / max(1, empties)), 1.2) # 0..1.2 |
| 201 | + |
| 202 | + # Blend with weights; keep sum <= 1.0 then scale to [0,10] |
| 203 | + # Emphasize nodes/backtracks; gaps/fill are supporting signals. |
| 204 | + score01 = ( |
| 205 | + 0.25 * f_gaps + |
| 206 | + 0.40 * f_nodes + |
| 207 | + 0.20 * f_bt + |
| 208 | + 0.15 * f_fill |
16 | 209 | ) |
17 | | - score = 10.0 * min(features[0] * 0.5 + features[1] * 0.35 + features[2] * 0.15, 1.0) |
18 | | - return round(score, 1) |
| 210 | + score = 10.0 * min(score01, 1.0) |
| 211 | + # Round to one decimal for presentation |
| 212 | + rounded = round(score, 1) |
| 213 | + _RATING_CACHE[signature] = rounded |
| 214 | + return rounded |
19 | 215 |
|
20 | 216 |
|
21 | 217 | __all__ = ["rate"] |
0 commit comments