Skip to content

Commit 25a8d99

Browse files
added input validation for all sensible parameters with tests
1 parent a992bb7 commit 25a8d99

2 files changed

Lines changed: 30 additions & 6 deletions

File tree

pyaptamer/mcts/_algorithm.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ class MCTS(BaseObject):
2424
----------
2525
states : list[str], optional, default=None
2626
Possible values for the nodes. Underscores indicate whether the values are
27-
supposed to be prepended or appended to the sequence.
27+
supposed to be prepended or appended to the sequence. Must be non-empty and
28+
contain unique entries.
2829
depth : int, optional, default=20
29-
Maximum depth of the search tree, also the length of the generated sequences.
30+
Maximum depth of the search tree, also the length of the generated
31+
sequences. Must be >= 1.
3032
n_iterations : int, optional, default=1000
31-
Number of iterations per round for the MCTS algorithm.
33+
Number of iterations per round for the MCTS algorithm. Must be >= 1.
3234
experiment : BaseAptamerEval, optional, default=None
3335
An instance of an experiment class definingthe goal function for the algorithm.
3436
@@ -74,17 +76,23 @@ def __init__(
7476
n_iterations: int = 1000,
7577
experiment=None,
7678
) -> None:
79+
if depth < 1:
80+
raise ValueError(f"`depth` must be >= 1, got {depth}.")
7781
if n_iterations < 1:
7882
raise ValueError(f"`n_iterations` must be >= 1, got {n_iterations}.")
7983

84+
if states is None:
85+
states = ["A_", "C_", "G_", "U_", "_A", "_C", "_G", "_U"]
86+
elif not states:
87+
raise ValueError("`states` must contain at least one entry.")
88+
elif len(states) != len(set(states)):
89+
raise ValueError("`states` must contain unique entries.")
90+
8091
self.experiment = experiment
8192
self.depth = depth
8293
self.n_iterations = n_iterations
8394

8495
super().__init__()
85-
86-
if states is None:
87-
states = ["A_", "C_", "G_", "U_", "_A", "_C", "_G", "_U"]
8896
self.states = states
8997

9098
self.root = TreeNode(

pyaptamer/mcts/tests/test_mcts.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,28 @@ def mcts(request):
241241
class TestMCTS:
242242
"""Tests for the MCTS() class."""
243243

244+
@pytest.mark.parametrize("depth", [0, -1])
245+
def test_init_invalid_depth(self, depth):
246+
"""Check invalid depths are rejected early."""
247+
with pytest.raises(ValueError, match=r"`depth` must be >= 1"):
248+
MCTS(depth=depth)
249+
244250
@pytest.mark.parametrize("n_iterations", [0, -1])
245251
def test_init_invalid_iterations(self, n_iterations):
246252
"""Check invalid iteration counts are rejected early."""
247253
with pytest.raises(ValueError, match=r"`n_iterations` must be >= 1"):
248254
MCTS(n_iterations=n_iterations)
249255

256+
def test_init_empty_states(self):
257+
"""Check an empty search space is rejected early."""
258+
with pytest.raises(ValueError, match=r"`states` must contain at least one"):
259+
MCTS(states=[])
260+
261+
def test_init_duplicate_states(self):
262+
"""Check duplicate states are rejected early."""
263+
with pytest.raises(ValueError, match=r"`states` must contain unique entries"):
264+
MCTS(states=["A_", "A_"])
265+
250266
def test_reset(self, mcts):
251267
"""Check correct reset of the inner state."""
252268
# modify its inner state

0 commit comments

Comments
 (0)