Skip to content

Commit 01930d0

Browse files
changes requested
1 parent 25a8d99 commit 01930d0

2 files changed

Lines changed: 17 additions & 9 deletions

File tree

pyaptamer/mcts/_algorithm.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ class MCTS(BaseObject):
2424
----------
2525
states : list[str], optional, default=None
2626
Possible values for the nodes. Underscores indicate whether the values are
27-
supposed to be prepended or appended to the sequence. Must be non-empty and
28-
contain unique entries.
27+
supposed to be prepended or appended to the sequence. If None or empty,
28+
defaults to the standard RNA nucleotide states. Must contain unique entries.
2929
depth : int, optional, default=20
3030
Maximum depth of the search tree, also the length of the generated
3131
sequences. Must be >= 1.
@@ -34,6 +34,15 @@ class MCTS(BaseObject):
3434
experiment : BaseAptamerEval, optional, default=None
3535
An instance of an experiment class definingthe goal function for the algorithm.
3636
37+
Raises
38+
------
39+
ValueError
40+
If ``depth`` is less than 1.
41+
ValueError
42+
If ``n_iterations`` is less than 1.
43+
ValueError
44+
If ``states`` contains duplicate entries.
45+
3746
Attributes
3847
----------
3948
root : TreeNode
@@ -81,10 +90,8 @@ def __init__(
8190
if n_iterations < 1:
8291
raise ValueError(f"`n_iterations` must be >= 1, got {n_iterations}.")
8392

84-
if states is None:
93+
if not states:
8594
states = ["A_", "C_", "G_", "U_", "_A", "_C", "_G", "_U"]
86-
elif not states:
87-
raise ValueError("`states` must contain at least one entry.")
8895
elif len(states) != len(set(states)):
8996
raise ValueError("`states` must contain unique entries.")
9097

pyaptamer/mcts/tests/test_mcts.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,11 @@ def test_init_invalid_iterations(self, n_iterations):
253253
with pytest.raises(ValueError, match=r"`n_iterations` must be >= 1"):
254254
MCTS(n_iterations=n_iterations)
255255

256-
def test_init_empty_states(self):
257-
"""Check an empty search space is rejected early."""
258-
with pytest.raises(ValueError, match=r"`states` must contain at least one"):
259-
MCTS(states=[])
256+
@pytest.mark.parametrize("states", [None, []])
257+
def test_init_empty_or_none_states_defaults(self, states):
258+
"""Check that None or empty states default to the standard nucleotide set."""
259+
mcts = MCTS(states=states)
260+
assert mcts.states == ["A_", "C_", "G_", "U_", "_A", "_C", "_G", "_U"]
260261

261262
def test_init_duplicate_states(self):
262263
"""Check duplicate states are rejected early."""

0 commit comments

Comments
 (0)