Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/core/audit_committer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def commit(self, snapshot_id: int, approved: bool = True) -> str:
else:
# 已有文件 — mtime 校验后 write_text
rel_path = str(resolved.relative_to(self.workspace.root_path))
record = db.get_file_read_record(rel_path)
record = db.get_file_read_record(_session_id, rel_path)
if record is not None:
stored_mtime = record[2]
stored_mtime = record[3]
current_mtime = resolved.stat().st_mtime
if abs(current_mtime - stored_mtime) > 0.001:
return f"ERROR: 文件已被外部修改,审核终止: {rel_path}.请重新读取后再批准."
Expand All @@ -81,7 +81,7 @@ def commit(self, snapshot_id: int, approved: bool = True) -> str:
rel_path = str(resolved.relative_to(self.workspace.root_path))
new_meta = FileTracker.get_file_meta(resolved)
if new_meta:
db.record_file_read(rel_path, new_meta["mtime"], new_meta["size"], new_meta["checksum"])
db.record_file_read(_session_id, rel_path, new_meta["mtime"], new_meta["size"], new_meta["checksum"])

# 5. 更新审计状态
db.update_snapshot_audit(snapshot_id, "APPROVED")
Expand Down
64 changes: 52 additions & 12 deletions src/core/database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,15 @@ def _init_tables(self) -> None:

CREATE TABLE IF NOT EXISTS file_read_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_path TEXT NOT NULL UNIQUE,
session_id INTEGER NOT NULL,
file_path TEXT NOT NULL,
mtime REAL NOT NULL,
size INTEGER NOT NULL DEFAULT 0,
checksum TEXT NOT NULL DEFAULT '',
last_read_at REAL NOT NULL,
read_count INTEGER NOT NULL DEFAULT 1
read_count INTEGER NOT NULL DEFAULT 1,
FOREIGN KEY (session_id) REFERENCES sessions(id),
UNIQUE(session_id, file_path)
);

CREATE TABLE IF NOT EXISTS file_snapshots (
Expand All @@ -98,7 +101,6 @@ def _init_tables(self) -> None:

CREATE INDEX IF NOT EXISTS idx_tool_calls_session ON tool_calls(session_id);
CREATE INDEX IF NOT EXISTS idx_tool_calls_func ON tool_calls(func_name);
CREATE INDEX IF NOT EXISTS idx_file_read_records_path ON file_read_records(file_path);
CREATE INDEX IF NOT EXISTS idx_file_snapshots_audit ON file_snapshots(audit_status);
"""
)
Expand All @@ -111,6 +113,18 @@ def _init_tables(self) -> None:
if any(row[1] == "args_hash" for row in conn.execute("PRAGMA table_info(tool_calls)")):
self._migrate_args_hash_to_kwargs(conn)

# Phase 4 migration: add session_id to file_read_records
if not any(row[1] == "session_id" for row in conn.execute("PRAGMA table_info(file_read_records)")):
self._migrate_file_read_records_add_session(conn)

# Ensure file_read_records index exists (for fresh databases or after migration).
# Must run after migrations since the column may only exist after Phase 4.
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_file_read_records_session_path "
"ON file_read_records(session_id, file_path)"
)
conn.commit()

@staticmethod
def _migrate_add_pending_content(conn: sqlite3.Connection) -> None:
try:
Expand All @@ -125,6 +139,30 @@ def _migrate_args_hash_to_kwargs(conn: sqlite3.Connection) -> None:
conn.execute("ALTER TABLE tool_calls RENAME COLUMN args_hash TO kwargs")
conn.commit()

@staticmethod
def _migrate_file_read_records_add_session(conn: sqlite3.Connection) -> None:
conn.executescript(
"""
DROP TABLE IF EXISTS file_read_records;

CREATE TABLE file_read_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL,
file_path TEXT NOT NULL,
mtime REAL NOT NULL,
size INTEGER NOT NULL DEFAULT 0,
checksum TEXT NOT NULL DEFAULT '',
last_read_at REAL NOT NULL,
read_count INTEGER NOT NULL DEFAULT 1,
FOREIGN KEY (session_id) REFERENCES sessions(id),
UNIQUE(session_id, file_path)
);

