Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions backend/app/services/report_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,12 @@ def to_dict(self) -> Dict[str, Any]:
- 不要添加模拟中不存在的信息
- 如果某方面信息不足,如实说明

5. 【禁止捏造数据】
- ❌ 禁止捏造用户名、引用、统计数字或互动数据
- ❌ 禁止在回复中包含 <tool_result> 块 — 只有系统会提供工具结果
- ✅ 只能引用真实出现在工具结果中的实体、引用和数据
- 如果工具结果中没有相关内容,应如实说明,而非编造

═══════════════════════════════════════════════════════════════
【⚠️ 格式规范 - 极其重要!】
═══════════════════════════════════════════════════════════════
Expand Down Expand Up @@ -1133,7 +1139,26 @@ def _get_tools_description(self) -> str:
if params_desc:
desc_parts.append(f" 参数: {params_desc}")
return "\n".join(desc_parts)


@staticmethod
def _strip_fake_tool_results(response: str) -> str:
"""Strip any <tool_result> blocks the LLM fabricated in its response.

When the LLM generates a <tool_call> block and then continues to generate
a <tool_result> block in the same response, we must strip the fake result
before appending to message history. The real tool result will be injected
separately by the system.
"""
import re
cleaned = re.sub(
r'<tool_result>.*?</tool_result>',
'',
response,
flags=re.DOTALL,
)
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned)
return cleaned.strip()

