|
30 | 30 | from pydantic import BaseModel, Field |
31 | 31 |
|
32 | 32 | from camel.agents import ChatAgent |
33 | | -from camel.agents.chat_agent import StreamContentAccumulator, ToolCallingRecord |
| 33 | +from camel.agents.chat_agent import ( |
| 34 | + StreamContentAccumulator, |
| 35 | + ToolCallingRecord, |
| 36 | + _ToolOutputHistoryEntry, |
| 37 | +) |
34 | 38 | from camel.configs import ChatGPTConfig |
35 | 39 | from camel.generators import SystemMessageGenerator |
36 | 40 | from camel.memories import MemoryRecord |
37 | 41 | from camel.messages import BaseMessage |
| 42 | +from camel.models import BaseModelBackend |
38 | 43 | from camel.models import AnthropicModel, ModelFactory, OpenAIModel |
39 | 44 | from camel.terminators import ResponseWordsTerminator |
40 | 45 | from camel.toolkits import ( |
|
90 | 95 | ) |
91 | 96 |
|
92 | 97 |
|
| 98 | +class DummyModel(BaseModelBackend): |
| 99 | + @property |
| 100 | + def token_counter(self): |
| 101 | + return self._token_counter |
| 102 | + |
| 103 | + def _run(self, messages, response_format=None, tools=None): |
| 104 | + raise NotImplementedError |
| 105 | + |
| 106 | + async def _arun(self, messages, response_format=None, tools=None): |
| 107 | + raise NotImplementedError |
| 108 | + |
| 109 | + |
93 | 110 | @parametrize |
94 | 111 | def test_chat_agent(model, step_call_count=3): |
95 | 112 | model = model |
@@ -145,6 +162,55 @@ def test_chat_agent(model, step_call_count=3): |
145 | 162 | ), f"Error in round {i + 1}" |
146 | 163 |
|
147 | 164 |
|
| 165 | +def test_chat_agent_reset_clears_tool_output_history(): |
| 166 | + model = DummyModel(ModelType.GPT_4O_MINI) |
| 167 | + assistant = ChatAgent( |
| 168 | + system_message="You are a helpful assistant.", |
| 169 | + model=model, |
| 170 | + enable_snapshot_clean=True, |
| 171 | + ) |
| 172 | + assistant._tool_output_history = [ |
| 173 | + _ToolOutputHistoryEntry( |
| 174 | + tool_name="tool_a", |
| 175 | + tool_call_id="call_a", |
| 176 | + result_text="old", |
| 177 | + record_uuids=["uuid_a"], |
| 178 | + record_timestamps=[1.0], |
| 179 | + ) |
| 180 | + ] |
| 181 | + |
| 182 | + assistant.reset() |
| 183 | + |
| 184 | + assert assistant._tool_output_history == [] |
| 185 | + |
| 186 | + |
| 187 | +def test_clean_snapshot_in_memory_skips_missing_records(): |
| 188 | + model = DummyModel(ModelType.GPT_4O_MINI) |
| 189 | + assistant = ChatAgent( |
| 190 | + system_message="You are a helpful assistant.", |
| 191 | + model=model, |
| 192 | + enable_snapshot_clean=True, |
| 193 | + ) |
| 194 | + entry = _ToolOutputHistoryEntry( |
| 195 | + tool_name="browser_get_page_snapshot", |
| 196 | + tool_call_id="call_missing", |
| 197 | + result_text="- button [ref=1]", |
| 198 | + record_uuids=["missing_uuid"], |
| 199 | + record_timestamps=[123.0], |
| 200 | + ) |
| 201 | + |
| 202 | + chat_history_block = getattr(assistant.memory, "_chat_history_block", None) |
| 203 | + storage = getattr(chat_history_block, "storage", None) |
| 204 | + assert storage is not None |
| 205 | + records_before = storage.load() |
| 206 | + |
| 207 | + assistant._clean_snapshot_in_memory(entry) |
| 208 | + |
| 209 | + records_after = storage.load() |
| 210 | + assert records_after == records_before |
| 211 | + assert entry.cached is True |
| 212 | + |
| 213 | + |
148 | 214 | @pytest.mark.model_backend |
149 | 215 | def test_chat_agent_stored_messages(): |
150 | 216 | system_msg = BaseMessage( |
|
0 commit comments