Skip to content

Commit dd4d9c1

Browse files
committed
Fix ep resolution in chess.gaviota.PythonTablebase (fixes #1132)
1 parent 636e95f commit dd4d9c1

File tree

2 files changed

+26
-121
lines changed

2 files changed

+26
-121
lines changed

chess/gaviota.py

+21-121
Original file line numberDiff line numberDiff line change
@@ -1356,51 +1356,9 @@ def split_index(i: int) -> Tuple[int, int]:
13561356
iBMATEt = tb_BMATE | 4
13571357

13581358

1359-
def removepiece(ys: List[int], yp: List[int], j: int) -> None:
1360-
del ys[j]
1361-
del yp[j]
1362-
13631359
def opp(side: int) -> int:
13641360
return 1 if side == 0 else 0
13651361

1366-
def adjust_up(dist: int) -> int:
1367-
udist = dist
1368-
sw = udist & INFOMASK
1369-
1370-
if sw in [iWMATE, iWMATEt, iBMATE, iBMATEt]:
1371-
udist += (1 << PLYSHIFT)
1372-
1373-
return udist
1374-
1375-
def bestx(side: int, a: int, b: int) -> int:
1376-
# 0 = selectfirst
1377-
# 1 = selectlowest
1378-
# 2 = selecthighest
1379-
# 3 = selectsecond
1380-
comparison = [
1381-
# draw, wmate, bmate, forbid
1382-
[0, 3, 0, 0], # draw
1383-
[0, 1, 0, 0], # wmate
1384-
[3, 3, 2, 0], # bmate
1385-
[3, 3, 3, 0], # forbid
1386-
]
1387-
1388-
xorkey = [0, 3]
1389-
1390-
if a == iFORBID:
1391-
return b
1392-
if b == iFORBID:
1393-
return a
1394-
1395-
retu = [a, a, b, b]
1396-
1397-
if b < a:
1398-
retu[1] = b
1399-
retu[2] = a
1400-
1401-
key = comparison[a & 3][b & 3] ^ xorkey[side]
1402-
return retu[key]
1403-
14041362
def unpackdist(d: int) -> Tuple[int, int]:
14051363
return d >> PLYSHIFT, d & INFOMASK
14061364

@@ -1492,12 +1450,11 @@ class Request:
14921450
black_piece_types: List[int]
14931451
is_reversed: bool
14941452

1495-
def __init__(self, white_squares: List[int], white_types: List[chess.PieceType], black_squares: List[int], black_types: List[chess.PieceType], side: int, epsq: int):
1453+
def __init__(self, white_squares: List[int], white_types: List[chess.PieceType], black_squares: List[int], black_types: List[chess.PieceType], side: int):
14961454
self.white_squares, self.white_types = sortlists(white_squares, white_types)
14971455
self.black_squares, self.black_types = sortlists(black_squares, black_types)
14981456
self.realside = side
14991457
self.side = side
1500-
self.epsq = epsq
15011458

15021459

15031460
@dataclasses.dataclass
@@ -1569,17 +1526,34 @@ def probe_dtm(self, board: chess.Board) -> int:
15691526
if board.occupied == board.kings:
15701527
return 0
15711528

1529+
# Resolve en passant.
1530+
dtm = self._probe_dtm_no_ep(board)
1531+
for move in board.generate_legal_ep():
1532+
try:
1533+
board.push(move)
1534+
1535+
child_dtm = -self._probe_dtm_no_ep(board)
1536+
if child_dtm > 0:
1537+
child_dtm += 1
1538+
elif child_dtm < 0:
1539+
child_dtm -= 1
1540+
1541+
dtm = min(dtm, child_dtm) if dtm * child_dtm > 0 else max(dtm, child_dtm)
1542+
finally:
1543+
board.pop()
1544+
return dtm
1545+
1546+
def _probe_dtm_no_ep(self, board: chess.Board) -> int:
15721547
# Prepare the tablebase request.
15731548
white_squares = list(chess.SquareSet(board.occupied_co[chess.WHITE]))
15741549
white_types = [typing.cast(chess.PieceType, board.piece_type_at(sq)) for sq in white_squares]
15751550
black_squares = list(chess.SquareSet(board.occupied_co[chess.BLACK]))
15761551
black_types = [typing.cast(chess.PieceType, board.piece_type_at(sq)) for sq in black_squares]
15771552
side = 0 if (board.turn == chess.WHITE) else 1
1578-
epsq = board.ep_square if board.ep_square else NOSQUARE
1579-
req = Request(white_squares, white_types, black_squares, black_types, side, epsq)
1553+
req = Request(white_squares, white_types, black_squares, black_types, side)
15801554

15811555
# Probe.
1582-
dtm = self.egtb_get_dtm(req)
1556+
dtm = self._tb_probe(req)
15831557
ply, res = unpackdist(dtm)
15841558

15851559
if res == iWMATE:
@@ -1675,10 +1649,7 @@ def _setup_tablebase(self, req: Request) -> BinaryIO:
16751649
req.white_piece_types = req.black_types
16761650
req.black_piece_squares = [flip_ns(s) for s in req.white_squares]
16771651
req.black_piece_types = req.white_types
1678-
16791652
req.side = opp(req.side)
1680-
if req.epsq != NOSQUARE:
1681-
req.epsq = flip_ns(req.epsq)
16821653
else:
16831654
raise MissingTableError(f"no gaviota table available for: {white_letters.upper()}v{black_letters.upper()}")
16841655

@@ -1708,77 +1679,6 @@ def close(self) -> None:
17081679
_, stream = self.streams.popitem()
17091680
stream.close()
17101681

1711-
def egtb_get_dtm(self, req: Request) -> int:
1712-
dtm = self._tb_probe(req)
1713-
1714-
if req.epsq != NOSQUARE:
1715-
capturer_a = 0
1716-
capturer_b = 0
1717-
xed = 0
1718-
1719-
# Flip for move generation.
1720-
if req.side == 0:
1721-
xs = list(req.white_piece_squares)
1722-
xp = list(req.white_piece_types)
1723-
ys = list(req.black_piece_squares)
1724-
yp = list(req.black_piece_types)
1725-
else:
1726-
xs = list(req.black_piece_squares)
1727-
xp = list(req.black_piece_types)
1728-
ys = list(req.white_piece_squares)
1729-
yp = list(req.white_piece_types)
1730-
1731-
# Captured pawn trick: from ep square to captured.
1732-
xed = req.epsq ^ (1 << 3)
1733-
1734-
# Find captured index (j).
1735-
try:
1736-
j = ys.index(xed)
1737-
except ValueError:
1738-
j = -1
1739-
1740-
# Try first possible ep capture.
1741-
if 0 == (0x88 & (map88(xed) + 1)):
1742-
capturer_a = xed + 1
1743-
1744-
# Try second possible ep capture.
1745-
if 0 == (0x88 & (map88(xed) - 1)):
1746-
capturer_b = xed - 1
1747-
1748-
if (j > -1) and (ys[j] == xed):
1749-
# Find capturers (i).
1750-
for i in range(len(xs)):
1751-
if xp[i] == chess.PAWN and (xs[i] == capturer_a or xs[i] == capturer_b):
1752-
epscore = iFORBID
1753-
1754-
# Copy position.
1755-
xs_after = xs[:]
1756-
ys_after = ys[:]
1757-
xp_after = xp[:]
1758-
yp_after = yp[:]
1759-
1760-
# Execute capture.
1761-
xs_after[i] = req.epsq
1762-
removepiece(ys_after, yp_after, j)
1763-
1764-
# Flip back.
1765-
if req.side == 1:
1766-
xs_after, ys_after = ys_after, xs_after
1767-
xp_after, yp_after = yp_after, xp_after
1768-
1769-
# Make subrequest.
1770-
subreq = Request(xs_after, xp_after, ys_after, yp_after, opp(req.side), NOSQUARE)
1771-
try:
1772-
epscore = self._tb_probe(subreq)
1773-
epscore = adjust_up(epscore)
1774-
1775-
# Choose to ep or not.
1776-
dtm = bestx(req.side, epscore, dtm)
1777-
except IndexError:
1778-
break
1779-
1780-
return dtm
1781-
17821682
def egtb_block_getnumber(self, req: Request, idx: int) -> int:
17831683
maxindex = EGKEY[req.egkey].maxindex
17841684

test.py

+5
Original file line numberDiff line numberDiff line change
@@ -4348,6 +4348,11 @@ def test_two_ep(self):
43484348
board = chess.Board("K7/8/8/6k1/5pPp/8/8/8 b - g3 0 61")
43494349
self.assertEqual(self.tablebase.probe_dtm(board), 17)
43504350

4351+
@catchAndSkip(chess.gaviota.MissingTableError, "need KPvKP.gtb.cp4")
4352+
def test_ep_is_best(self):
4353+
board = chess.Board("8/8/7k/8/1pP5/7K/8/8 b - c3 0 1")
4354+
self.assertEqual(self.tablebase.probe_dtm(board), 19)
4355+
43514356

43524357
class SvgTestCase(unittest.TestCase):
43534358

0 commit comments

Comments
 (0)