def plan_outline(
self,
progress_callback: Optional[Callable] = None
Expand Down Expand Up @@ -1335,7 +1360,8 @@ def _generate_section_react(

if conflict_retries <= 2:
# 前两次:丢弃本次响应,要求 LLM 重新回复
messages.append({"role": "assistant", "content": response})
cleaned_response = ReportAgent._strip_fake_tool_results(response)
messages.append({"role": "assistant", "content": cleaned_response})
messages.append({
"role": "user",
"content": (
Expand Down Expand Up @@ -1375,7 +1401,8 @@ def _generate_section_react(
if has_final_answer:
# 工具调用次数不足,拒绝并要求继续调工具
if tool_calls_count < min_tool_calls:
messages.append({"role": "assistant", "content": response})
cleaned_response = ReportAgent._strip_fake_tool_results(response)
messages.append({"role": "assistant", "content": cleaned_response})
unused_tools = all_tools - used_tools
unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)})" if unused_tools else ""
messages.append({
Expand Down Expand Up @@ -1451,9 +1478,10 @@ def _generate_section_react(
unused_tools = all_tools - used_tools
unused_hint = ""
if unused_tools and tool_calls_count < self.MAX_TOOL_CALLS_PER_SECTION:
unused_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list="、".join(unused_tools))
unlock_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list="、".join(unused_tools))

messages.append({"role": "assistant", "content": response})
cleaned_response = ReportAgent._strip_fake_tool_results(response)
messages.append({"role": "assistant", "content": cleaned_response})
messages.append({
"role": "user",
"content": REACT_OBSERVATION_TEMPLATE.format(
Expand Down Expand Up @@ -1857,7 +1885,8 @@ def chat(
tool_calls_made.append(call)

# 将结果添加到消息
messages.append({"role": "assistant", "content": response})
cleaned_response = ReportAgent._strip_fake_tool_results(response)
messages.append({"role": "assistant", "content": cleaned_response})
observation = "\n".join([f"[{r['tool']}结果]\n{r['result']}" for r in tool_results])
messages.append({
"role": "user",
Expand Down
6 changes: 6 additions & 0 deletions backend/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short
1 change: 1 addition & 0 deletions backend/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests package marker"""
244 changes: 244 additions & 0 deletions backend/tests/test_file_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
"""Unit tests for file_parser module."""

import os
import tempfile
from pathlib import Path

import pytest

# Import the module directly to avoid Flask initialization issues
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

# Direct import of the module to avoid app package initialization
import importlib.util
spec = importlib.util.spec_from_file_location(
"file_parser",
Path(__file__).parent.parent / "app" / "utils" / "file_parser.py"
)
file_parser_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(file_parser_module)

_read_text_with_fallback = file_parser_module._read_text_with_fallback
FileParser = file_parser_module.FileParser
split_text_into_chunks = file_parser_module.split_text_into_chunks


class TestReadTextWithFallback:
"""Tests for _read_text_with_fallback function."""

def test_read_utf8_file(self):
"""Should read UTF-8 encoded file correctly."""
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.txt', delete=False) as f:
f.write("Hello, 你好, こんにちは")
path = f.name
try:
result = _read_text_with_fallback(path)
assert result == "Hello, 你好, こんにちは"
finally:
os.unlink(path)

def test_read_gbk_file_with_fallback(self):
"""Should read GBK encoded file using UTF-8 replacement when detection fails."""
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
content = "你好世界".encode('gbk')
f.write(content)
path = f.name
try:
result = _read_text_with_fallback(path)
# Result may be garbled if charset detection fails, but should not raise
assert len(result) > 0
finally:
os.unlink(path)

def test_read_latin1_file(self):
"""Should read Latin-1 encoded file using UTF-8 replacement when detection fails."""
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
content = "Héllo Wörld".encode('latin-1')
f.write(content)
path = f.name
try:
result = _read_text_with_fallback(path)
# Result may be garbled if charset detection fails, but should not raise
assert len(result) > 0
finally:
os.unlink(path)

def test_read_file_with_replacement_chars(self):
"""Should replace invalid characters instead of failing."""
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
content = b"Hello\x00\xff\xfeWorld"
f.write(content)
path = f.name
try:
result = _read_text_with_fallback(path)
assert "Hello" in result
assert "World" in result
finally:
os.unlink(path)


class TestFileParser:
"""Tests for FileParser class."""

def test_supported_extensions(self):
"""Should have correct supported extensions."""
assert '.pdf' in FileParser.SUPPORTED_EXTENSIONS
assert '.md' in FileParser.SUPPORTED_EXTENSIONS
assert '.markdown' in FileParser.SUPPORTED_EXTENSIONS
assert '.txt' in FileParser.SUPPORTED_EXTENSIONS

def test_extract_text_from_nonexistent_file(self):
"""Should raise FileNotFoundError for nonexistent file."""
with pytest.raises(FileNotFoundError):
FileParser.extract_text('/nonexistent/file.txt')

def test_extract_text_from_unsupported_format(self):
"""Should raise ValueError for unsupported format."""
with tempfile.NamedTemporaryFile(suffix='.xyz', delete=False) as f:
path = f.name
try:
with pytest.raises(ValueError, match="不支持的文件格式"):
FileParser.extract_text(path)
finally:
os.unlink(path)

def test_extract_text_from_md_file(self):
"""Should extract text from markdown file."""
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.md', delete=False) as f:
f.write("# Title\n\nThis is content.")
path = f.name
try:
result = FileParser.extract_text(path)
assert "# Title" in result
assert "This is content." in result
finally:
os.unlink(path)

def test_extract_text_from_txt_file(self):
"""Should extract text from txt file."""
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.txt', delete=False) as f:
f.write("Plain text content")
path = f.name
try:
result = FileParser.extract_text(path)
assert result == "Plain text content"
finally:
os.unlink(path)

def test_extract_from_multiple_with_all_valid(self):
"""Should extract from multiple valid files."""
files = []
try:
for i in range(3):
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write(f"Content {i}")
files.append(f.name)

result = FileParser.extract_from_multiple(files)
assert "Content 0" in result
assert "Content 1" in result
assert "Content 2" in result
assert "文档 1" in result
assert "文档 2" in result
assert "文档 3" in result
finally:
for path in files:
os.unlink(path)

def test_extract_from_multiple_with_invalid_file(self):
"""Should handle invalid file gracefully in batch mode."""
files = [
'/nonexistent/path.txt',
]
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write("Valid content")
files.append(f.name)

result = FileParser.extract_from_multiple(files)
assert "Valid content" in result
assert "提取失败" in result
finally:
for path in files[1:]:
if os.path.exists(path):
os.unlink(path)


class TestSplitTextIntoChunks:
"""Tests for split_text_into_chunks function."""

def test_short_text_returns_single_chunk(self):
"""Should return single chunk when text is shorter than chunk_size."""
text = "Short text"
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert len(result) == 1
assert result[0] == text

def test_empty_text_returns_empty_list(self):
"""Should return empty list for empty/whitespace text."""
assert split_text_into_chunks(" ", chunk_size=500, overlap=50) == []
assert split_text_into_chunks("", chunk_size=500, overlap=50) == []

def test_text_exactly_at_chunk_size(self):
"""Should return single chunk when text equals chunk_size."""
text = "a" * 500
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert len(result) == 1

def test_long_text_splits_into_multiple_chunks(self):
"""Should split long text into multiple chunks."""
text = "a" * 1000
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert len(result) >= 2

def test_chunks_have_overlap(self):
"""Should have overlapping content between consecutive chunks."""
text = "abcdefghij" * 100
chunks = split_text_into_chunks(text, chunk_size=100, overlap=20)
if len(chunks) >= 2:
assert chunks[0][-20:] == chunks[1][:20], "Chunks should overlap"

def test_chunks_preserve_content(self):
"""Should preserve all original content across chunks."""
text = "".join(str(i) for i in range(500))
chunks = split_text_into_chunks(text, chunk_size=100, overlap=10)
combined = "".join(chunks)
assert text[:400] in combined or all(c in combined for c in text[:400])

def test_chunk_size_parameter(self):
"""Should respect chunk_size parameter."""
text = "a" * 1000
result = split_text_into_chunks(text, chunk_size=100, overlap=0)
for chunk in result:
assert len(chunk) <= 100

def test_overlap_parameter(self):
"""Should respect overlap parameter."""
text = "abcdefghij" * 100
chunks = split_text_into_chunks(text, chunk_size=50, overlap=10)
if len(chunks) >= 2:
overlap_size = len(chunks[0]) - (len(chunks[0].rstrip()) - len(chunks[1].lstrip()))
assert overlap_size >= 5

def test_split_at_sentence_boundary(self):
"""Should try to split at sentence boundaries when possible."""
text = "第一句。第二句。第三句。" * 50
chunks = split_text_into_chunks(text, chunk_size=100, overlap=10)
for chunk in chunks:
if len(chunk) > 50:
assert chunk[-1] in "。.\n"

def test_last_chunk_may_be_smaller(self):
"""Should allow last chunk to be smaller than chunk_size."""
text = "a" * 550
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert any(len(chunk) < 500 for chunk in result)

def test_whitespace_only_chunks_filtered(self):
"""Should filter out whitespace-only chunks."""
text = "content" + " " * 600 + "more content"
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
for chunk in result:
assert chunk.strip()
Loading