Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions src/core/database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ def _init_tables(self) -> None:
CREATE INDEX IF NOT EXISTS idx_tool_calls_session ON tool_calls(session_id);
CREATE INDEX IF NOT EXISTS idx_tool_calls_func ON tool_calls(func_name);
CREATE INDEX IF NOT EXISTS idx_file_snapshots_audit ON file_snapshots(audit_status);

CREATE TABLE IF NOT EXISTS tool_call_summaries (
session_id INTEGER NOT NULL,
func_name TEXT NOT NULL,
kwargs_json TEXT NOT NULL,
result TEXT NOT NULL,
timestamp REAL NOT NULL,
PRIMARY KEY (session_id, func_name, kwargs_json),
FOREIGN KEY (session_id) REFERENCES sessions(id)
);

CREATE INDEX IF NOT EXISTS idx_tool_call_summaries_session ON tool_call_summaries(session_id);
"""
)
conn.commit()
Expand All @@ -120,8 +132,7 @@ def _init_tables(self) -> None:
# Ensure file_read_records index exists (for fresh databases or after migration).
# Must run after migrations since the column may only exist after Phase 4.
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_file_read_records_session_path "
"ON file_read_records(session_id, file_path)"
"CREATE INDEX IF NOT EXISTS idx_file_read_records_session_path ON file_read_records(session_id, file_path)"
)
conn.commit()

Expand Down Expand Up @@ -368,3 +379,33 @@ def reset_instances(cls) -> None:
for instance in cls._instances.values():
instance.close()
cls._instances.clear()

# -- Tool call summaries --

def record_tool_call_summary(
self,
session_id: int,
func_name: str,
kwargs_json: str,
result: str,
) -> None:
with self._write_lock:
conn = self._get_connection()
conn.execute(
"INSERT INTO tool_call_summaries "
"(session_id, func_name, kwargs_json, result, timestamp) "
"VALUES (?, ?, ?, ?, ?) "
"ON CONFLICT(session_id, func_name, kwargs_json) DO UPDATE SET "
"result = excluded.result, "
"timestamp = excluded.timestamp",
(session_id, func_name, kwargs_json, result, time.time()),
)
conn.commit()

def get_tool_call_summaries(self, session_id: int) -> list[tuple]:
"""Get all tool call summaries for a session ordered by timestamp DESC."""
return self.fetchall(
"SELECT session_id, func_name, kwargs_json, result, timestamp "
"FROM tool_call_summaries WHERE session_id = ? ORDER BY timestamp DESC",
(session_id,),
)
20 changes: 20 additions & 0 deletions src/core/tool_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import nest_asyncio

from src.console.result_manager import _to_string
from src.utils.string_snapshot import truncate_string
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace
Expand Down Expand Up @@ -205,6 +206,7 @@ def execute(self, func_name: str, *args: Any, **kwargs: Any) -> Any:

duration_ms = (time.perf_counter() - start_time) * 1000
self._log_tool_call(func_name, kwargs, duration_ms, status)
self._record_tool_call_summary(func_name, kwargs, result)

return self._compress_result(result)
else:
Expand Down Expand Up @@ -280,3 +282,21 @@ def _log_tool_call(self, func_name: str, kwargs: dict, duration_ms: float, statu
except Exception:
pass
return f"ToolRegistry(sync_tools={len(self._tools)})"

def _record_tool_call_summary(self, func_name: str, kwargs: dict, result: Any) -> None:
session_id = getattr(self, "_current_session_id", None)
if session_id is None:
return

# Exclude write tools
if func_name in {"write", "edit", "confirm_edit"}:
return

try:
kwargs_json = self._compute_kwargs_json(kwargs)
workspace = getattr(self, "_workspace", None)
if workspace is not None:
result_str = _to_string(result)
workspace.db.record_tool_call_summary(session_id, func_name, kwargs_json, result_str)
except Exception:
pass
150 changes: 150 additions & 0 deletions tests/core/test_tool_call_summaries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Test tool_call_summaries table functionality."""

import os
import sqlite3
import time

import pytest

from src.core.database_manager import DatabaseManager


@pytest.fixture
def temp_db(tmp_path):
"""Create a temporary database for testing."""
temp_workspace = tmp_path / "workspace"
os.makedirs(temp_workspace, exist_ok=True)
db = DatabaseManager(str(temp_workspace))
yield db
db.close()
DatabaseManager.reset_instances()


