|
1 | 1 | # Copyright (c) Microsoft Corporation.
|
2 | 2 | # Licensed under the MIT license.
|
3 | 3 |
|
| 4 | +import pytest |
| 5 | +import re |
| 6 | + |
4 | 7 | from unittest.mock import patch
|
5 | 8 |
|
6 |
| -from pyrit.common.utils import combine_dict, select_word_indices |
| 9 | +from pyrit.common.utils import combine_dict, get_random_indices, select_word_indices |
7 | 10 |
|
8 | 11 |
|
9 | 12 | def test_combine_non_empty_dict():
|
@@ -36,14 +39,103 @@ def test_combine_dict_same_keys():
|
36 | 39 | assert combine_dict(dict1, dict2) == {"c": "d"}
|
37 | 40 |
|
38 | 41 |
|
39 |
| -def test_word_indices_selection(): |
| 42 | +def test_get_random_indices(): |
| 43 | + with patch("random.sample", return_value=[2, 4, 6]): |
| 44 | + result = get_random_indices(start=0, size=10, proportion=0.3) |
| 45 | + assert result == [2, 4, 6] |
| 46 | + |
| 47 | + assert get_random_indices(start=5, size=10, proportion=0) == [] |
| 48 | + assert sorted(get_random_indices(start=27, size=10, proportion=1)) == list(range(27, 37)) |
| 49 | + |
| 50 | + with pytest.raises(ValueError): |
| 51 | + get_random_indices(start=-1, size=10, proportion=0.5) |
| 52 | + with pytest.raises(ValueError): |
| 53 | + get_random_indices(start=0, size=0, proportion=0.5) |
| 54 | + with pytest.raises(ValueError): |
| 55 | + get_random_indices(start=0, size=10, proportion=-1) |
| 56 | + with pytest.raises(ValueError): |
| 57 | + get_random_indices(start=0, size=10, proportion=1.01) |
| 58 | + |
| 59 | + |
| 60 | +def test_word_indices_all_mode(): |
40 | 61 | assert select_word_indices(words=["word1", "word2", "word3"], mode="all") == [0, 1, 2]
|
| 62 | + assert select_word_indices(words=[], mode="all") == [] |
| 63 | + |
| 64 | + large_word_list = [f"word{i}" for i in range(1000)] |
| 65 | + assert select_word_indices(words=large_word_list, mode="all") == list(range(1000)) |
| 66 | + |
| 67 | + |
| 68 | +def test_word_indices_custom_mode(): |
41 | 69 | assert select_word_indices(words=["word1", "word2", "word3"], mode="custom", indices=[0, 2]) == [0, 2]
|
| 70 | + assert select_word_indices(words=["word1", "word2", "word3"], mode="custom", indices=[]) == [] |
| 71 | + assert select_word_indices(words=["word1", "word2", "word3"], mode="custom") == [] |
| 72 | + assert select_word_indices(words=[], mode="custom", indices=[0, 1]) == [] |
| 73 | + |
| 74 | + with pytest.raises(ValueError): |
| 75 | + select_word_indices(words=["word1", "word2", "word3"], mode="custom", indices=[0, 3, -1, 5]) |
| 76 | + |
| 77 | + large_word_list = [f"word{i}" for i in range(1000)] |
| 78 | + custom_indices = list(range(0, 1000, 10)) # every 10th index |
| 79 | + assert select_word_indices(words=large_word_list, mode="custom", indices=custom_indices) == custom_indices |
| 80 | + |
| 81 | + |
| 82 | +def test_word_indices_keywords_mode(): |
42 | 83 | assert select_word_indices(words=["word1", "word2", "pyrit", "word4"], mode="keywords", keywords=["pyrit"]) == [2]
|
| 84 | + assert select_word_indices( |
| 85 | + words=["word1", "pyrit", "word3", "test"], mode="keywords", keywords=["pyrit", "test"] |
| 86 | + ) == [1, 3] |
| 87 | + |
| 88 | + assert select_word_indices(words=[], mode="keywords", keywords=["pyrit"]) == [] |
| 89 | + assert select_word_indices(words=["word1", "word2", "word3"], mode="keywords") == [] |
| 90 | + assert select_word_indices(words=["word1", "word2", "word3"], mode="keywords", keywords=[]) == [] |
| 91 | + assert select_word_indices(words=["word1", "word2", "word3"], mode="keywords", keywords=["pyrit"]) == [] |
| 92 | + |
| 93 | + large_word_list = [f"word{i}" for i in range(1000)] |
| 94 | + large_word_list[123] = "pyrit" |
| 95 | + large_word_list[456] = "pyrit" |
| 96 | + large_word_list[789] = "test" |
| 97 | + assert select_word_indices(words=large_word_list, mode="keywords", keywords=["pyrit", "test"]) == [123, 456, 789] |
| 98 | + |
| 99 | + |
| 100 | +def test_word_indices_regex_mode(): |
43 | 101 | assert select_word_indices(words=["word1", "word2", "pyrit", "word4"], mode="regex", regex=r"word\d") == [0, 1, 3]
|
| 102 | + assert select_word_indices(words=["word1", "word2", "word3"], mode="regex") == [0, 1, 2] # default pattern is "." |
| 103 | + assert select_word_indices(words=["word1", "word2", "word3"], mode="regex", regex=r"pyrit") == [] |
| 104 | + assert select_word_indices(words=[], mode="regex", regex=r"word\d") == [] |
44 | 105 |
|
45 |
| - with patch("random.sample", return_value=[0, 2]): |
46 |
| - result = select_word_indices(words=["word1", "word2", "word3", "word4"], mode="random", percentage=50) |
47 |
| - assert result == [0, 2] |
| 106 | + pattern = re.compile(r"word\d") |
| 107 | + assert select_word_indices(words=["word1", "word2", "pyrit", "word4"], mode="regex", regex=pattern) == [0, 1, 3] |
| 108 | + |
| 109 | + large_word_list = [f"word{i}" for i in range(1000)] |
| 110 | + large_word_list[123] = "don't" |
| 111 | + large_word_list[456] = "match" |
| 112 | + large_word_list[789] = "these" |
| 113 | + regex_results = select_word_indices(words=large_word_list, mode="regex", regex=r"word\d+") |
| 114 | + assert len(regex_results) == 997 # 1000 - 3 (123, 456, 789 don't match) |
| 115 | + assert 123 not in regex_results |
| 116 | + assert 456 not in regex_results |
| 117 | + assert 789 not in regex_results |
48 | 118 |
|
49 |
| - assert select_word_indices(words=["word1", "word2"], mode="invalid_mode") == [0, 1] |
| 119 | + |
| 120 | +def test_word_indices_random_mode(): |
| 121 | + with patch("random.sample", return_value=[0, 2]): |
| 122 | + result = select_word_indices(words=["word1", "word2", "word3", "word4"], mode="random") |
| 123 | + assert result == [0, 2] |
| 124 | + result = select_word_indices(words=["word1", "word2", "word3", "word4"], mode="random", proportion=0.5) |
| 125 | + assert result == [0, 2] |
| 126 | + |
| 127 | + assert select_word_indices(words=[], mode="random", proportion=0.5) == [] |
| 128 | + assert select_word_indices(words=["word1", "word2", "word3", "word4"], mode="random", proportion=0) == [] |
| 129 | + assert len(select_word_indices(words=["word1", "word2", "word3", "word4"], mode="random", proportion=1)) == 4 |
| 130 | + |
| 131 | + # Test with actual randomness but verify length is correct |
| 132 | + large_word_list = [f"word{i}" for i in range(1000)] |
| 133 | + random_results = select_word_indices(words=large_word_list, mode="random", proportion=0.43) |
| 134 | + assert len(random_results) == 430 # 43% of 1000 |
| 135 | + |
| 136 | + |
| 137 | +def test_word_indices_invalid_mode(): |
| 138 | + # Should default to "all" mode with warning |
| 139 | + assert select_word_indices(words=["word1", "word2"], mode="invalid") == [0, 1] # type: ignore |
| 140 | + assert select_word_indices(words=["word1", "word2", "word3"], mode="invalid") == [0, 1, 2] # type: ignore |
| 141 | + assert select_word_indices(words=[], mode="invalid") == [] # type: ignore |
0 commit comments