Skip to content

Commit a992bb7

Browse files
Validate n_iterations in MCTS
1 parent 3c8b502 commit a992bb7

2 files changed

Lines changed: 9 additions & 0 deletions

File tree

pyaptamer/mcts/_algorithm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def __init__(
7474
n_iterations: int = 1000,
7575
experiment=None,
7676
) -> None:
77+
if n_iterations < 1:
78+
raise ValueError(f"`n_iterations` must be >= 1, got {n_iterations}.")
79+
7780
self.experiment = experiment
7881
self.depth = depth
7982
self.n_iterations = n_iterations

pyaptamer/mcts/tests/test_mcts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ def mcts(request):
241241
class TestMCTS:
242242
"""Tests for the MCTS() class."""
243243

244+
@pytest.mark.parametrize("n_iterations", [0, -1])
245+
def test_init_invalid_iterations(self, n_iterations):
246+
"""Check invalid iteration counts are rejected early."""
247+
with pytest.raises(ValueError, match=r"`n_iterations` must be >= 1"):
248+
MCTS(n_iterations=n_iterations)
249+
244250
def test_reset(self, mcts):
245251
"""Check correct reset of the inner state."""
246252
# modify its inner state

0 commit comments

Comments
 (0)