Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/workspace/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,56 @@ def _record_read_meta(self, resolved_path: Path) -> None:
self.workspace.db.record_file_read(rel_path, meta["mtime"], meta["size"], meta["checksum"])
except Exception:
pass

def _validate_mtime(self, resolved_path: Path) -> str | None:
"""校验文件自上次读取后是否被外部修改."""
if not resolved_path.exists():
return None

rel_path = str(resolved_path.relative_to(self.workspace.root_path))
record = self.workspace.db.get_file_read_record(rel_path)
if record is None:
return None

stored_mtime = record[2]
current_mtime = resolved_path.stat().st_mtime

if abs(current_mtime - stored_mtime) > 0.001:
return (
f"ERROR: FILE_MODIFIED_EXTERNALLY - "
f'The file "{rel_path}" was modified externally since last read. '
f'Please re-read the file with the "read" tool before editing it.'
)
return None

@staticmethod
def _generate_diff(old_content: str, new_content: str, file_path: str) -> str:
import difflib

old_lines = old_content.splitlines(keepends=True)
new_lines = new_content.splitlines(keepends=True)
diff = difflib.unified_diff(old_lines, new_lines, fromfile=f"a/{file_path}", tofile=f"b/{file_path}")
return "".join(diff)

@staticmethod
def handle_tool_exceptions(func):
"""工具方法异常处理装饰器."""
from functools import wraps

from src.models.tool_error_response import ToolErrorResponse
from src.workspace.path_validator import PathNotFoundError, WorkspaceBoundaryError

@wraps(func)
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except PathNotFoundError as err1:
return ToolErrorResponse(self.__class__.__name__, err1).to_str()
except WorkspaceBoundaryError as err2:
return ToolErrorResponse(self.__class__.__name__, err2).to_str()
except PermissionError as err3:
return ToolErrorResponse(self.__class__.__name__, err3).to_str()
except Exception as err:
return ToolErrorResponse(self.__class__.__name__, err).to_str()

return wrapper
197 changes: 78 additions & 119 deletions src/workspace/tools/edit_tool.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""安全的字符串替换编辑工具 — 只发布待审核更改."""

import difflib
from pathlib import Path

from src.models.tool_error_response import ToolErrorResponse
from src.workspace.path_validator import PathNotFoundError, WorkspaceBoundaryError
from src.workspace.tools.base_tool import BaseTool
from src.workspace.workspace import Workspace

Expand All @@ -21,6 +19,7 @@ def __init__(self, workspace: Workspace):
self.func = self.edit
self.params = BaseTool.extract_params(self.edit)

@BaseTool.handle_tool_exceptions
def edit(
self,
file_path: str,
Expand All @@ -45,117 +44,84 @@ def edit(
context_before: 匹配前的上下文文本(可选,用于校验)
context_after: 匹配后的上下文文本(可选,用于校验)
"""
try:
# 1. 参数校验
if not old_string:
return ToolErrorResponse(self.__class__.__name__, ValueError("old_string 不能为空")).to_str()

if max_replacements < 1:
return ToolErrorResponse(self.__class__.__name__, ValueError("max_replacements 必须 >= 1")).to_str()
if max_replacements > 100:
max_replacements = 100

# 2. 路径解析
source_file_path = Path(file_path)
resolved_path: Path = self.workspace.path_validator.resolve_path(source_file_path)

if not resolved_path.is_file():
return ToolErrorResponse(
self.__class__.__name__,
FileNotFoundError(f"文件不存在: {resolved_path}"),
).to_str()

# 3. mtime 校验
mtime_error = self._validate_mtime(resolved_path)
if mtime_error:
return mtime_error

# 4. 读取文件内容
old_content = resolved_path.read_text(encoding="utf-8")

# 5. 查找匹配
count = 0
idx = 0
while count < max_replacements:
idx = old_content.find(old_string, idx)
if idx == -1:
break
count += 1

# 上下文校验
if context_before or context_after:
ctx_error = self._check_context(old_content, idx, old_string, context_before, context_after, count)
if ctx_error:
return ctx_error

idx += len(old_string)

if count == 0:
return (
f"No changes made: old_string not found in file.\nFile: {file_path}\nSearching for: '{old_string}'"
)

# 6. 执行替换(生成新内容)
new_content = old_content.replace(old_string, new_string, count)

# 7. 生成 diff
rel_path = str(resolved_path.relative_to(self.workspace.root_path))
diff_content = self._generate_diff(old_content, new_content, rel_path)

