Skip to content

Commit f480d55

Browse files
[BUG]: Validate n_iterations in MCTS (#361)
#### Reference Issues/PRs Fixes #358 #### What does this implement/fix? Explain your changes. 1. Expanded constructor input validation in `MCTS.__init__`: * Raises `ValueError` when:- ** a. `n_iterations < 1`. b. `depth < 1`. c. `states` is empty. d. `states` contains duplicates. 4. Added focused regression tests in `test_mcts.py`: * Verifies invalid:- a. `n_iterations` values (`0`, `-1`) are rejected. b. `depth` values (`0`, `-1`) are rejected. c. `states` inputs (empty and duplicate entries) are rejected. ```python def __init__(self, ...) -> None: if depth < 1: raise ValueError(f"`depth` must be >= 1, got {depth}.") if n_iterations < 1: raise ValueError(f"`n_iterations` must be >= 1, got {n_iterations}.") if states is None: states = ["A_", "C_", "G_", "U_", "_A", "_C", "_G", "_U"] elif not states: raise ValueError("`states` must contain at least one entry.") elif len(states) != len(set(states)): raise ValueError("`states` must contain unique entries.") ``` #### What should a reviewer concentrate their feedback on? - [ ] Please focus on whether validating `depth`, `n_iterations`, and `states` in `MCTS.__init__` is the right and consistent fix for preventing invalid runtime behavior (including the infinite-loop path in `run()` for `n_iterations <= 0`). - [ ] Please check whether `ValueError` is the correct API behavior for these invalid constructor inputs. - [ ] Please check that the added regression tests are focused and cover the relevant invalid cases without over-scoping. #### Did you add any tests for the change? yes ```python @pytest.mark.parametrize("depth", [0, -1]) def test_init_invalid_depth(self, depth): ... @pytest.mark.parametrize("n_iterations", [0, -1]) def test_init_invalid_iterations(self, n_iterations): ... def test_init_empty_states(self): ... def test_init_duplicate_states(self): ... ``` #### PR checklist - [x] The PR title starts with either [ENH], [MNT], [DOC], or [BUG]. [BUG] - [x] Added/modified tests - [x] Used pre-commit hooks when committing to ensure that code is compliant with hooks. py-test <img width="1912" height="165" alt="image" src="https://github.com/user-attachments/assets/a396765d-8f98-40f4-b108-73b39be7391d" /> notebooks diff <img width="1919" height="121" alt="Screenshot 2026-04-11 004844" src="https://github.com/user-attachments/assets/541bf011-285f-4319-8389-634def1ac032" /> pre-commit <img width="1907" height="199" alt="image" src="https://github.com/user-attachments/assets/41140756-df9d-4bd9-a7cc-5c7e8b5281e3" /> run-jupyter-notebooks <img width="1914" height="300" alt="Image" src="https://github.com/user-attachments/assets/c92528ec-5a7b-4144-a155-77c90effe95b" />
1 parent f528b38 commit f480d55

2 files changed

Lines changed: 48 additions & 6 deletions

File tree

pyaptamer/mcts/_algorithm.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ class MCTS(BaseObject):
2424
----------
2525
states : list[str], optional, default=None
2626
Possible values for the nodes. Underscores indicate whether the values are
27-
supposed to be prepended or appended to the sequence.
27+
supposed to be prepended or appended to the sequence. If None or empty,
28+
defaults to the standard RNA nucleotide states. Must contain unique entries.
2829
depth : int, optional, default=20
29-
Maximum depth of the search tree, also the length of the generated sequences.
30+
Maximum depth of the search tree, also the length of the generated
31+
sequences. Must be >= 1.
3032
n_iterations : int, optional, default=1000
31-
Number of iterations per round for the MCTS algorithm.
33+
Number of iterations per round for the MCTS algorithm. Must be >= 1.
3234
experiment : BaseAptamerEval, optional, default=None
3335
An instance of an experiment class definingthe goal function for the algorithm.
3436
@@ -74,14 +76,31 @@ def __init__(
7476
n_iterations: int = 1000,
7577
experiment=None,
7678
) -> None:
79+
"""
80+
Raises
81+
------
82+
ValueError
83+
If `depth` is less than 1.
84+
ValueError
85+
If `n_iterations` is less than 1.
86+
ValueError
87+
If `states` contains duplicate entries.
88+
"""
89+
if depth < 1:
90+
raise ValueError(f"`depth` must be >= 1, got {depth}.")
91+
if n_iterations < 1:
92+
raise ValueError(f"`n_iterations` must be >= 1, got {n_iterations}.")
93+
94+
if not states:
95+
states = ["A_", "C_", "G_", "U_", "_A", "_C", "_G", "_U"]
96+
elif len(states) != len(set(states)):
97+
raise ValueError("`states` must contain unique entries.")
98+
7799
self.experiment = experiment
78100
self.depth = depth
79101
self.n_iterations = n_iterations
80102

81103
super().__init__()
82-
83-
if states is None:
84-
states = ["A_", "C_", "G_", "U_", "_A", "_C", "_G", "_U"]
85104
self.states = states
86105

87106
self.root = TreeNode(

pyaptamer/mcts/tests/test_mcts.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,29 @@ def mcts(request):
241241
class TestMCTS:
242242
"""Tests for the MCTS() class."""
243243

244+
@pytest.mark.parametrize("depth", [0, -1])
245+
def test_init_invalid_depth(self, depth):
246+
"""Check invalid depths are rejected early."""
247+
with pytest.raises(ValueError, match=r"`depth` must be >= 1"):
248+
MCTS(depth=depth)
249+
250+
@pytest.mark.parametrize("n_iterations", [0, -1])
251+
def test_init_invalid_iterations(self, n_iterations):
252+
"""Check invalid iteration counts are rejected early."""
253+
with pytest.raises(ValueError, match=r"`n_iterations` must be >= 1"):
254+
MCTS(n_iterations=n_iterations)
255+
256+
@pytest.mark.parametrize("states", [None, []])
257+
def test_init_empty_or_none_states_defaults(self, states):
258+
"""Check that None or empty states default to the standard nucleotide set."""
259+
mcts = MCTS(states=states)
260+
assert mcts.states == ["A_", "C_", "G_", "U_", "_A", "_C", "_G", "_U"]
261+
262+
def test_init_duplicate_states(self):
263+
"""Check duplicate states are rejected early."""
264+
with pytest.raises(ValueError, match=r"`states` must contain unique entries"):
265+
MCTS(states=["A_", "A_"])
266+
244267
def test_reset(self, mcts):
245268
"""Check correct reset of the inner state."""
246269
# modify its inner state

0 commit comments

Comments
 (0)