|
27 | 27 | ContextCodeSnippet, |
28 | 28 | CodeSnippetKey, |
29 | 29 | ) |
| 30 | + |
| 31 | +from buttercup.program_model.utils.common import ( |
| 32 | + Function, |
| 33 | + FunctionBody, |
| 34 | + TypeDefinition, |
| 35 | + TypeDefinitionType, |
| 36 | +) |
30 | 37 | from buttercup.patcher.patcher import PatchInput |
31 | 38 | from buttercup.patcher.utils import PatchInputPoV |
32 | 39 | from langgraph.types import Command |
@@ -113,11 +120,36 @@ def mock_llm_functions( |
113 | 120 | mock_test_instructions_prompt: MagicMock, |
114 | 121 | ): |
115 | 122 | """Mock LLM creation functions and environment variables.""" |
| 123 | + # Create a mock CodeQueryPersistentRest that returns expected data |
| 124 | + mock_codequery = MagicMock() |
| 125 | + mock_codequery.get_functions.return_value = [ |
| 126 | + Function( |
| 127 | + name="test_function", |
| 128 | + file_path=Path("/src/test.c"), |
| 129 | + bodies=[FunctionBody(body="int test_function() { return 0; }", start_line=1, end_line=3)], |
| 130 | + ) |
| 131 | + ] |
| 132 | + mock_codequery.get_callers.return_value = [] |
| 133 | + mock_codequery.get_callees.return_value = [] |
| 134 | + mock_codequery.get_types.return_value = [ |
| 135 | + TypeDefinition( |
| 136 | + name="test_struct", |
| 137 | + type=TypeDefinitionType.STRUCT, |
| 138 | + definition="struct test_struct { int x; };", |
| 139 | + definition_line=5, |
| 140 | + file_path=Path("/src/test.h"), |
| 141 | + ) |
| 142 | + ] |
| 143 | + mock_codequery.get_type_calls.return_value = [] |
| 144 | + |
116 | 145 | with ( |
117 | 146 | patch.dict(os.environ, {"BUTTERCUP_LITELLM_HOSTNAME": "http://test-host", "BUTTERCUP_LITELLM_KEY": "test-key"}), |
118 | 147 | patch("buttercup.common.llm.create_default_llm", return_value=mock_cheap_llm), |
119 | 148 | patch("buttercup.common.llm.create_llm", return_value=mock_cheap_llm), |
120 | 149 | patch("langgraph.prebuilt.chat_agent_executor._get_prompt_runnable", return_value=mock_agent_llm), |
| 150 | + patch("buttercup.patcher.utils.get_codequery", return_value=mock_codequery), |
| 151 | + patch("buttercup.program_model.rest_client.CodeQueryPersistentRest", return_value=mock_codequery), |
| 152 | + patch("buttercup.patcher.agents.tools.CodeQueryPersistentRest", return_value=mock_codequery), |
121 | 153 | ): |
122 | 154 | import buttercup.patcher.agents.context_retriever |
123 | 155 |
|
@@ -548,10 +580,28 @@ def test_retrieve_context_basic( |
548 | 580 | assert "const ebitmap_t *e1" in snippet.code |
549 | 581 |
|
550 | 582 |
|
| 583 | +@patch("buttercup.patcher.agents.tools.get_codequery") |
| 584 | +@patch("buttercup.patcher.utils.get_codequery") |
551 | 585 | def test_missing_arg_tool_call( |
552 | | - mock_agent: ContextRetrieverAgent, mock_agent_llm: MagicMock, mock_runnable_config |
| 586 | + mock_get_codequery_utils, |
| 587 | + mock_get_codequery_tools, |
| 588 | + mock_agent: ContextRetrieverAgent, |
| 589 | + mock_agent_llm: MagicMock, |
| 590 | + mock_runnable_config, |
553 | 591 | ) -> None: |
554 | 592 | """Test basic context retrieval functionality.""" |
| 593 | + # Setup mock codequery |
| 594 | + mock_codequery = MagicMock() |
| 595 | + mock_codequery.get_functions.return_value = [ |
| 596 | + Function( |
| 597 | + name="main", |
| 598 | + file_path=Path("/src/test.c"), |
| 599 | + bodies=[FunctionBody(body="int main() { int a = foo(); return a; }", start_line=1, end_line=3)], |
| 600 | + ) |
| 601 | + ] |
| 602 | + mock_get_codequery_utils.return_value = mock_codequery |
| 603 | + mock_get_codequery_tools.return_value = mock_codequery |
| 604 | + |
555 | 605 | # Create a test state with a simple request |
556 | 606 | state = ContextRetrieverState( |
557 | 607 | code_snippet_requests=[ |
@@ -666,10 +716,28 @@ def test_recursion_limit(mock_agent: ContextRetrieverAgent, mock_agent_llm: Magi |
666 | 716 | ) |
667 | 717 |
|
668 | 718 |
|
| 719 | +@patch("buttercup.patcher.agents.tools.get_codequery") |
| 720 | +@patch("buttercup.patcher.utils.get_codequery") |
669 | 721 | def test_recursion_limit_tmp_code_snippets( |
670 | | - mock_agent: ContextRetrieverAgent, mock_agent_llm: MagicMock, mock_runnable_config |
| 722 | + mock_get_codequery_utils, |
| 723 | + mock_get_codequery_tools, |
| 724 | + mock_agent: ContextRetrieverAgent, |
| 725 | + mock_agent_llm: MagicMock, |
| 726 | + mock_runnable_config, |
671 | 727 | ) -> None: |
672 | 728 | """Test hitting the context request limit but getting some partial results.""" |
| 729 | + # Setup mock codequery |
| 730 | + mock_codequery = MagicMock() |
| 731 | + mock_codequery.get_functions.return_value = [ |
| 732 | + Function( |
| 733 | + name="main", |
| 734 | + file_path=Path("/src/test.c"), |
| 735 | + bodies=[FunctionBody(body="int main() { int a = foo(); return a; }", start_line=1, end_line=3)], |
| 736 | + ) |
| 737 | + ] |
| 738 | + mock_get_codequery_utils.return_value = mock_codequery |
| 739 | + mock_get_codequery_tools.return_value = mock_codequery |
| 740 | + |
673 | 741 | state = ContextRetrieverState( |
674 | 742 | code_snippet_requests=[ |
675 | 743 | CodeSnippetRequest(request="Find function main"), |
@@ -734,10 +802,28 @@ def test_recursion_limit_tmp_code_snippets( |
734 | 802 | assert code_snippet.code == "int main() { int a = foo(); return a; }" |
735 | 803 |
|
736 | 804 |
|
| 805 | +@patch("buttercup.patcher.agents.tools.get_codequery") |
| 806 | +@patch("buttercup.patcher.utils.get_codequery") |
737 | 807 | def test_dupped_code_snippets( |
738 | | - mock_agent: ContextRetrieverAgent, mock_agent_llm: MagicMock, mock_runnable_config |
| 808 | + mock_get_codequery_utils, |
| 809 | + mock_get_codequery_tools, |
| 810 | + mock_agent: ContextRetrieverAgent, |
| 811 | + mock_agent_llm: MagicMock, |
| 812 | + mock_runnable_config, |
739 | 813 | ) -> None: |
740 | 814 | """Test that we don't return duplicate code snippets.""" |
| 815 | + # Setup mock codequery |
| 816 | + mock_codequery = MagicMock() |
| 817 | + mock_codequery.get_functions.return_value = [ |
| 818 | + Function( |
| 819 | + name="main", |
| 820 | + file_path=Path("/src/test.c"), |
| 821 | + bodies=[FunctionBody(body="int main() { int a = foo(); return a; }", start_line=1, end_line=3)], |
| 822 | + ) |
| 823 | + ] |
| 824 | + mock_get_codequery_utils.return_value = mock_codequery |
| 825 | + mock_get_codequery_tools.return_value = mock_codequery |
| 826 | + |
741 | 827 | state = ContextRetrieverState( |
742 | 828 | code_snippet_requests=[ |
743 | 829 | CodeSnippetRequest(request="Find function main"), |
@@ -842,10 +928,31 @@ def test_dupped_code_snippets( |
842 | 928 | assert code_snippet.code == "int main() { int a = foo(); return a; }" |
843 | 929 |
|
844 | 930 |
|
| 931 | +@patch("buttercup.patcher.agents.tools.get_codequery") |
| 932 | +@patch("buttercup.patcher.utils.get_codequery") |
845 | 933 | def test_get_type( |
846 | | - mock_agent: ContextRetrieverAgent, mock_cheap_llm: MagicMock, mock_agent_llm: MagicMock, mock_runnable_config |
| 934 | + mock_get_codequery_utils, |
| 935 | + mock_get_codequery_tools, |
| 936 | + mock_agent: ContextRetrieverAgent, |
| 937 | + mock_cheap_llm: MagicMock, |
| 938 | + mock_agent_llm: MagicMock, |
| 939 | + mock_runnable_config, |
847 | 940 | ) -> None: |
848 | 941 | """Test that we can get the type definition.""" |
| 942 | + # Setup mock codequery |
| 943 | + mock_codequery = MagicMock() |
| 944 | + mock_codequery.get_types.return_value = [ |
| 945 | + TypeDefinition( |
| 946 | + name="ebitmap_t", |
| 947 | + type=TypeDefinitionType.STRUCT, |
| 948 | + definition="struct ebitmap_t { int a; }", |
| 949 | + definition_line=5, |
| 950 | + file_path=Path("/src/test.h"), |
| 951 | + ) |
| 952 | + ] |
| 953 | + mock_get_codequery_utils.return_value = mock_codequery |
| 954 | + mock_get_codequery_tools.return_value = mock_codequery |
| 955 | + |
849 | 956 | state = ContextRetrieverState( |
850 | 957 | code_snippet_requests=[ |
851 | 958 | CodeSnippetRequest(request="Find type ebitmap_t"), |
@@ -893,10 +1000,37 @@ def test_get_type( |
893 | 1000 | assert code_snippet.code == "struct ebitmap_t { int a; }" |
894 | 1001 |
|
895 | 1002 |
|
| 1003 | +@patch("buttercup.patcher.agents.tools.get_codequery") |
| 1004 | +@patch("buttercup.patcher.utils.get_codequery") |
896 | 1005 | def test_get_definitions_no_paths( |
897 | | - mock_agent: ContextRetrieverAgent, mock_agent_llm: MagicMock, mock_runnable_config |
| 1006 | + mock_get_codequery_utils, |
| 1007 | + mock_get_codequery_tools, |
| 1008 | + mock_agent: ContextRetrieverAgent, |
| 1009 | + mock_agent_llm: MagicMock, |
| 1010 | + mock_runnable_config, |
898 | 1011 | ) -> None: |
899 | 1012 | """Test that we can get the type definition even if the file path is not provided.""" |
| 1013 | + # Setup mock codequery |
| 1014 | + mock_codequery = MagicMock() |
| 1015 | + mock_codequery.get_functions.return_value = [ |
| 1016 | + Function( |
| 1017 | + name="main", |
| 1018 | + file_path=Path("/src/example_project/test.c"), |
| 1019 | + bodies=[FunctionBody(body="int main() { int a = foo(); return a; }", start_line=1, end_line=3)], |
| 1020 | + ) |
| 1021 | + ] |
| 1022 | + mock_codequery.get_types.return_value = [ |
| 1023 | + TypeDefinition( |
| 1024 | + name="ebitmap_t", |
| 1025 | + type=TypeDefinitionType.STRUCT, |
| 1026 | + definition="struct ebitmap_t { int a; }", |
| 1027 | + definition_line=5, |
| 1028 | + file_path=Path("/src/example_project/test.h"), |
| 1029 | + ) |
| 1030 | + ] |
| 1031 | + mock_get_codequery_utils.return_value = mock_codequery |
| 1032 | + mock_get_codequery_tools.return_value = mock_codequery |
| 1033 | + |
900 | 1034 | state = ContextRetrieverState( |
901 | 1035 | code_snippet_requests=[ |
902 | 1036 | CodeSnippetRequest(request="Find type ebitmap_t"), |
@@ -1079,10 +1213,28 @@ def test_low_recursion_limit_empty( |
1079 | 1213 | ) |
1080 | 1214 |
|
1081 | 1215 |
|
| 1216 | +@patch("buttercup.patcher.agents.tools.get_codequery") |
| 1217 | +@patch("buttercup.patcher.utils.get_codequery") |
1082 | 1218 | def test_low_recursion_limit_with_results( |
1083 | | - mock_agent: ContextRetrieverAgent, mock_agent_llm: MagicMock, mock_runnable_config |
| 1219 | + mock_get_codequery_utils, |
| 1220 | + mock_get_codequery_tools, |
| 1221 | + mock_agent: ContextRetrieverAgent, |
| 1222 | + mock_agent_llm: MagicMock, |
| 1223 | + mock_runnable_config, |
1084 | 1224 | ) -> None: |
1085 | 1225 | """Test that hitting a low recursion limit after getting some results still returns those results.""" |
| 1226 | + # Setup mock codequery |
| 1227 | + mock_codequery = MagicMock() |
| 1228 | + mock_codequery.get_functions.return_value = [ |
| 1229 | + Function( |
| 1230 | + name="main", |
| 1231 | + file_path=Path("/src/test.c"), |
| 1232 | + bodies=[FunctionBody(body="int main() { int a = foo(); return a; }", start_line=1, end_line=3)], |
| 1233 | + ) |
| 1234 | + ] |
| 1235 | + mock_get_codequery_utils.return_value = mock_codequery |
| 1236 | + mock_get_codequery_tools.return_value = mock_codequery |
| 1237 | + |
1086 | 1238 | state = ContextRetrieverState( |
1087 | 1239 | code_snippet_requests=[ |
1088 | 1240 | CodeSnippetRequest(request="Find function main"), |
@@ -1152,10 +1304,37 @@ def test_low_recursion_limit_with_results( |
1152 | 1304 | assert code_snippet.code == "int main() { int a = foo(); return a; }" |
1153 | 1305 |
|
1154 | 1306 |
|
| 1307 | +@patch("buttercup.patcher.agents.tools.get_codequery") |
| 1308 | +@patch("buttercup.patcher.utils.get_codequery") |
1155 | 1309 | def test_multiple_code_snippet_requests( |
1156 | | - mock_agent: ContextRetrieverAgent, mock_agent_llm: MagicMock, mock_runnable_config |
| 1310 | + mock_get_codequery_utils, |
| 1311 | + mock_get_codequery_tools, |
| 1312 | + mock_agent: ContextRetrieverAgent, |
| 1313 | + mock_agent_llm: MagicMock, |
| 1314 | + mock_runnable_config, |
1157 | 1315 | ) -> None: |
1158 | 1316 | """Test handling multiple code snippet requests in a single state.""" |
| 1317 | + # Setup mock codequery |
| 1318 | + mock_codequery = MagicMock() |
| 1319 | + mock_codequery.get_functions.return_value = [ |
| 1320 | + Function( |
| 1321 | + name="main", |
| 1322 | + file_path=Path("/src/example_project/test.c"), |
| 1323 | + bodies=[FunctionBody(body="int main() { int a = foo(); return a; }", start_line=1, end_line=3)], |
| 1324 | + ) |
| 1325 | + ] |
| 1326 | + mock_codequery.get_types.return_value = [ |
| 1327 | + TypeDefinition( |
| 1328 | + name="ebitmap_t", |
| 1329 | + type=TypeDefinitionType.STRUCT, |
| 1330 | + definition="struct ebitmap_t { int a; }", |
| 1331 | + definition_line=5, |
| 1332 | + file_path=Path("/src/example_project/test.h"), |
| 1333 | + ) |
| 1334 | + ] |
| 1335 | + mock_get_codequery_utils.return_value = mock_codequery |
| 1336 | + mock_get_codequery_tools.return_value = mock_codequery |
| 1337 | + |
1159 | 1338 | state = ContextRetrieverState( |
1160 | 1339 | code_snippet_requests=[ |
1161 | 1340 | CodeSnippetRequest(request="Find function main"), |
@@ -1397,8 +1576,37 @@ def test_invalid_argument_types( |
1397 | 1576 | assert len(result.update["relevant_code_snippets"]) == 0 |
1398 | 1577 |
|
1399 | 1578 |
|
1400 | | -def test_llm_error_recovery(mock_agent: ContextRetrieverAgent, mock_agent_llm: MagicMock, mock_runnable_config) -> None: |
| 1579 | +@patch("buttercup.patcher.agents.tools.get_codequery") |
| 1580 | +@patch("buttercup.patcher.utils.get_codequery") |
| 1581 | +def test_llm_error_recovery( |
| 1582 | + mock_get_codequery_utils, |
| 1583 | + mock_get_codequery_tools, |
| 1584 | + mock_agent: ContextRetrieverAgent, |
| 1585 | + mock_agent_llm: MagicMock, |
| 1586 | + mock_runnable_config, |
| 1587 | +) -> None: |
1401 | 1588 | """Test that the agent recovers from LLM errors and continues processing.""" |
| 1589 | + # Setup mock codequery |
| 1590 | + mock_codequery = MagicMock() |
| 1591 | + mock_codequery.get_functions.return_value = [ |
| 1592 | + Function( |
| 1593 | + name="main", |
| 1594 | + file_path=Path("/src/test.c"), |
| 1595 | + bodies=[FunctionBody(body="int main() { int a = foo(); return a; }", start_line=1, end_line=3)], |
| 1596 | + ) |
| 1597 | + ] |
| 1598 | + mock_codequery.get_types.return_value = [ |
| 1599 | + TypeDefinition( |
| 1600 | + name="ebitmap_t", |
| 1601 | + type=TypeDefinitionType.STRUCT, |
| 1602 | + definition="struct ebitmap_t { int a; }", |
| 1603 | + definition_line=5, |
| 1604 | + file_path=Path("/src/test.h"), |
| 1605 | + ) |
| 1606 | + ] |
| 1607 | + mock_get_codequery_utils.return_value = mock_codequery |
| 1608 | + mock_get_codequery_tools.return_value = mock_codequery |
| 1609 | + |
1402 | 1610 | state = ContextRetrieverState( |
1403 | 1611 | code_snippet_requests=[ |
1404 | 1612 | CodeSnippetRequest(request="Find function main"), |
|
0 commit comments