Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/core/database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _init_tables(self) -> None:
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL,
func_name TEXT NOT NULL,
args_hash TEXT NOT NULL DEFAULT '',
kwargs TEXT NOT NULL DEFAULT '',
timestamp REAL NOT NULL,
duration_ms REAL NOT NULL DEFAULT 0.0,
status TEXT NOT NULL DEFAULT 'success',
Expand Down Expand Up @@ -107,6 +107,10 @@ def _init_tables(self) -> None:
# Phase 2 migration: add pending_content column to file_snapshots
self._migrate_add_pending_content(conn)

# Phase 3 migration: rename args_hash to kwargs and truncate old data
if any(row[1] == "args_hash" for row in conn.execute("PRAGMA table_info(tool_calls)")):
self._migrate_args_hash_to_kwargs(conn)

@staticmethod
def _migrate_add_pending_content(conn: sqlite3.Connection) -> None:
try:
Expand All @@ -115,6 +119,12 @@ def _migrate_add_pending_content(conn: sqlite3.Connection) -> None:
except sqlite3.OperationalError:
pass # Column already exists

@staticmethod
def _migrate_args_hash_to_kwargs(conn: sqlite3.Connection) -> None:
conn.execute("DELETE FROM tool_calls")
conn.execute("ALTER TABLE tool_calls RENAME COLUMN args_hash TO kwargs")
conn.commit()

def close(self) -> None:
if hasattr(self._thread_local, "connection") and self._thread_local.connection is not None:
self._thread_local.connection.close()
Expand Down Expand Up @@ -163,15 +173,15 @@ def log_tool_call(
self,
session_id: int,
func_name: str,
args_hash: str,
kwargs: str,
duration_ms: float = 0.0,
status: str = "success",
audit_status: str = "none",
) -> int:
cursor = self.execute(
"INSERT INTO tool_calls (session_id, func_name, args_hash, timestamp, duration_ms, status, audit_status) "
"INSERT INTO tool_calls (session_id, func_name, kwargs, timestamp, duration_ms, status, audit_status) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
(session_id, func_name, args_hash, time.time(), duration_ms, status, audit_status),
(session_id, func_name, kwargs, time.time(), duration_ms, status, audit_status),
)
return cursor.lastrowid

Expand Down
16 changes: 10 additions & 6 deletions src/core/tool_registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
import hashlib
import copy
import inspect
import json
import os
Expand All @@ -11,6 +11,7 @@

import nest_asyncio

from src.utils.string_snapshot import truncate_string
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace

Expand Down Expand Up @@ -239,16 +240,19 @@ def set_session_id(self, session_id: int) -> None:
self._current_session_id = session_id

@staticmethod
def _compute_args_hash(kwargs: dict) -> str:
sorted_json = json.dumps(kwargs, sort_keys=True, default=str)
return hashlib.blake2b(sorted_json.encode("utf-8")).hexdigest()
def _compute_kwargs_json(kwargs: dict) -> str:
truncated = copy.deepcopy(kwargs)
for key, value in truncated.items():
if isinstance(value, str) and len(value) > 256:
truncated[key] = truncate_string(value, max_length=256, suffix="...")
return json.dumps(truncated, sort_keys=True, default=str)

def _log_tool_call(self, func_name: str, kwargs: dict, duration_ms: float, status: str) -> str | None:
session_id = getattr(self, "_current_session_id", None)
if session_id is None:
return None
try:
args_hash = self._compute_args_hash(kwargs)
kwargs_json = self._compute_kwargs_json(kwargs)
workspace = getattr(self, "_workspace", None)
if workspace is not None:
# Determine audit_status based on tool category
Expand All @@ -268,7 +272,7 @@ def _log_tool_call(self, func_name: str, kwargs: dict, duration_ms: float, statu
workspace.db.log_tool_call(
session_id,
func_name,
args_hash,
kwargs_json,
duration_ms=duration_ms,
status=status,
audit_status=audit_status,
Expand Down
4 changes: 3 additions & 1 deletion tests/core/test_database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_log_tool_call_stored_correctly(self, db: DatabaseManager):
db.log_tool_call(session_id, "write", "def456", duration_ms=120.0, status="error", audit_status="none")

row = db.fetchone(
"SELECT func_name, args_hash, duration_ms, status FROM tool_calls WHERE session_id = ?",
"SELECT func_name, kwargs, duration_ms, status FROM tool_calls WHERE session_id = ?",
(session_id,),
)
assert row is not None
Expand Down Expand Up @@ -297,6 +297,8 @@ def write_session(name):
db.create_session(name=name)
except Exception as e:
errors.append(e)
finally:
db.close() # 确保每个线程关闭自己的连接,避免 ResourceWarning

threads = [threading.Thread(target=write_session, args=(f"thread_{i}",)) for i in range(5)]
for t in threads:
Expand Down
5 changes: 5 additions & 0 deletions tests/core/test_tool_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
@pytest.fixture(autouse=True)
def isolate_tool_registry():
"""自动在每个测试前后隔离 ToolRegistry 单例"""
from src.core.database_manager import DatabaseManager

# 保存原始实例
original_instance = ToolRegistry._instance

Expand All @@ -27,6 +29,9 @@ def isolate_tool_registry():
# 测试后恢复
ToolRegistry._instance = original_instance

# 清理数据库连接,避免 ResourceWarning
DatabaseManager.reset_instances()


def test_validate_config():
"""测试配置验证 - 一次性测试所有阈值"""
Expand Down