|
1 | 1 | # Copyright (c) Microsoft Corporation.
|
2 | 2 | # Licensed under the MIT license.
|
3 | 3 |
|
| 4 | +from unittest.mock import patch |
| 5 | + |
4 | 6 | from pyrit.common.utils import combine_dict, select_word_indices
|
5 | 7 |
|
6 | 8 |
|
@@ -36,5 +38,12 @@ def test_combine_dict_same_keys():
|
36 | 38 |
|
37 | 39 | def test_word_indices_selection():
|
38 | 40 | assert select_word_indices(words=["word1", "word2", "word3"], mode="all") == [0, 1, 2]
|
| 41 | + assert select_word_indices(words=["word1", "word2", "word3"], mode="custom", indices=[0, 2]) == [0, 2] |
39 | 42 | assert select_word_indices(words=["word1", "word2", "pyrit", "word4"], mode="keywords", keywords=["pyrit"]) == [2]
|
40 | 43 | assert select_word_indices(words=["word1", "word2", "pyrit", "word4"], mode="regex", regex=r"word\d") == [0, 1, 3]
|
| 44 | + |
| 45 | + with patch("random.sample", return_value=[0, 2]): |
| 46 | + result = select_word_indices(words=["word1", "word2", "word3", "word4"], mode="random", sample_ratio=0.5) |
| 47 | + assert result == [0, 2] |
| 48 | + |
| 49 | + assert select_word_indices(words=["word1", "word2"], mode="invalid_mode") == [0, 1] |
0 commit comments