class TestToolCallSummariesTable:
def test_table_exists(self, temp_db):
cursor = temp_db._get_connection()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='tool_call_summaries'"
).fetchall()
assert len(tables) == 1
assert tables[0][0] == "tool_call_summaries"

def test_primary_key_constraints(self, temp_db):
session_id = temp_db.create_session()
kwargs_json = '{"file_path": "test.txt"}'

# Insert first record
temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result1")

# Insert with same primary key (should update)
time.sleep(0.01)
temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result2")

# Should only have one record
summaries = temp_db.get_tool_call_summaries(session_id)
assert len(summaries) == 1
assert summaries[0][3] == "result2" # result updated
assert summaries[0][4] > time.time() - 1 # timestamp updated

def test_different_kwargs_create_new_records(self, temp_db):
session_id = temp_db.create_session()
kwargs1 = '{"file_path": "test.txt"}'
kwargs2 = '{"file_path": "other.txt"}'

temp_db.record_tool_call_summary(session_id, "read", kwargs1, "result1")
temp_db.record_tool_call_summary(session_id, "read", kwargs2, "result2")

summaries = temp_db.get_tool_call_summaries(session_id)
assert len(summaries) == 2

def test_different_func_names_create_new_records(self, temp_db):
session_id = temp_db.create_session()
kwargs_json = '{"query": "test"}'

temp_db.record_tool_call_summary(session_id, "search", kwargs_json, "result1")
temp_db.record_tool_call_summary(session_id, "stat", kwargs_json, "result2")

summaries = temp_db.get_tool_call_summaries(session_id)
assert len(summaries) == 2

def test_get_tool_call_summaries_ordered_by_timestamp(self, temp_db):
session_id = temp_db.create_session()
kwargs_json = '{"file_path": "test.txt"}'

time.sleep(0.01)
temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result1")
time.sleep(0.01)
temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result2")
time.sleep(0.01)
temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result3")

summaries = temp_db.get_tool_call_summaries(session_id)
assert len(summaries) == 1

def test_get_tool_call_summaries_from_different_sessions(self, temp_db):
session_id1 = temp_db.create_session()
session_id2 = temp_db.create_session()
kwargs_json = '{"file_path": "test.txt"}'

temp_db.record_tool_call_summary(session_id1, "read", kwargs_json, "result1")
temp_db.record_tool_call_summary(session_id2, "read", kwargs_json, "result2")

summaries1 = temp_db.get_tool_call_summaries(session_id1)
summaries2 = temp_db.get_tool_call_summaries(session_id2)

assert len(summaries1) == 1
assert len(summaries2) == 1
assert summaries1[0][3] != summaries2[0][3]

def test_session_foreign_key_constraint(self, temp_db):
kwargs_json = '{"file_path": "test.txt"}'
# Trying to insert with non-existent session should raise FK constraint error
# because PRAGMA foreign_keys=ON is enabled
with pytest.raises(sqlite3.IntegrityError):
temp_db.record_tool_call_summary(999, "read", kwargs_json, "result1")

# No record should be inserted
summaries = temp_db.get_tool_call_summaries(999)
assert len(summaries) == 0

def test_index_exists(self, temp_db):
cursor = temp_db._get_connection()
indexes = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='index' AND name='idx_tool_call_summaries_session'"
).fetchall()
assert len(indexes) == 1


class TestToolCallSummariesIntegration:
def test_record_tool_call_summary_stores_correct_data(self, temp_db):
session_id = temp_db.create_session()
func_name = "read"
kwargs_json = '{"file_path": "test.txt"}'
result = "file content here"

temp_db.record_tool_call_summary(session_id, func_name, kwargs_json, result)

summaries = temp_db.get_tool_call_summaries(session_id)
assert len(summaries) == 1
assert summaries[0][0] == session_id
assert summaries[0][1] == func_name
assert summaries[0][2] == kwargs_json
assert summaries[0][3] == result
assert summaries[0][4] > 0 # timestamp

def test_multiple_sessions_isolated(self, temp_db):
session_id1 = temp_db.create_session()
session_id2 = temp_db.create_session()

kwargs_json = '{"file_path": "test.txt"}'
result1 = "result for session 1"
result2 = "result for session 2"

temp_db.record_tool_call_summary(session_id1, "read", kwargs_json, result1)
temp_db.record_tool_call_summary(session_id2, "read", kwargs_json, result2)

summaries1 = temp_db.get_tool_call_summaries(session_id1)
summaries2 = temp_db.get_tool_call_summaries(session_id2)

assert summaries1[0][3] == result1
assert summaries2[0][3] == result2