CREATE INDEX IF NOT EXISTS idx_file_read_records_session_path ON file_read_records(session_id, file_path);
"""
)
conn.commit()

def close(self) -> None:
if hasattr(self._thread_local, "connection") and self._thread_local.connection is not None:
self._thread_local.connection.close()
Expand Down Expand Up @@ -193,27 +231,28 @@ def update_tool_call_status(self, call_id: int, status: str, audit_status: str)

# -- File read records --

def record_file_read(self, file_path: str, mtime: float, size: int, checksum: str) -> None:
def record_file_read(self, session_id: int, file_path: str, mtime: float, size: int, checksum: str) -> None:
with self._write_lock:
conn = self._get_connection()
conn.execute(
"INSERT INTO file_read_records (file_path, mtime, size, checksum, last_read_at, read_count) "
"VALUES (?, ?, ?, ?, ?, 1) "
"ON CONFLICT(file_path) DO UPDATE SET "
"INSERT INTO file_read_records "
"(session_id, file_path, mtime, size, checksum, last_read_at, read_count) "
"VALUES (?, ?, ?, ?, ?, ?, 1) "
"ON CONFLICT(session_id, file_path) DO UPDATE SET "
"mtime = excluded.mtime, "
"size = excluded.size, "
"checksum = excluded.checksum, "
"last_read_at = excluded.last_read_at, "
"read_count = read_count + 1",
(file_path, mtime, size, checksum, time.time()),
(session_id, file_path, mtime, size, checksum, time.time()),
)
conn.commit()

def get_file_read_record(self, file_path: str) -> tuple | None:
def get_file_read_record(self, session_id: int, file_path: str) -> tuple | None:
return self.fetchone(
"SELECT id, file_path, mtime, size, checksum, last_read_at, read_count "
"FROM file_read_records WHERE file_path = ?",
(file_path,),
"SELECT id, session_id, file_path, mtime, size, checksum, last_read_at, read_count "
"FROM file_read_records WHERE session_id = ? AND file_path = ?",
(session_id, file_path),
)

# -- File snapshots --
Expand Down Expand Up @@ -302,6 +341,7 @@ def rename_session(self, session_id: int, name: str) -> None:
def delete_session(self, session_id: int) -> None:
self.execute("DELETE FROM tool_calls WHERE session_id = ?", (session_id,))
self.execute("DELETE FROM file_snapshots WHERE session_id = ?", (session_id,))
self.execute("DELETE FROM file_read_records WHERE session_id = ?", (session_id,))
self.execute("DELETE FROM sessions WHERE id = ?", (session_id,))

def get_tool_usage_ranking(self, session_id: int | None = None, limit: int = 10) -> list[tuple]:
Expand Down
16 changes: 12 additions & 4 deletions src/workspace/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,12 @@ def _record_read_meta(self, resolved_path: Path) -> None:
try:
meta = FileTracker.get_file_meta(resolved_path)
if meta:
rel_path = str(resolved_path.relative_to(self.workspace.root_path))
self.workspace.db.record_file_read(rel_path, meta["mtime"], meta["size"], meta["checksum"])
session_id = self.workspace._current_session_id
if session_id is not None:
rel_path = str(resolved_path.relative_to(self.workspace.root_path))
self.workspace.db.record_file_read(
session_id, rel_path, meta["mtime"], meta["size"], meta["checksum"]
)
except Exception:
pass

Expand All @@ -191,12 +195,16 @@ def _validate_mtime(self, resolved_path: Path) -> str | None:
if not resolved_path.exists():
return None

session_id = self.workspace._current_session_id
if session_id is None:
return None

rel_path = str(resolved_path.relative_to(self.workspace.root_path))
record = self.workspace.db.get_file_read_record(rel_path)
record = self.workspace.db.get_file_read_record(session_id, rel_path)
if record is None:
return None

stored_mtime = record[2]
stored_mtime = record[3]
current_mtime = resolved_path.stat().st_mtime

if abs(current_mtime - stored_mtime) > 0.001:
Expand Down
10 changes: 8 additions & 2 deletions tests/core/test_audit_committer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def reset_singletons():
@pytest.fixture
def workspace(tmp_path: Path) -> Workspace:
ws = Workspace(str(tmp_path))
ws._current_session_id = ws.db.create_session(name="test_session")
return ws


Expand All @@ -37,6 +38,7 @@ def _create_pending_snapshot(workspace: Workspace, file_path: str, pending_conte
"diff_content",
audit_status="PENDING_AUDIT",
pending_content=pending_content,
session_id=workspace._current_session_id,
)


Expand Down Expand Up @@ -72,7 +74,9 @@ def test_approve_existing_file_with_read(self, committer: AuditCommitter, worksp
target = workspace.root_path / "test.txt"
target.write_text("original", encoding="utf-8")

workspace.db.record_file_read("test.txt", target.stat().st_mtime, target.stat().st_size, "hash")
workspace.db.record_file_read(
workspace._current_session_id, "test.txt", target.stat().st_mtime, target.stat().st_size, "hash"
)

snapshot_id = _create_pending_snapshot(workspace, "test.txt", "updated content")
result = committer.commit(snapshot_id, approved=True)
Expand All @@ -94,7 +98,9 @@ def test_approve_mtime_mismatch_fails(self, committer: AuditCommitter, workspace
target = workspace.root_path / "test.txt"
target.write_text("original", encoding="utf-8")

workspace.db.record_file_read("test.txt", target.stat().st_mtime, target.stat().st_size, "hash")
workspace.db.record_file_read(
workspace._current_session_id, "test.txt", target.stat().st_mtime, target.stat().st_size, "hash"
)

# Modify file externally
time.sleep(0.1)
Expand Down
29 changes: 16 additions & 13 deletions tests/core/test_database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,27 +128,30 @@ def test_update_tool_call_status(self, db: DatabaseManager):

class TestFileReadRecords:
def test_record_file_read(self, db: DatabaseManager):
db.record_file_read("src/main.py", 1234567890.5, 1024, "abc123hash")
session_id = db.create_session()
db.record_file_read(session_id, "src/main.py", 1234567890.5, 1024, "abc123hash")

row = db.get_file_read_record("src/main.py")
row = db.get_file_read_record(session_id, "src/main.py")
assert row is not None
assert row[2] == 1234567890.5
assert row[3] == 1024
assert row[4] == "abc123hash"
assert row[3] == 1234567890.5
assert row[4] == 1024
assert row[5] == "abc123hash"

def test_record_file_read_upsert(self, db: DatabaseManager):
db.record_file_read("src/main.py", 1000.0, 100, "hash1")
db.record_file_read("src/main.py", 2000.0, 200, "hash2")
session_id = db.create_session()
db.record_file_read(session_id, "src/main.py", 1000.0, 100, "hash1")
db.record_file_read(session_id, "src/main.py", 2000.0, 200, "hash2")

row = db.get_file_read_record("src/main.py")
row = db.get_file_read_record(session_id, "src/main.py")
assert row is not None
assert row[2] == 2000.0
assert row[3] == 200
assert row[4] == "hash2"
assert row[6] == 2
assert row[3] == 2000.0
assert row[4] == 200
assert row[5] == "hash2"
assert row[7] == 2

def test_get_nonexistent_read_record(self, db: DatabaseManager):
row = db.get_file_read_record("nonexistent.py")
session_id = db.create_session()
row = db.get_file_read_record(session_id, "nonexistent.py")
assert row is None


Expand Down
1 change: 1 addition & 0 deletions tests/workspace/tools/test_edit_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def reset_singletons():
@pytest.fixture
def workspace(tmp_path: Path) -> Workspace:
ws = Workspace(str(tmp_path))
ws._current_session_id = ws.db.create_session(name="test_session")
return ws


Expand Down
17 changes: 9 additions & 8 deletions tests/workspace/tools/test_read_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def reset_singletons():
@pytest.fixture
def workspace(tmp_path: Path) -> Workspace:
ws = Workspace(str(tmp_path))
ws._current_session_id = ws.db.create_session(name="test_session")
return ws


Expand All @@ -33,9 +34,9 @@ def test_read_records_mtime(self, workspace: Workspace, tmp_path: Path):
result = tool.read("test.txt")

assert "hello" in result
record = workspace.db.get_file_read_record("test.txt")
record = workspace.db.get_file_read_record(workspace._current_session_id, "test.txt")
assert record is not None
assert record[2] == pytest.approx(file.stat().st_mtime, abs=0.01)
assert record[3] == pytest.approx(file.stat().st_mtime, abs=0.01)

def test_read_records_checksum(self, workspace: Workspace, tmp_path: Path):
file = tmp_path / "test.txt"
Expand All @@ -47,10 +48,10 @@ def test_read_records_checksum(self, workspace: Workspace, tmp_path: Path):
tool = ReadTool(workspace)
tool.read("test.txt")

record = workspace.db.get_file_read_record("test.txt")
record = workspace.db.get_file_read_record(workspace._current_session_id, "test.txt")
assert record is not None
expected_checksum = hashlib.blake2b(content.encode("utf-8")).hexdigest()
assert record[4] == expected_checksum
assert record[5] == expected_checksum

def test_read_nonexistent_file_no_db_record(self, workspace: Workspace):
from src.workspace.tools.read_tool import ReadTool
Expand All @@ -59,7 +60,7 @@ def test_read_nonexistent_file_no_db_record(self, workspace: Workspace):
result = tool.read("nonexistent.txt")

assert "error" in result.lower() or "Error" in result
record = workspace.db.get_file_read_record("nonexistent.txt")
record = workspace.db.get_file_read_record(workspace._current_session_id, "nonexistent.txt")
assert record is None

def test_read_updates_read_count(self, workspace: Workspace, tmp_path: Path):
Expand All @@ -72,9 +73,9 @@ def test_read_updates_read_count(self, workspace: Workspace, tmp_path: Path):
tool.read("test.txt")
tool.read("test.txt")

record = workspace.db.get_file_read_record("test.txt")
record = workspace.db.get_file_read_record(workspace._current_session_id, "test.txt")
assert record is not None
assert record[6] == 2
assert record[7] == 2


class TestReadLinesToolMtimeRecording:
Expand All @@ -88,5 +89,5 @@ def test_read_lines_records_mtime(self, workspace: Workspace, tmp_path: Path):
result = tool.read_lines("test.txt", 1, 2)

assert "line1" in result
record = workspace.db.get_file_read_record("test.txt")
record = workspace.db.get_file_read_record(workspace._current_session_id, "test.txt")
assert record is not None
5 changes: 3 additions & 2 deletions tests/workspace/tools/test_write_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def reset_singletons():
@pytest.fixture
def workspace(tmp_path: Path) -> Workspace:
ws = Workspace(str(tmp_path))
ws._current_session_id = ws.db.create_session(name="test_session")
return ws


Expand Down Expand Up @@ -102,8 +103,8 @@ def test_write_modified_externally_fails(self, write_tool, read_tool, tmp_path:
time.sleep(0.1)

new_mtime = file.stat().st_mtime
read_record = read_tool.workspace.db.get_file_read_record("test.txt")
stored_mtime = read_record[2] if read_record else None
read_record = read_tool.workspace.db.get_file_read_record(read_tool.workspace._current_session_id, "test.txt")
stored_mtime = read_record[3] if read_record else None

if stored_mtime and abs(new_mtime - stored_mtime) < 0.001:
file.write_text("modified externally again", encoding="utf-8")
Expand Down