Skip to content

Commit 54c574e

Browse files
committed
make tests more exhaustive
1 parent d223f35 commit 54c574e

File tree

2 files changed

+102
-7
lines changed

2 files changed

+102
-7
lines changed

pyrit/common/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -138,5 +138,8 @@ def select_word_indices(
138138
valid_indices = [i for i in custom_indices if 0 <= i < len(words)]
139139
invalid_indices = [i for i in custom_indices if i < 0 or i >= len(words)]
140140
if invalid_indices:
141-
logger.warning(f"Ignoring out-of-bounds indices: {invalid_indices}")
141+
raise ValueError(
142+
f"Invalid indices {invalid_indices} provided for custom selection. "
143+
f"Valid range is 0 to {len(words) - 1}."
144+
)
142145
return valid_indices

tests/unit/common/test_helper_functions.py

+98-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4+
import pytest
5+
import re
6+
47
from unittest.mock import patch
58

6-
from pyrit.common.utils import combine_dict, select_word_indices
9+
from pyrit.common.utils import combine_dict, get_random_indices, select_word_indices
710

811

912
def test_combine_non_empty_dict():
@@ -36,14 +39,103 @@ def test_combine_dict_same_keys():
3639
assert combine_dict(dict1, dict2) == {"c": "d"}
3740

3841

39-
def test_word_indices_selection():
42+
def test_get_random_indices():
43+
with patch("random.sample", return_value=[2, 4, 6]):
44+
result = get_random_indices(start=0, size=10, proportion=0.3)
45+
assert result == [2, 4, 6]
46+
47+
assert get_random_indices(start=5, size=10, proportion=0) == []
48+
assert sorted(get_random_indices(start=27, size=10, proportion=1)) == list(range(27, 37))
49+
50+
with pytest.raises(ValueError):
51+
get_random_indices(start=-1, size=10, proportion=0.5)
52+
with pytest.raises(ValueError):
53+
get_random_indices(start=0, size=0, proportion=0.5)
54+
with pytest.raises(ValueError):
55+
get_random_indices(start=0, size=10, proportion=-1)
56+
with pytest.raises(ValueError):
57+
get_random_indices(start=0, size=10, proportion=1.01)
58+
59+
60+
def test_word_indices_all_mode():
4061
assert select_word_indices(words=["word1", "word2", "word3"], mode="all") == [0, 1, 2]
62+
assert select_word_indices(words=[], mode="all") == []
63+
64+
large_word_list = [f"word{i}" for i in range(1000)]
65+
assert select_word_indices(words=large_word_list, mode="all") == list(range(1000))
66+
67+
68+
def test_word_indices_custom_mode():
4169
assert select_word_indices(words=["word1", "word2", "word3"], mode="custom", indices=[0, 2]) == [0, 2]
70+
assert select_word_indices(words=["word1", "word2", "word3"], mode="custom", indices=[]) == []
71+
assert select_word_indices(words=["word1", "word2", "word3"], mode="custom") == []
72+
assert select_word_indices(words=[], mode="custom", indices=[0, 1]) == []
73+
74+
with pytest.raises(ValueError):
75+
select_word_indices(words=["word1", "word2", "word3"], mode="custom", indices=[0, 3, -1, 5])
76+
77+
large_word_list = [f"word{i}" for i in range(1000)]
78+
custom_indices = list(range(0, 1000, 10)) # every 10th index
79+
assert select_word_indices(words=large_word_list, mode="custom", indices=custom_indices) == custom_indices
80+
81+
82+
def test_word_indices_keywords_mode():
4283
assert select_word_indices(words=["word1", "word2", "pyrit", "word4"], mode="keywords", keywords=["pyrit"]) == [2]
84+
assert select_word_indices(
85+
words=["word1", "pyrit", "word3", "test"], mode="keywords", keywords=["pyrit", "test"]
86+
) == [1, 3]
87+
88+
assert select_word_indices(words=[], mode="keywords", keywords=["pyrit"]) == []
89+
assert select_word_indices(words=["word1", "word2", "word3"], mode="keywords") == []
90+
assert select_word_indices(words=["word1", "word2", "word3"], mode="keywords", keywords=[]) == []
91+
assert select_word_indices(words=["word1", "word2", "word3"], mode="keywords", keywords=["pyrit"]) == []
92+
93+
large_word_list = [f"word{i}" for i in range(1000)]
94+
large_word_list[123] = "pyrit"
95+
large_word_list[456] = "pyrit"
96+
large_word_list[789] = "test"
97+
assert select_word_indices(words=large_word_list, mode="keywords", keywords=["pyrit", "test"]) == [123, 456, 789]
98+
99+
100+
def test_word_indices_regex_mode():
43101
assert select_word_indices(words=["word1", "word2", "pyrit", "word4"], mode="regex", regex=r"word\d") == [0, 1, 3]
102+
assert select_word_indices(words=["word1", "word2", "word3"], mode="regex") == [0, 1, 2] # default pattern is "."
103+
assert select_word_indices(words=["word1", "word2", "word3"], mode="regex", regex=r"pyrit") == []
104+
assert select_word_indices(words=[], mode="regex", regex=r"word\d") == []
44105

45-
with patch("random.sample", return_value=[0, 2]):
46-
result = select_word_indices(words=["word1", "word2", "word3", "word4"], mode="random", percentage=50)
47-
assert result == [0, 2]
106+
pattern = re.compile(r"word\d")
107+
assert select_word_indices(words=["word1", "word2", "pyrit", "word4"], mode="regex", regex=pattern) == [0, 1, 3]
108+
109+
large_word_list = [f"word{i}" for i in range(1000)]
110+
large_word_list[123] = "don't"
111+
large_word_list[456] = "match"
112+
large_word_list[789] = "these"
113+
regex_results = select_word_indices(words=large_word_list, mode="regex", regex=r"word\d+")
114+
assert len(regex_results) == 997 # 1000 - 3 (123, 456, 789 don't match)
115+
assert 123 not in regex_results
116+
assert 456 not in regex_results
117+
assert 789 not in regex_results
48118

49-
assert select_word_indices(words=["word1", "word2"], mode="invalid_mode") == [0, 1]
119+
120+
def test_word_indices_random_mode():
121+
with patch("random.sample", return_value=[0, 2]):
122+
result = select_word_indices(words=["word1", "word2", "word3", "word4"], mode="random")
123+
assert result == [0, 2]
124+
result = select_word_indices(words=["word1", "word2", "word3", "word4"], mode="random", proportion=0.5)
125+
assert result == [0, 2]
126+
127+
assert select_word_indices(words=[], mode="random", proportion=0.5) == []
128+
assert select_word_indices(words=["word1", "word2", "word3", "word4"], mode="random", proportion=0) == []
129+
assert len(select_word_indices(words=["word1", "word2", "word3", "word4"], mode="random", proportion=1)) == 4
130+
131+
# Test with actual randomness but verify length is correct
132+
large_word_list = [f"word{i}" for i in range(1000)]
133+
random_results = select_word_indices(words=large_word_list, mode="random", proportion=0.43)
134+
assert len(random_results) == 430 # 43% of 1000
135+
136+
137+
def test_word_indices_invalid_mode():
138+
# Should default to "all" mode with warning
139+
assert select_word_indices(words=["word1", "word2"], mode="invalid") == [0, 1] # type: ignore
140+
assert select_word_indices(words=["word1", "word2", "word3"], mode="invalid") == [0, 1, 2] # type: ignore
141+
assert select_word_indices(words=[], mode="invalid") == [] # type: ignore

0 commit comments

Comments
 (0)