Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,7 +1480,26 @@ def _clean_snapshot_in_memory(
if '- ' in result_text and '[ref=' in result_text:
cleaned_result = self._clean_snapshot_content(result_text)

# Update the message in memory storage
chat_history_block = getattr(
self.memory, "_chat_history_block", None
)
storage = getattr(chat_history_block, "storage", None)
if storage is None:
return

existing_records = storage.load()

# Remove records by UUID
updated_records = [
record
for record in existing_records
if record["uuid"] not in entry.record_uuids
]

# Recreate only the function result message with cleaned content.
# The assistant message with tool calls is already recorded
# separately by _record_assistant_tool_calls_from_requests and
# should not be modified here.
timestamp = (
entry.record_timestamps[0]
if entry.record_timestamps
Expand All @@ -1495,20 +1514,6 @@ def _clean_snapshot_in_memory(
result=cleaned_result,
tool_call_id=entry.tool_call_id,
)

chat_history_block = getattr(
self.memory, "_chat_history_block", None
)
storage = getattr(chat_history_block, "storage", None)
if storage is None:
return

existing_records = storage.load()
updated_records = [
record
for record in existing_records
if record["uuid"] not in entry.record_uuids
]
new_record = MemoryRecord(
message=cleaned_message,
role_at_backend=OpenAIBackendRole.FUNCTION,
Expand Down
225 changes: 153 additions & 72 deletions test/agents/test_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,20 @@ def test_chat_agent(model, step_call_count=3):
for i in range(step_call_count):
for user_msg in [user_msg_bm, user_msg_str]:
response = assistant.step(user_msg)
assert isinstance(
response.msgs, list
), f"Error in round {i + 1}"
assert isinstance(response.msgs, list), (
f"Error in round {i + 1}"
)
assert len(response.msgs) > 0, f"Error in round {i + 1}"
assert isinstance(
response.terminated, bool
), f"Error in round {i + 1}"
assert isinstance(response.terminated, bool), (
f"Error in round {i + 1}"
)
assert response.terminated is False, f"Error in round {i + 1}"
assert isinstance(
response.info, dict
), f"Error in round {i + 1}"
assert (
response.info['id'] is not None
), f"Error in round {i + 1}"
assert isinstance(response.info, dict), (
f"Error in round {i + 1}"
)
assert response.info['id'] is not None, (
f"Error in round {i + 1}"
)


@pytest.mark.model_backend
Expand Down Expand Up @@ -370,9 +370,9 @@ def test_chat_agent_step_with_external_tools(step_call_count=3):
external_tool_call_requests = response.info[
"external_tool_call_requests"
]
assert (
external_tool_call_requests[0].tool_name == "math_subtract"
), f"Error in calling round {i + 1}"
assert external_tool_call_requests[0].tool_name == "math_subtract", (
f"Error in calling round {i + 1}"
)


@pytest.mark.model_backend
Expand Down Expand Up @@ -514,9 +514,9 @@ async def mock_arun(*args, **kwargs):
external_tool_call_requests = response.info[
"external_tool_call_requests"
]
assert (
external_tool_call_requests[0].tool_name == "math_subtract"
), f"Error in calling round {i + 1}"
assert external_tool_call_requests[0].tool_name == "math_subtract", (
f"Error in calling round {i + 1}"
)


@pytest.mark.model_backend
Expand Down Expand Up @@ -657,18 +657,18 @@ def test_chat_agent_multiple_return_messages(n, step_call_count=3):
)

for i in range(step_call_count):
assert (
assistant_with_sys_msg_response.msgs is not None
), f"Error in calling round {i + 1}"
assert (
len(assistant_with_sys_msg_response.msgs) == n
), f"Error in calling round {i + 1}"
assert (
assistant_without_sys_msg_response.msgs is not None
), f"Error in calling round {i + 1}"
assert (
len(assistant_without_sys_msg_response.msgs) == n
), f"Error in calling round {i + 1}"
assert assistant_with_sys_msg_response.msgs is not None, (
f"Error in calling round {i + 1}"
)
assert len(assistant_with_sys_msg_response.msgs) == n, (
f"Error in calling round {i + 1}"
)
assert assistant_without_sys_msg_response.msgs is not None, (
f"Error in calling round {i + 1}"
)
assert len(assistant_without_sys_msg_response.msgs) == n, (
f"Error in calling round {i + 1}"
)


@pytest.mark.model_backend
Expand Down Expand Up @@ -753,12 +753,12 @@ def test_chat_agent_stream_output(step_call_count=3):
assert len(msg.content) > 0, f"Error in calling round {i + 1}"

stream_usage = stream_assistant_response.info["usage"]
assert (
stream_usage["completion_tokens"] > 0
), f"Error in calling round {i + 1}"
assert (
stream_usage["prompt_tokens"] > 0
), f"Error in calling round {i + 1}"
assert stream_usage["completion_tokens"] > 0, (
f"Error in calling round {i + 1}"
)
assert stream_usage["prompt_tokens"] > 0, (
f"Error in calling round {i + 1}"
)
assert (
stream_usage["total_tokens"]
== stream_usage["completion_tokens"]
Expand Down Expand Up @@ -1039,12 +1039,12 @@ def test_tool_calling_sync(step_call_count=3):
]

assert len(tool_calls) > 0, f"Error in calling round {i + 1}"
assert str(tool_calls[0]).startswith(
"Tool Execution"
), f"Error in calling round {i + 1}"
assert (
tool_calls[0].tool_name == "math_multiply"
), f"Error in calling round {i + 1}"
assert str(tool_calls[0]).startswith("Tool Execution"), (
f"Error in calling round {i + 1}"
)
assert tool_calls[0].tool_name == "math_multiply", (
f"Error in calling round {i + 1}"
)
assert tool_calls[0].args == {
"a": 2,
"b": 8,
Expand Down Expand Up @@ -1165,9 +1165,9 @@ async def test_tool_calling_math_async(step_call_count=3):

tool_calls = agent_response.info['tool_calls']

assert (
tool_calls[0].tool_name == "math_multiply"
), f"Error in calling round {i + 1}"
assert tool_calls[0].tool_name == "math_multiply", (
f"Error in calling round {i + 1}"
)
assert tool_calls[0].args == {
"a": 2,
"b": 8,
Expand Down Expand Up @@ -1254,16 +1254,16 @@ def mock_run_tool_calling_async(*args, **kwargs):
tool_calls = agent_response.info['tool_calls']

assert tool_calls, f"Error in calling round {i + 1}"
assert str(tool_calls[0]).startswith(
"Tool Execution"
), f"Error in calling round {i + 1}"
assert str(tool_calls[0]).startswith("Tool Execution"), (
f"Error in calling round {i + 1}"
)

assert (
tool_calls[0].tool_name == "async_sleep"
), f"Error in calling round {i + 1}"
assert tool_calls[0].args == {
'second': 1
}, f"Error in calling round {i + 1}"
assert tool_calls[0].tool_name == "async_sleep", (
f"Error in calling round {i + 1}"
)
assert tool_calls[0].args == {'second': 1}, (
f"Error in calling round {i + 1}"
)
assert tool_calls[0].result == 1, f"Error in calling round {i + 1}"


Expand Down Expand Up @@ -1294,9 +1294,9 @@ def test_response_words_termination(step_call_count=3):

assert agent.terminated, f"Error in calling round {i + 1}"
assert agent_response.terminated, f"Error in calling round {i + 1}"
assert (
"goodbye" in agent_response.info['termination_reasons'][0]
), f"Error in calling round {i + 1}"
assert "goodbye" in agent_response.info['termination_reasons'][0], (
f"Error in calling round {i + 1}"
)


def test_chat_agent_vision(step_call_count=3):
Expand Down Expand Up @@ -1362,9 +1362,9 @@ def test_chat_agent_vision(step_call_count=3):

for i in range(step_call_count):
agent_response = agent.step(user_msg)
assert (
agent_response.msgs[0].content == "Yes."
), f"Error in calling round {i + 1}"
assert agent_response.msgs[0].content == "Yes.", (
f"Error in calling round {i + 1}"
)


@pytest.mark.model_backend
Expand Down Expand Up @@ -1534,9 +1534,9 @@ async def test_chat_agent_async_stream_with_async_generator():
# Create an async generator that wraps the chunks
# This simulates what GeminiModel does with _wrap_async_stream_with_
# thought_preservation
async def mock_async_generator() -> (
AsyncGenerator[ChatCompletionChunk, None]
):
async def mock_async_generator() -> AsyncGenerator[
ChatCompletionChunk, None
]:
for chunk in chunks:
yield chunk

Expand All @@ -1563,12 +1563,12 @@ async def mock_async_generator() -> (

# Verify final response contains the accumulated content
final_response = responses[-1]
assert (
final_response.msg is not None
), "Final response should have a message"
assert (
"Hello" in final_response.msg.content
), "Final content should contain 'Hello'"
assert final_response.msg is not None, (
"Final response should have a message"
)
assert "Hello" in final_response.msg.content, (
"Final content should contain 'Hello'"
)


@pytest.mark.model_backend
Expand Down Expand Up @@ -1718,9 +1718,9 @@ def test_add(a: int, b: int) -> int:

call_count = 0

async def mock_async_generator() -> (
AsyncGenerator[ChatCompletionChunk, None]
):
async def mock_async_generator() -> AsyncGenerator[
ChatCompletionChunk, None
]:
nonlocal call_count
if call_count == 0:
call_count += 1
Expand Down Expand Up @@ -1837,3 +1837,84 @@ class MathResult(BaseModel):
assert len(responses) > 1, "Should receive multiple streaming chunks"
assert responses[-1].msg.parsed.answer == 6
assert responses[-1].msg.parsed.explanation


def test_clean_snapshot_in_memory():
"""Test that snapshot content is properly cleaned in memory.

This tests the _clean_snapshot_in_memory functionality which removes
stale snapshot markers and references from tool output messages stored
in memory. The cleaning preserves the assistant message (tool call request)
and only updates the function result message.
"""
from unittest.mock import MagicMock, patch

from camel.agents.chat_agent import _ToolOutputHistoryEntry
from camel.messages import FunctionCallingMessage

# Create a mock model to avoid API calls
mock_model = MagicMock()
mock_model.model_type = ModelType.DEFAULT
mock_model.model_config_dict = {}
mock_model.token_counter = None
mock_model.model_platform_name = "openai"

with patch.object(ChatAgent, '_init_model', return_value=mock_model):
agent = ChatAgent(
system_message="Test agent",
model=mock_model,
)

# Manually enable snapshot cleaning
agent._enable_snapshot_clean = True
agent._tool_output_history = []

# Create a mock memory storage
mock_storage = MagicMock()
mock_chat_history_block = MagicMock()
mock_chat_history_block.storage = mock_storage

agent.memory._chat_history_block = mock_chat_history_block

# Create a test entry with snapshot markers
test_uuid = "test-uuid-123"
test_timestamp = 1234567890.0
entry = _ToolOutputHistoryEntry(
tool_name="test_tool",
tool_call_id="call_123",
result_text="- Item 1 [ref=abc]\n- Item 2 [ref=def]\n",
record_uuids=[test_uuid],
record_timestamps=[test_timestamp],
cached=False,
)
agent._tool_output_history.append(entry)

# Mock the storage to return the existing record
mock_storage.load.return_value = [
{
"uuid": test_uuid,
"timestamp": test_timestamp,
"message": {"content": "- Item 1 [ref=abc]\n- Item 2 [ref=def]\n"},
}
]

# Call the clean function
agent._clean_snapshot_in_memory(entry)

# Verify storage was updated
assert mock_storage.clear.called
assert mock_storage.save.called

# Get the saved records
saved_records = mock_storage.save.call_args[0][0]

# Should have one record (the cleaned function result)
assert len(saved_records) == 1

# The record should be a function message with cleaned content
saved_record = saved_records[0]
assert saved_record["role_at_backend"] == OpenAIBackendRole.FUNCTION.value

# Verify entry was marked as cached
assert entry.cached is True
assert len(entry.record_uuids) == 1 # Single function result record
Loading