# 8. 记录快照
from src.core.file_tracker import FileTracker

old_hash = FileTracker.compute_checksum_from_string(old_content)
new_hash = FileTracker.compute_checksum_from_string(new_content)
session_id = self.workspace._current_session_id
snapshot_id = self.workspace.db.record_file_snapshot(
rel_path,
old_hash,
new_hash,
diff_content,
audit_status="PENDING_AUDIT",
session_id=session_id,
pending_content=new_content,
)

# 9. 返回预览
return (
f"[Edit Preview]\n"
f"File: {rel_path}\n"
f"Snapshot ID: {snapshot_id}\n"
f"Replacements: {count}\n"
f"Diff:\n{diff_content}"
)

except PathNotFoundError as err1:
return ToolErrorResponse(self.__class__.__name__, err1).to_str()
except WorkspaceBoundaryError as err2:
return ToolErrorResponse(self.__class__.__name__, err2).to_str()
except PermissionError as err3:
return ToolErrorResponse(self.__class__.__name__, err3).to_str()
except Exception as err:
return ToolErrorResponse(self.__class__.__name__, err).to_str()

def _validate_mtime(self, resolved_path: Path) -> str | None:
"""校验文件自上次读取后是否被外部修改."""
if not resolved_path.exists():
return None

# 1. 参数校验
if not old_string:
return ToolErrorResponse(self.__class__.__name__, ValueError("old_string 不能为空")).to_str()

if max_replacements < 1:
return ToolErrorResponse(self.__class__.__name__, ValueError("max_replacements 必须 >= 1")).to_str()
if max_replacements > 100:
max_replacements = 100

# 2. 路径解析
source_file_path = Path(file_path)
resolved_path: Path = self.workspace.path_validator.resolve_path(source_file_path)

if not resolved_path.is_file():
return ToolErrorResponse(
self.__class__.__name__,
FileNotFoundError(f"文件不存在: {resolved_path}"),
).to_str()

# 3. mtime 校验
mtime_error = self._validate_mtime(resolved_path)
if mtime_error:
return mtime_error

# 4. 读取文件内容
old_content = resolved_path.read_text(encoding="utf-8")

# 5. 查找匹配
count = 0
idx = 0
while count < max_replacements:
idx = old_content.find(old_string, idx)
if idx == -1:
break
count += 1

# 上下文校验
if context_before or context_after:
ctx_error = self._check_context(old_content, idx, old_string, context_before, context_after, count)
if ctx_error:
return ctx_error

idx += len(old_string)

if count == 0:
return f"No changes made: old_string not found in file.\nFile: {file_path}\nSearching for: '{old_string}'"

# 6. 执行替换(生成新内容)
new_content = old_content.replace(old_string, new_string, count)

# 7. 生成 diff
rel_path = str(resolved_path.relative_to(self.workspace.root_path))
record = self.workspace.db.get_file_read_record(rel_path)
if record is None:
return None

stored_mtime = record[2]
current_mtime = resolved_path.stat().st_mtime

if abs(current_mtime - stored_mtime) > 0.001:
return (
f"ERROR: FILE_MODIFIED_EXTERNALLY - "
f'The file "{rel_path}" was modified externally since last read. '
f'Please re-read the file with the "read" tool before editing it.'
)
return None
diff_content = self._generate_diff(old_content, new_content, rel_path)

# 8. 记录快照
from src.core.file_tracker import FileTracker

old_hash = FileTracker.compute_checksum_from_string(old_content)
new_hash = FileTracker.compute_checksum_from_string(new_content)
session_id = self.workspace._current_session_id
snapshot_id = self.workspace.db.record_file_snapshot(
rel_path,
old_hash,
new_hash,
diff_content,
audit_status="PENDING_AUDIT",
session_id=session_id,
pending_content=new_content,
)

# 9. 返回预览
return (
f"[Edit Preview]\n"
f"File: {rel_path}\n"
f"Snapshot ID: {snapshot_id}\n"
f"Replacements: {count}\n"
f"Diff:\n{diff_content}"
)

@staticmethod
def _check_context(
Expand Down Expand Up @@ -195,10 +161,3 @@ def _check_context(
).to_str()

return None

@staticmethod
def _generate_diff(old_content: str, new_content: str, file_path: str) -> str:
old_lines = old_content.splitlines(keepends=True)
new_lines = new_content.splitlines(keepends=True)
diff = difflib.unified_diff(old_lines, new_lines, fromfile=f"a/{file_path}", tofile=f"b/{file_path}")
return "".join(diff)
Loading