From 3eb5354de25d7ad3256ff7476518424eca7caf08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 13 Apr 2026 14:40:03 +0800 Subject: [PATCH 01/70] bump version --- src/sirchmunk/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sirchmunk/version.py b/src/sirchmunk/version.py index 87b826d..73c89eb 100644 --- a/src/sirchmunk/version.py +++ b/src/sirchmunk/version.py @@ -1 +1 @@ -__version__ = "0.0.7+main" +__version__ = "0.0.8+main" From b72a878c2947c63cae2fcda090353071cad4b29c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 13 Apr 2026 17:19:12 +0800 Subject: [PATCH 02/70] Introduce Sirchmunk Learnings (insights from pageindex and LLM wiki) --- src/sirchmunk/cli/cli.py | 250 ++++++ src/sirchmunk/learnings/README.md | 218 ++++++ src/sirchmunk/learnings/__init__.py | 29 +- src/sirchmunk/learnings/compiler.py | 840 +++++++++++++++++++++ src/sirchmunk/learnings/knowledge_base.py | 33 +- src/sirchmunk/learnings/lint.py | 213 ++++++ src/sirchmunk/learnings/tree_indexer.py | 444 +++++++++++ src/sirchmunk/llm/prompts.py | 87 +++ src/sirchmunk/schema/knowledge.py | 12 + src/sirchmunk/search.py | 113 +++ src/sirchmunk/storage/knowledge_storage.py | 69 +- 11 files changed, 2294 insertions(+), 14 deletions(-) create mode 100644 src/sirchmunk/learnings/README.md create mode 100644 src/sirchmunk/learnings/compiler.py create mode 100644 src/sirchmunk/learnings/lint.py create mode 100644 src/sirchmunk/learnings/tree_indexer.py diff --git a/src/sirchmunk/cli/cli.py b/src/sirchmunk/cli/cli.py index 8919732..9d09762 100644 --- a/src/sirchmunk/cli/cli.py +++ b/src/sirchmunk/cli/cli.py @@ -6,6 +6,7 @@ sirchmunk init - Initialize working directory + generate .env sirchmunk serve - Start the API server (backend only) sirchmunk search - Perform a search query + sirchmunk compile - Compile documents into knowledge indices sirchmunk web init - Build WebUI frontend (requires Node.js) sirchmunk web serve - Start API + WebUI (single port) sirchmunk web serve --dev - Start API + Next.js dev server (dual port) @@ -1225,6 +1226,207 @@ def cmd_mcp_version(args: argparse.Namespace) -> int: return 0 +# ------------------------------------------------------------------ +# sirchmunk compile +# ------------------------------------------------------------------ + +def cmd_compile(args: argparse.Namespace) -> int: + """Compile document collections into structured knowledge indices. + + Builds PageIndex-style tree indices and LLM Wiki-style knowledge + clusters for downstream search acceleration. + + Args: + args: Command-line arguments + + Returns: + Exit code (0 for success, non-zero for failure) + """ + try: + work_path = Path( + getattr(args, "work_path", None) or str(_get_default_work_path()) + ).expanduser().resolve() + os.environ["SIRCHMUNK_WORK_PATH"] = str(work_path) + + env_file = work_path / ".env" + if env_file.exists(): + _load_env_file(env_file) + + paths = args.paths or None + if not paths: + print(" Error: --paths is required for compile.") + print(" Usage: sirchmunk compile --paths /data/docs") + return 1 + + # Status mode + if getattr(args, "status", False): + return asyncio.run(_compile_status(paths, work_path)) + + # Lint mode + if getattr(args, "lint", False): + return asyncio.run(_compile_lint( + work_path, auto_fix=getattr(args, "fix", False), + )) + + # Normal compile + incremental = not getattr(args, "full", False) + return asyncio.run(_compile_run( + paths=paths, + work_path=work_path, + incremental=incremental, + max_files=getattr(args, "max_files", None), + concurrency=getattr(args, "concurrency", 3), + shallow=getattr(args, "shallow", False), + )) + + except KeyboardInterrupt: + print("\n Compile cancelled.") + return 130 + except Exception as e: + logger.error(f"Compile failed: {e}", exc_info=True) + print(f" Compile error: {e}") + return 1 + + +async def _compile_run( + paths: list, + work_path: Path, + incremental: bool = True, + max_files: Optional[int] = None, + concurrency: int = 3, + shallow: bool = False, +) -> int: + """Execute compile using AgenticSearch.""" + from sirchmunk.search import AgenticSearch + from sirchmunk.llm.openai_chat import OpenAIChat + + llm_api_key = os.getenv("LLM_API_KEY", "") + if not llm_api_key: + print(" LLM_API_KEY is not set.") + print(" Configure it in ~/.sirchmunk/.env or set the environment variable.") + return 1 + + llm = OpenAIChat( + base_url=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"), + api_key=llm_api_key, + model=os.getenv("LLM_MODEL_NAME", "gpt-5.2"), + ) + + searcher = AgenticSearch(llm=llm, work_path=str(work_path)) + + print("=" * 60) + print(" Sirchmunk Knowledge Compile") + print("=" * 60) + print() + print(f" Paths: {', '.join(paths)}") + print(f" Incremental: {incremental}") + if shallow: + print(" Mode: shallow (tree indexing skipped)") + if max_files: + print(f" Max files: {max_files} (importance sampling)") + print() + + report = await searcher.compile( + paths=paths, + incremental=incremental, + shallow=shallow, + max_files=max_files, + concurrency=concurrency, + ) + + print() + print("=" * 60) + print(" Compile Report") + print("=" * 60) + print() + print(f" Total files: {report.get('total_files', 0)}") + print(f" Files added: {report.get('files_added', 0)}") + print(f" Files modified: {report.get('files_modified', 0)}") + print(f" Files skipped: {report.get('files_skipped', 0)}") + if report.get("files_sampled"): + print(f" Files sampled: {report['files_sampled']}") + print(f" Trees built: {report.get('trees_built', 0)}") + print(f" Clusters created: {report.get('clusters_created', 0)}") + print(f" Clusters merged: {report.get('clusters_merged', 0)}") + print(f" Cross-refs: {report.get('cross_refs_built', 0)}") + print(f" Elapsed: {report.get('elapsed_seconds', 0):.1f}s") + if report.get("errors"): + print(f" Errors: {len(report['errors'])}") + for err in report["errors"][:5]: + print(f" - {err}") + print() + + return 0 + + +async def _compile_status(paths: list, work_path: Path) -> int: + """Show compile status.""" + from sirchmunk.search import AgenticSearch + from sirchmunk.llm.openai_chat import OpenAIChat + + llm = OpenAIChat( + base_url=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"), + api_key=os.getenv("LLM_API_KEY", ""), + model=os.getenv("LLM_MODEL_NAME", "gpt-5.2"), + ) + + searcher = AgenticSearch(llm=llm, work_path=str(work_path)) + status = await searcher.compile_status(paths=paths) + + print("=" * 60) + print(" Compile Status") + print("=" * 60) + print() + print(f" Compiled files: {status.get('total_compiled_files', 0)}") + print(f" Tree indices: {status.get('total_trees', 0)}") + print(f" Clusters: {status.get('total_clusters', 0)}") + print(f" Last compile: {status.get('last_compile_at', 'Never')}") + print() + + return 0 + + +async def _compile_lint(work_path: Path, auto_fix: bool = False) -> int: + """Run knowledge lint checks.""" + from sirchmunk.search import AgenticSearch + from sirchmunk.llm.openai_chat import OpenAIChat + + llm = OpenAIChat( + base_url=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"), + api_key=os.getenv("LLM_API_KEY", ""), + model=os.getenv("LLM_MODEL_NAME", "gpt-5.2"), + ) + + searcher = AgenticSearch(llm=llm, work_path=str(work_path)) + report = await searcher.compile_lint(auto_fix=auto_fix) + + print("=" * 60) + print(" Knowledge Lint Report") + print("=" * 60) + print() + print(f" Clusters checked: {report.get('total_clusters_checked', 0)}") + print(f" Trees checked: {report.get('total_trees_checked', 0)}") + print(f" Errors: {report.get('errors', 0)}") + print(f" Warnings: {report.get('warnings', 0)}") + if auto_fix: + print(f" Auto-fixes: {report.get('auto_fixes_applied', 0)}") + print() + + issues = report.get("issues", []) + if issues: + for issue in issues[:20]: + severity = issue.get("severity", "info").upper() + msg = issue.get("message", "") + cid = issue.get("cluster_id", "") + fixed = " [FIXED]" if issue.get("auto_fixed") else "" + print(f" [{severity}] {msg} {f'(cluster={cid})' if cid else ''}{fixed}") + if len(issues) > 20: + print(f" ... and {len(issues) - 20} more") + print() + + return 0 + + # ------------------------------------------------------------------ # sirchmunk upload # ------------------------------------------------------------------ @@ -1435,6 +1637,54 @@ def create_parser() -> argparse.ArgumentParser: ) search_parser.set_defaults(func=cmd_search) + # === compile command === + compile_parser = subparsers.add_parser( + "compile", + help="Compile document collections into knowledge indices", + description=( + "Compile documents into structured knowledge indices (tree + clusters). " + "Optional step after 'sirchmunk init'." + ), + ) + compile_parser.add_argument( + "--paths", nargs="+", required=True, + help="Directories or files to compile", + ) + compile_parser.add_argument( + "--full", action="store_true", default=False, + help="Force full recompile (ignore incremental cache)", + ) + compile_parser.add_argument( + "--max-files", type=int, default=None, + help="Max files to process (triggers importance sampling for large sets)", + ) + compile_parser.add_argument( + "--concurrency", type=int, default=3, + help="Max parallel file compilations (default: 3)", + ) + compile_parser.add_argument( + "--shallow", action="store_true", default=False, + help="Skip tree indexing — use direct LLM summarisation only (faster)", + ) + compile_parser.add_argument( + "--status", action="store_true", default=False, + help="Show compile status instead of running compile", + ) + compile_parser.add_argument( + "--lint", action="store_true", default=False, + help="Run knowledge health checks", + ) + compile_parser.add_argument( + "--fix", action="store_true", default=False, + help="Auto-fix lint issues (use with --lint)", + ) + compile_parser.add_argument( + "--work-path", + default=None, + help="Working directory (default: ~/.sirchmunk)", + ) + compile_parser.set_defaults(func=cmd_compile) + # === web command group === web_parser = subparsers.add_parser( "web", diff --git a/src/sirchmunk/learnings/README.md b/src/sirchmunk/learnings/README.md new file mode 100644 index 0000000..0fc1bbe --- /dev/null +++ b/src/sirchmunk/learnings/README.md @@ -0,0 +1,218 @@ +# Sirchmunk Learnings Module + +The `sirchmunk/learnings` module implements **knowledge compilation and continuous learning** capabilities. It houses the core logic for transforming raw document collections into structured, searchable knowledge networks. + +## Architecture Overview + +``` +learnings/ +├── __init__.py # Public API exports +├── knowledge_base.py # Runtime knowledge builder (search-time) +├── evidence_processor.py # Monte Carlo evidence sampling +├── compiler.py # Offline knowledge compiler (compile-time) +├── tree_indexer.py # PageIndex-style document tree indexer +├── lint.py # Knowledge network health checks +└── README.md # This file +``` + +### Design Philosophy + +The module fuses insights from three frameworks: + +1. **PageIndex** (VectifyAI) — Hierarchical tree indexing replaces brute-force vector search with LLM reasoning-based navigation. The key insight: *similarity ≠ relevance*. + +2. **LLM Wiki** (Karpathy) — Documents are not merely "indexed" but "compiled" into an interlinked knowledge network that compounds over time. Knowledge clusters grow richer with each compile cycle. + +3. **NotebookLM** (Google) — Strict source grounding ensures every claim traces back to original evidence. The `EvidenceUnit` system provides full provenance. + +### Compile vs. Search + +| Aspect | Compile (offline) | Search (runtime) | +|--------|-------------------|-------------------| +| **When** | `sirchmunk compile` | `sirchmunk search` | +| **Speed** | Minutes (batch) | Seconds (interactive) | +| **Purpose** | Build indices + knowledge | Answer queries | +| **Module** | `compiler.py` (uses `tree_indexer.py`) | `knowledge_base.py`, `evidence_processor.py` | +| **Required** | Optional | Always available | + +Compile products are automatically leveraged by search when present, but search functions independently without them. + +--- + +## Components + +### DocumentTreeIndexer (`tree_indexer.py`) + +Builds hierarchical JSON tree indices for structured long documents. + +**Key concepts:** +- Only triggers for documents ≥ 50KB in eligible formats (PDF, DOCX, MD, HTML, etc.) +- LLM analyzes document structure recursively (up to 4 levels deep) +- Each node stores: title, summary, character range +- Query-time navigation: LLM selects relevant branches instead of scanning everything + +**Data structures:** +- `TreeNode` — Single node with `node_id`, `title`, `summary`, `char_range`, `children` +- `DocumentTree` — Complete tree for a document, JSON-serializable, cached by file hash + +**Usage:** +```python +indexer = DocumentTreeIndexer(llm=llm, cache_dir=cache_path) + +# Build (async, LLM-powered) +tree = await indexer.build_tree(file_path, content, max_depth=4) + +# Navigate (async, LLM-powered branch selection) +leaves = await indexer.navigate(tree, query="How does X work?") +for leaf in leaves: + relevant_text = content[leaf.char_range[0]:leaf.char_range[1]] + +# Cache check (sync) +if indexer.has_tree(file_path): + tree = indexer.load_tree(file_path) +``` + +### KnowledgeCompiler (`compiler.py`) + +Orchestrates the unified compile pipeline. + +**Four-phase pipeline:** +1. **File Discovery & Change Detection** — Scans paths, compares with manifest for incremental processing +2. **Per-File Compile** — Unified pipeline per file: tree-if-eligible → summary → topics → rich evidence +3. **Knowledge Aggregation** — Merges into existing clusters or creates new ones (three-tier similarity) +4. **Cross-Reference Building** — Creates `WeakSemanticEdge` links between related clusters + +**Unified single-file pipeline:** +For each file, the compiler runs a single pipeline instead of separate "tree" and "wiki" modes: +- If the file is ≥ 50KB and in an eligible format, a tree is built first. The root node's summary is synthesized from children's section summaries via LLM, and `EvidenceUnit` snippets + `tree_path` are populated directly from tree leaves. +- If the file is small or `shallow=True`, a direct LLM summary is generated instead. +- In both cases, topics are extracted and a `KnowledgeCluster` is created/merged. + +**Three-tier similarity strategy:** +| Similarity | Action | +|-----------|--------| +| ≥ 0.80 | Merge into existing cluster, re-compute embedding | +| 0.50 – 0.79 | Create new cluster + build `embed_sim` weak edges | +| < 0.50 | Create standalone cluster | + +**Importance probability sampling** (`ImportanceSampler`): +For large datasets, select a representative subset using weighted random sampling: +- File size (log-scaled): larger files contain more information +- Novelty: uncompiled files get 4× weight over already-compiled ones +- Extension diversity: structured formats (PDF, DOCX) get 1.5× boost + +**Key data structures:** +- `CompileManifest` — Tracks file hashes and cluster associations for incremental compile +- `FileManifestEntry` — Per-file state (hash, compile timestamp, tree flag, cluster IDs) +- `CompileReport` — Statistics from a compile run +- `CompileStatus` — Quick status snapshot + +### KnowledgeLint (`lint.py`) + +Health checks for the knowledge network (inspired by LLM Wiki's Lint operation). + +**Checks performed:** +- **Empty clusters** — Clusters with minimal or no content +- **Stale evidence** — Evidence pointing to files that no longer exist +- **Orphan clusters** — Clusters with no evidence and no queries +- **Isolated clusters** — Clusters with no cross-references +- **Orphan trees** — Tree cache files without matching manifest entries +- **Stale manifest** — Manifest entries pointing to deleted files + +**Auto-fix capabilities:** +- Deprecate clusters where all evidence sources are gone +- Remove orphan tree cache files + +### KnowledgeBase (`knowledge_base.py`) + +Runtime knowledge builder used during search operations. + +**Tree-aware evidence extraction:** +When a tree index exists for a file, `_extract_evidence_for_file()` navigates to relevant sections first, then runs Monte Carlo sampling within those narrowed regions. This dramatically improves precision for large documents. + +### MonteCarloEvidenceSampling (`evidence_processor.py`) + +Statistical sampling for finding relevant regions in documents. Used both at compile-time and search-time. + +--- + +## CLI Interface + +```bash +# Compile documents (optional, after sirchmunk init) +sirchmunk compile --paths /data/docs /data/reports + +# Incremental compile (default, skips unchanged files) +sirchmunk compile --paths /data/docs + +# Full recompile +sirchmunk compile --paths /data/docs --full + +# Importance sampling for large datasets +sirchmunk compile --paths /data/docs --max-files 100 + +# Shallow mode: skip tree indexing, use direct LLM summarisation +sirchmunk compile --paths /data/docs --shallow + +# Check compile status +sirchmunk compile --paths /data/docs --status + +# Run health checks +sirchmunk compile --paths /data/docs --lint +sirchmunk compile --paths /data/docs --lint --fix +``` + +## Python SDK + +```python +from sirchmunk.search import AgenticSearch + +searcher = AgenticSearch(work_path="~/.sirchmunk") + +# Compile +report = await searcher.compile( + paths=["/data/docs"], + incremental=True, + shallow=False, # set True to skip tree indexing + max_files=100, # importance sampling + concurrency=3, +) + +# Status +status = await searcher.compile_status(paths=["/data/docs"]) + +# Lint +lint_report = await searcher.compile_lint(auto_fix=True) + +# Search (automatically uses compile products when available) +result = await searcher.search("query", paths=["/data/docs"]) +``` + +--- + +## Cache Layout + +``` +{work_path}/.cache/ +├── compile/ +│ ├── manifest.json # Compile manifest (incremental state) +│ └── trees/ +│ ├── {file_hash_1}.json # Tree index for document 1 +│ └── {file_hash_2}.json # Tree index for document 2 +└── knowledge/ + └── knowledge_clusters.parquet # Persistent cluster storage (DuckDB + Parquet) +``` + +## Schema Extensions + +The compile feature extends existing schemas: + +- **`EvidenceUnit`** — Added `tree_path` (node IDs from tree navigation) and `page_range` (character offsets) +- **`KnowledgeCluster`** — Added `merge_count` (tracks compile-time merge frequency for lifecycle promotion: ≥ 3 merges → `STABLE`) + +## Design Principles + +- **SOLID compliance**: Each class has a single responsibility; dependencies are injected via constructor +- **Optional by design**: Compile never breaks existing search functionality +- **Incremental**: Only processes changed files; manifest tracks state across runs +- **Production-ready**: Bounded concurrency, error isolation per file, graceful schema migration diff --git a/src/sirchmunk/learnings/__init__.py b/src/sirchmunk/learnings/__init__.py index 0829846..bc14211 100644 --- a/src/sirchmunk/learnings/__init__.py +++ b/src/sirchmunk/learnings/__init__.py @@ -1 +1,28 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. \ No newline at end of file +# Copyright (c) ModelScope Contributors. All rights reserved. + +from sirchmunk.learnings.compiler import ( + CompileManifest, + CompileReport, + CompileStatus, + ImportanceSampler, + KnowledgeCompiler, +) +from sirchmunk.learnings.lint import KnowledgeLint, LintReport +from sirchmunk.learnings.tree_indexer import ( + DocumentTree, + DocumentTreeIndexer, + TreeNode, +) + +__all__ = [ + "CompileManifest", + "CompileReport", + "CompileStatus", + "DocumentTree", + "DocumentTreeIndexer", + "ImportanceSampler", + "KnowledgeCompiler", + "KnowledgeLint", + "LintReport", + "TreeNode", +] \ No newline at end of file diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py new file mode 100644 index 0000000..3c2b0da --- /dev/null +++ b/src/sirchmunk/learnings/compiler.py @@ -0,0 +1,840 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Knowledge compiler — orchestrates offline compile of document collections. + +Fuses PageIndex (tree indexing) and LLM Wiki (knowledge compilation network) +into a single compile pipeline that produces structured tree indices and +knowledge clusters for downstream search acceleration. +""" + +import asyncio +import json +import math +import os +import random +import hashlib +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from sirchmunk.learnings.tree_indexer import ( + DocumentTree, + DocumentTreeIndexer, +) +from sirchmunk.llm.openai_chat import OpenAIChat +from sirchmunk.schema.knowledge import ( + AbstractionLevel, + EvidenceUnit, + KnowledgeCluster, + Lifecycle, + WeakSemanticEdge, +) +from sirchmunk.storage.knowledge_storage import KnowledgeStorage +from sirchmunk.utils import LogCallback, create_logger +from sirchmunk.utils.file_utils import fast_extract, get_fast_hash + +# Concurrency cap for LLM-heavy file processing +_DEFAULT_CONCURRENCY = 3 + +# Similarity threshold for merging into existing clusters during compile +_MERGE_SIMILARITY_THRESHOLD = 0.75 + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class FileManifestEntry: + """State of a single file in the compile manifest.""" + + file_hash: str + compiled_at: str + has_tree: bool + cluster_ids: List[str] + size_bytes: int + + def to_dict(self) -> Dict[str, Any]: + return { + "file_hash": self.file_hash, + "compiled_at": self.compiled_at, + "has_tree": self.has_tree, + "cluster_ids": self.cluster_ids, + "size_bytes": self.size_bytes, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "FileManifestEntry": + return cls( + file_hash=data["file_hash"], + compiled_at=data["compiled_at"], + has_tree=data.get("has_tree", False), + cluster_ids=data.get("cluster_ids", []), + size_bytes=data.get("size_bytes", 0), + ) + + +@dataclass +class CompileManifest: + """Tracks compiled file states for incremental processing.""" + + version: str = "1.0" + last_compile_at: Optional[str] = None + files: Dict[str, FileManifestEntry] = field(default_factory=dict) + + def to_json(self) -> str: + return json.dumps({ + "version": self.version, + "last_compile_at": self.last_compile_at, + "files": {k: v.to_dict() for k, v in self.files.items()}, + }, ensure_ascii=False, indent=2) + + @classmethod + def from_json(cls, json_str: str) -> "CompileManifest": + data = json.loads(json_str) + files = { + k: FileManifestEntry.from_dict(v) + for k, v in data.get("files", {}).items() + } + return cls( + version=data.get("version", "1.0"), + last_compile_at=data.get("last_compile_at"), + files=files, + ) + + +@dataclass +class FileEntry: + """Discovered file pending compilation.""" + + path: str + size_bytes: int + file_hash: str + + +@dataclass +class ChangeSet: + """Delta between discovered files and the manifest.""" + + added: List[FileEntry] = field(default_factory=list) + modified: List[FileEntry] = field(default_factory=list) + deleted: List[str] = field(default_factory=list) + unchanged: List[str] = field(default_factory=list) + + +@dataclass +class FileCompileResult: + """Result of compiling a single file.""" + + path: str + tree: Optional[DocumentTree] = None + summary: str = "" + topics: List[str] = field(default_factory=list) + evidence: Optional[EvidenceUnit] = None + cluster_ids: List[str] = field(default_factory=list) + error: Optional[str] = None + + +@dataclass +class CompileReport: + """Summary report of a compile run.""" + + total_files: int = 0 + files_added: int = 0 + files_modified: int = 0 + files_skipped: int = 0 + files_deleted: int = 0 + files_sampled: int = 0 + trees_built: int = 0 + clusters_created: int = 0 + clusters_merged: int = 0 + cross_refs_built: int = 0 + errors: List[str] = field(default_factory=list) + elapsed_seconds: float = 0.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "total_files": self.total_files, + "files_added": self.files_added, + "files_modified": self.files_modified, + "files_skipped": self.files_skipped, + "files_deleted": self.files_deleted, + "files_sampled": self.files_sampled, + "trees_built": self.trees_built, + "clusters_created": self.clusters_created, + "clusters_merged": self.clusters_merged, + "cross_refs_built": self.cross_refs_built, + "errors": self.errors, + "elapsed_seconds": round(self.elapsed_seconds, 2), + } + + +@dataclass +class CompileStatus: + """Status snapshot of the compile state.""" + + total_compiled_files: int = 0 + total_clusters: int = 0 + total_trees: int = 0 + last_compile_at: Optional[str] = None + manifest_path: str = "" + + +# --------------------------------------------------------------------------- +# Importance probability sampler +# --------------------------------------------------------------------------- + +class ImportanceSampler: + """Select a representative subset of files using importance-based probability. + + Sampling strategy for large datasets: + - Larger files get higher probability (they contain more information). + - Uncompiled (new) files are prioritised over previously compiled ones. + - Files with rare extensions get a mild boost (diversity signal). + - The final probability is proportional to a composite importance score. + """ + + def __init__(self, max_files: int, seed: Optional[int] = None): + self._max_files = max_files + self._rng = random.Random(seed) + + def sample(self, files: List[FileEntry], manifest: CompileManifest) -> List[FileEntry]: + """Return up to *max_files* entries sampled by importance.""" + if len(files) <= self._max_files: + return files + + scores = [self._score(f, manifest) for f in files] + total = sum(scores) or 1.0 + probs = [s / total for s in scores] + + selected_indices = set() + attempts = 0 + while len(selected_indices) < self._max_files and attempts < len(files) * 3: + idx = self._weighted_choice(probs) + selected_indices.add(idx) + attempts += 1 + + return [files[i] for i in sorted(selected_indices)] + + def _score(self, entry: FileEntry, manifest: CompileManifest) -> float: + """Compute composite importance score.""" + # Size factor: log-scaled, bounded + size_score = math.log2(max(entry.size_bytes, 1024)) / 20.0 + + # Novelty factor: new files are more important + novelty = 2.0 if entry.path not in manifest.files else 0.5 + + # Extension diversity: rare extensions get a mild boost + ext = Path(entry.path).suffix.lower() + diversity = 1.5 if ext in {".pdf", ".docx", ".doc", ".tex"} else 1.0 + + return size_score * novelty * diversity + + def _weighted_choice(self, probs: List[float]) -> int: + r = self._rng.random() + cumulative = 0.0 + for i, p in enumerate(probs): + cumulative += p + if r <= cumulative: + return i + return len(probs) - 1 + + +# --------------------------------------------------------------------------- +# Compiler +# --------------------------------------------------------------------------- + +class KnowledgeCompiler: + """Orchestrate compile pipeline: file discovery -> tree indexing -> knowledge aggregation.""" + + # File extensions eligible for compilation + _ELIGIBLE_EXTENSIONS = { + ".pdf", ".docx", ".doc", ".md", ".markdown", ".html", ".htm", + ".rst", ".tex", ".txt", ".pptx", ".xlsx", + } + + def __init__( + self, + llm: OpenAIChat, + embedding_client: Optional[Any], + knowledge_storage: KnowledgeStorage, + tree_indexer: DocumentTreeIndexer, + work_path: Union[str, Path], + log_callback: LogCallback = None, + ): + self._llm = llm + self._embedding = embedding_client + self._storage = knowledge_storage + self._tree_indexer = tree_indexer + self._work_path = Path(work_path).expanduser().resolve() + self._log = create_logger(log_callback=log_callback) + + self._compile_dir = self._work_path / ".cache" / "compile" + self._compile_dir.mkdir(parents=True, exist_ok=True) + self._manifest_path = self._compile_dir / "manifest.json" + + # ------------------------------------------------------------------ # + # Public API # + # ------------------------------------------------------------------ # + + async def compile( + self, + paths: List[str], + *, + incremental: bool = True, + shallow: bool = False, + max_files: Optional[int] = None, + concurrency: int = _DEFAULT_CONCURRENCY, + ) -> CompileReport: + """Execute the unified knowledge compile pipeline. + + Args: + paths: Directories or files to compile. + incremental: Skip unchanged files. + shallow: Skip tree building even for eligible files — use direct + LLM summarisation only (faster, lower quality). + max_files: Cap on files to process (triggers importance sampling). + concurrency: Max parallel file compilations. + """ + import time + t0 = time.monotonic() + report = CompileReport() + + # Phase 1: discover and diff + await self._log.info("[Compile] Phase 1: File discovery & change detection") + manifest = self._load_manifest() + discovered = await self._discover_files(paths) + report.total_files = len(discovered) + await self._log.info(f"[Compile] Discovered {len(discovered)} eligible files") + + if incremental: + changes = self._detect_changes(discovered, manifest) + to_compile = changes.added + changes.modified + report.files_skipped = len(changes.unchanged) + report.files_deleted = len(changes.deleted) + for deleted_path in changes.deleted: + manifest.files.pop(deleted_path, None) + else: + to_compile = discovered + report.files_skipped = 0 + + report.files_added = len([f for f in to_compile if f.path not in manifest.files]) + report.files_modified = len(to_compile) - report.files_added + + # Phase 1.5: importance sampling for large datasets + if max_files and len(to_compile) > max_files: + await self._log.info( + f"[Compile] Applying importance sampling: {len(to_compile)} -> {max_files} files" + ) + sampler = ImportanceSampler(max_files=max_files) + to_compile = sampler.sample(to_compile, manifest) + report.files_sampled = len(to_compile) + + if not to_compile: + await self._log.info("[Compile] No files to compile (all up-to-date)") + report.elapsed_seconds = time.monotonic() - t0 + return report + + await self._log.info( + f"[Compile] Phase 2: Processing {len(to_compile)} files " + f"(concurrency={concurrency})" + ) + + # Phase 2: compile files with bounded concurrency + semaphore = asyncio.Semaphore(concurrency) + results: List[FileCompileResult] = [] + + async def _bounded(entry: FileEntry) -> FileCompileResult: + async with semaphore: + return await self._compile_single_file(entry, shallow=shallow) + + tasks = [_bounded(f) for f in to_compile] + for coro in asyncio.as_completed(tasks): + result = await coro + results.append(result) + if result.error: + report.errors.append(f"{result.path}: {result.error}") + else: + if result.tree: + report.trees_built += 1 + # Update manifest + manifest.files[result.path] = FileManifestEntry( + file_hash=get_fast_hash(result.path) or "", + compiled_at=datetime.now(timezone.utc).isoformat(), + has_tree=result.tree is not None, + cluster_ids=result.cluster_ids, + size_bytes=Path(result.path).stat().st_size if Path(result.path).exists() else 0, + ) + + # Phase 3: aggregate results into knowledge network + await self._log.info("[Compile] Phase 3: Knowledge aggregation") + for r in results: + if r.error or not r.summary: + continue + created, merged = await self._aggregate_to_knowledge_network(r) + report.clusters_created += created + report.clusters_merged += merged + + # Phase 4: cross-references + await self._log.info("[Compile] Phase 4: Building cross-references") + report.cross_refs_built = await self._build_cross_references(results) + + # Phase 5: persist manifest + manifest.last_compile_at = datetime.now(timezone.utc).isoformat() + self._save_manifest(manifest) + self._storage.force_sync() + + report.elapsed_seconds = time.monotonic() - t0 + await self._log.info( + f"[Compile] Done in {report.elapsed_seconds:.1f}s — " + f"trees={report.trees_built}, created={report.clusters_created}, " + f"merged={report.clusters_merged}, errors={len(report.errors)}" + ) + return report + + async def get_status(self, paths: List[str]) -> CompileStatus: + """Return current compile status for the given paths.""" + manifest = self._load_manifest() + path_set = {str(Path(p).resolve()) for p in paths} + + compiled_count = 0 + tree_count = 0 + cluster_ids: Set[str] = set() + for fp, entry in manifest.files.items(): + for p in path_set: + if fp.startswith(p): + compiled_count += 1 + if entry.has_tree: + tree_count += 1 + cluster_ids.update(entry.cluster_ids) + break + + return CompileStatus( + total_compiled_files=compiled_count, + total_clusters=len(cluster_ids), + total_trees=tree_count, + last_compile_at=manifest.last_compile_at, + manifest_path=str(self._manifest_path), + ) + + # ------------------------------------------------------------------ # + # File discovery and change detection # + # ------------------------------------------------------------------ # + + async def _discover_files(self, paths: List[str]) -> List[FileEntry]: + """Walk paths and return all compilation-eligible files.""" + entries: List[FileEntry] = [] + seen: Set[str] = set() + + for base in paths: + base_path = Path(base).expanduser().resolve() + if base_path.is_file(): + candidates = [base_path] + elif base_path.is_dir(): + candidates = sorted(base_path.rglob("*")) + else: + continue + + for fp in candidates: + if not fp.is_file(): + continue + if fp.suffix.lower() not in self._ELIGIBLE_EXTENSIONS: + continue + abs_path = str(fp.resolve()) + if abs_path in seen: + continue + seen.add(abs_path) + fh = get_fast_hash(abs_path) + if fh is None: + continue + entries.append(FileEntry( + path=abs_path, + size_bytes=fp.stat().st_size, + file_hash=fh, + )) + + return entries + + def _detect_changes( + self, discovered: List[FileEntry], manifest: CompileManifest, + ) -> ChangeSet: + """Compare discovered files against the manifest for incremental compile.""" + changes = ChangeSet() + current_paths = {f.path for f in discovered} + + for entry in discovered: + prev = manifest.files.get(entry.path) + if prev is None: + changes.added.append(entry) + elif prev.file_hash != entry.file_hash: + changes.modified.append(entry) + else: + changes.unchanged.append(entry.path) + + for old_path in manifest.files: + if old_path not in current_paths: + changes.deleted.append(old_path) + + return changes + + # ------------------------------------------------------------------ # + # Single-file compilation # + # ------------------------------------------------------------------ # + + async def _compile_single_file( + self, + entry: FileEntry, + *, + shallow: bool = False, + ) -> FileCompileResult: + """Unified compile pipeline: tree-if-eligible -> summary -> topics -> evidence. + + When *shallow* is True (or file is ineligible for tree indexing), + the pipeline skips tree building and summarises via a direct LLM call. + """ + result = FileCompileResult(path=entry.path) + try: + await self._log.info(f"[Compile] Processing: {Path(entry.path).name}") + + extraction = await fast_extract(file_path=entry.path) + content = extraction.content + if not content or len(content.strip()) < 100: + result.error = "Insufficient text content" + return result + + use_tree = ( + not shallow + and DocumentTreeIndexer.should_build_tree(entry.path, len(content)) + ) + + if use_tree: + result.tree = await self._tree_indexer.build_tree( + entry.path, content, + ) + + result.summary = await self._extract_summary( + entry.path, content, result.tree, + ) + result.topics = await self._extract_topics(result.summary) + result.evidence = self._build_evidence(entry, content, result) + + except Exception as exc: + result.error = str(exc) + await self._log.warning(f"[Compile] Failed: {entry.path}: {exc}") + + return result + + async def _extract_summary( + self, + file_path: str, + content: str, + tree: Optional[DocumentTree] = None, + ) -> str: + """Generate a document-level summary. + + When a tree is available its root already contains an LLM-synthesized + summary (produced by ``_synthesize_root_summary`` during tree build), + so we reuse it directly — no redundant LLM call. + """ + if tree and tree.root and tree.root.summary: + return tree.root.summary + + preview = content[:16000] if len(content) > 16000 else content + from sirchmunk.llm.prompts import COMPILE_DOC_SUMMARY + prompt = COMPILE_DOC_SUMMARY.format( + file_name=Path(file_path).name, + document_content=preview, + ) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + return resp.content.strip() + + def _build_evidence( + self, + entry: FileEntry, + content: str, + result: FileCompileResult, + ) -> EvidenceUnit: + """Build an EvidenceUnit, populating snippets/tree_path from tree leaves.""" + from sirchmunk.schema.metadata import FileInfo + + snippets: List[str] = [] + tree_path: Optional[List[str]] = None + + if result.tree and result.tree.root: + leaves = result.tree.root.all_leaves() + tree_path = [leaf.node_id for leaf in leaves] + for leaf in leaves: + start, end = leaf.char_range + snippet = content[start:end][:500] + if snippet.strip(): + snippets.append(snippet) + + return EvidenceUnit( + doc_id=FileInfo.get_cache_key(entry.path), + file_or_url=Path(entry.path), + summary=result.summary, + is_found=True, + snippets=snippets, + tree_path=tree_path, + extracted_at=datetime.now(timezone.utc), + ) + + async def _extract_topics(self, summary: str) -> List[str]: + """Extract key topics/entities from a document summary.""" + from sirchmunk.llm.prompts import COMPILE_TOPIC_EXTRACTION + prompt = COMPILE_TOPIC_EXTRACTION.format(summary=summary) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + try: + raw = resp.content.strip() + if raw.startswith("["): + parsed = json.loads(raw) + if isinstance(parsed, list): + return [str(t) for t in parsed if t] + return [t.strip() for t in raw.split(",") if t.strip()] + except (json.JSONDecodeError, TypeError): + return [] + + # ------------------------------------------------------------------ # + # Knowledge aggregation (LLM Wiki Ingest) # + # ------------------------------------------------------------------ # + + async def _aggregate_to_knowledge_network( + self, result: FileCompileResult, + ) -> Tuple[int, int]: + """Aggregate a file's compile result into the knowledge network. + + Three-tier similarity strategy (per design doc): + - similarity >= 0.80 → merge into existing cluster + - 0.50 <= sim < 0.80 → create new cluster + weak edge to similar + - similarity < 0.50 → create standalone cluster + + Returns: + (clusters_created, clusters_merged) + """ + created, merged = 0, 0 + if not result.summary: + return created, merged + + embedding = self._encode_text(result.summary) + + # Search for similar existing clusters across a wider range + best_match: Optional[Dict[str, Any]] = None + if embedding is not None: + similar = await self._storage.search_similar_clusters( + query_embedding=embedding, + top_k=3, + similarity_threshold=0.50, + ) + if similar: + best_match = similar[0] + + if best_match and best_match["similarity"] >= 0.80: + # Tier 1: merge into existing cluster + cluster = await self._storage.get(best_match["id"]) + if cluster: + await self._merge_into_cluster(cluster, result) + # Re-compute embedding for merged content + await self._update_cluster_embedding(cluster) + result.cluster_ids.append(cluster.id) + merged += 1 + return created, merged + + # Create a new cluster (Tier 2 or Tier 3) + cluster = await self._create_cluster(result) + if cluster: + result.cluster_ids.append(cluster.id) + await self._store_cluster_embedding(cluster, embedding, result.summary) + created += 1 + + # Tier 2: build weak edges to moderately similar clusters + if best_match and best_match["similarity"] >= 0.50: + for s in (similar or []): + if s["similarity"] >= 0.50: + target = await self._storage.get(s["id"]) + if target: + self._add_edge(cluster, target.id, "embed_sim", s["similarity"]) + self._add_edge(target, cluster.id, "embed_sim", s["similarity"]) + await self._storage.update(target) + await self._storage.update(cluster) + + return created, merged + + def _encode_text(self, text: str) -> Optional[Any]: + """Encode text to embedding vector, returns None on failure.""" + if not self._embedding: + return None + try: + return self._embedding.encode(text) + except Exception: + return None + + async def _store_cluster_embedding( + self, cluster: KnowledgeCluster, embedding: Optional[Any], text: str, + ) -> None: + """Store embedding for a cluster if available.""" + if embedding is None or not self._embedding: + return + text_hash = hashlib.md5(text.encode()).hexdigest() + vec = embedding.tolist() if hasattr(embedding, "tolist") else list(embedding) + await self._storage.store_embedding( + cluster.id, vec, + self._embedding.model_id or "default", + text_hash, + ) + + async def _update_cluster_embedding(self, cluster: KnowledgeCluster) -> None: + """Re-compute and store embedding after content merge.""" + content_text = str(cluster.content)[:2000] if cluster.content else "" + if not content_text: + return + embedding = self._encode_text(content_text) + await self._store_cluster_embedding(cluster, embedding, content_text) + + async def _merge_into_cluster( + self, + cluster: KnowledgeCluster, + result: FileCompileResult, + ) -> None: + """Merge a file compile result into an existing cluster.""" + # Append evidence + if result.evidence: + existing_doc_ids = {e.doc_id for e in cluster.evidences} + if result.evidence.doc_id not in existing_doc_ids: + cluster.evidences.append(result.evidence) + + # Enrich content via LLM merge + from sirchmunk.llm.prompts import COMPILE_MERGE_KNOWLEDGE + prompt = COMPILE_MERGE_KNOWLEDGE.format( + existing_content=str(cluster.content)[:3000], + new_summary=result.summary[:3000], + ) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + cluster.content = resp.content.strip() + + # Update metadata + cluster.search_results = list(set( + (cluster.search_results or []) + [result.path] + )) + merge_count = getattr(cluster, "merge_count", 0) or 0 + cluster.merge_count = merge_count + 1 + + # Lifecycle promotion + if cluster.merge_count >= 3 and cluster.lifecycle == Lifecycle.EMERGING: + cluster.lifecycle = Lifecycle.STABLE + + await self._storage.update(cluster) + + async def _create_cluster( + self, result: FileCompileResult, + ) -> Optional[KnowledgeCluster]: + """Create a new KnowledgeCluster from a file compile result.""" + cluster_text = result.summary + cluster_id = f"C{hashlib.sha256(cluster_text.encode('utf-8')).hexdigest()[:10]}" + + name = Path(result.path).stem[:60] + if result.topics: + name = result.topics[0][:60] + + cluster = KnowledgeCluster( + id=cluster_id, + name=name, + description=[result.summary[:500]], + content=result.summary, + evidences=[result.evidence] if result.evidence else [], + patterns=result.topics[:5], + lifecycle=Lifecycle.EMERGING, + confidence=0.5, + abstraction_level=AbstractionLevel.TECHNIQUE, + hotness=0.3, + search_results=[result.path], + ) + + ok = await self._storage.insert(cluster) + return cluster if ok else None + + # ------------------------------------------------------------------ # + # Cross-references # + # ------------------------------------------------------------------ # + + async def _build_cross_references( + self, results: List[FileCompileResult], + ) -> int: + """Build co-occurrence edges between clusters that share source files. + + Two clusters are co-occurring when the same source file contributed + evidence to both (e.g., different sections compiled into different + clusters). Includes historical data from the manifest. + """ + # Build a complete map: cluster_id -> set of source file paths + cluster_to_files: Dict[str, Set[str]] = {} + + # From current compile results + for r in results: + for cid in r.cluster_ids: + cluster_to_files.setdefault(cid, set()).add(r.path) + + # From manifest (historical data) + manifest = self._load_manifest() + for fp, entry in manifest.files.items(): + for cid in entry.cluster_ids: + cluster_to_files.setdefault(cid, set()).add(fp) + + # Find cluster pairs that share at least one source file + cluster_ids = list(cluster_to_files.keys()) + edges_created = 0 + pairs_seen: Set[Tuple[str, str]] = set() + + for i in range(len(cluster_ids)): + for j in range(i + 1, len(cluster_ids)): + cid_a, cid_b = cluster_ids[i], cluster_ids[j] + shared = cluster_to_files[cid_a] & cluster_to_files[cid_b] + if not shared: + continue + + pair_key = (min(cid_a, cid_b), max(cid_a, cid_b)) + if pair_key in pairs_seen: + continue + pairs_seen.add(pair_key) + + weight = min(len(shared) * 0.25, 1.0) + c_a = await self._storage.get(cid_a) + c_b = await self._storage.get(cid_b) + if c_a and c_b: + self._add_edge(c_a, cid_b, "co_occur", weight) + self._add_edge(c_b, cid_a, "co_occur", weight) + await self._storage.update(c_a) + await self._storage.update(c_b) + edges_created += 1 + + return edges_created + + @staticmethod + def _add_edge( + cluster: KnowledgeCluster, target_id: str, source: str, weight: float, + ) -> None: + """Add or update a WeakSemanticEdge on a cluster.""" + for edge in cluster.related_clusters: + if edge.target_cluster_id == target_id and edge.source == source: + edge.weight = max(edge.weight, weight) + return + cluster.related_clusters.append( + WeakSemanticEdge(target_cluster_id=target_id, weight=weight, source=source) + ) + + # ------------------------------------------------------------------ # + # Manifest I/O # + # ------------------------------------------------------------------ # + + def _load_manifest(self) -> CompileManifest: + if self._manifest_path.exists(): + try: + return CompileManifest.from_json( + self._manifest_path.read_text(encoding="utf-8") + ) + except Exception: + pass + return CompileManifest() + + def _save_manifest(self, manifest: CompileManifest) -> None: + self._manifest_path.write_text(manifest.to_json(), encoding="utf-8") diff --git a/src/sirchmunk/learnings/knowledge_base.py b/src/sirchmunk/learnings/knowledge_base.py index 387b368..bd2946c 100644 --- a/src/sirchmunk/learnings/knowledge_base.py +++ b/src/sirchmunk/learnings/knowledge_base.py @@ -120,11 +120,14 @@ async def _extract_evidence_for_file( confidence_threshold: float, top_k_snippets: int, verbose: bool, + tree_indexer=None, ) -> Optional[EvidenceUnit]: - """Extract evidence from a single file via Monte Carlo sampling. + """Extract evidence from a single file. - Performs text extraction followed by LLM-driven region-of-interest - identification. Designed to run concurrently for multiple files. + When a tree index exists for the file, uses LLM-driven tree navigation + to locate relevant sections precisely, then runs Monte Carlo sampling + within those narrowed regions. Falls back to full-document Monte + Carlo sampling otherwise. Args: file_path_or_url: Absolute path or URL to the document. @@ -133,6 +136,7 @@ async def _extract_evidence_for_file( confidence_threshold: Minimum confidence for evidence acceptance. top_k_snippets: Maximum evidence snippets per document. verbose: Whether to enable verbose logging. + tree_indexer: Optional DocumentTreeIndexer for tree-based navigation. Returns: EvidenceUnit on success, None on extraction failure. @@ -141,6 +145,28 @@ async def _extract_evidence_for_file( extraction_result = await fast_extract(file_path=file_path_or_url) doc_content: str = extraction_result.content + tree_path_ids = None + + # Try tree-based navigation for focused extraction + if tree_indexer is not None: + tree = tree_indexer.load_tree(file_path_or_url) + if tree is not None: + await self._log.info( + f"[KnowledgeBase] Using tree index for {Path(file_path_or_url).name}" + ) + leaves = await tree_indexer.navigate(tree, query) + if leaves: + # Narrow doc_content to matched regions + tree_path_ids = [n.node_id for n in leaves] + segments = [] + for node in leaves: + start, end = node.char_range + segment = doc_content[start:end] + if segment.strip(): + segments.append(segment) + if segments: + doc_content = "\n\n---\n\n".join(segments) + sampler = MonteCarloEvidenceSampling( llm=self.llm, doc_content=doc_content, @@ -162,6 +188,7 @@ async def _extract_evidence_for_file( snippets=roi_result.snippets, extracted_at=datetime.now(), conflict_group=[], + tree_path=tree_path_ids, ) self.llm_usages.extend(sampler.llm_usages) return evidence_unit diff --git a/src/sirchmunk/learnings/lint.py b/src/sirchmunk/learnings/lint.py new file mode 100644 index 0000000..e5baa6f --- /dev/null +++ b/src/sirchmunk/learnings/lint.py @@ -0,0 +1,213 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Knowledge lint — health checks and auto-fixes for the knowledge network. + +Inspired by LLM Wiki's Lint operation: validates cluster integrity, +detects stale evidence, and cleans orphaned tree indices. +""" + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union + +from sirchmunk.schema.knowledge import KnowledgeCluster, Lifecycle +from sirchmunk.storage.knowledge_storage import KnowledgeStorage +from sirchmunk.utils import LogCallback, create_logger + + +@dataclass +class LintIssue: + """A single lint finding.""" + + severity: str # "error", "warning", "info" + category: str # "stale_evidence", "orphan_tree", "empty_cluster", etc. + message: str + cluster_id: Optional[str] = None + file_path: Optional[str] = None + auto_fixed: bool = False + + def to_dict(self) -> Dict[str, Any]: + return { + "severity": self.severity, + "category": self.category, + "message": self.message, + "cluster_id": self.cluster_id, + "file_path": self.file_path, + "auto_fixed": self.auto_fixed, + } + + +@dataclass +class LintReport: + """Summary of a lint run.""" + + total_clusters_checked: int = 0 + total_trees_checked: int = 0 + issues: List[LintIssue] = field(default_factory=list) + auto_fixes_applied: int = 0 + + @property + def errors(self) -> int: + return sum(1 for i in self.issues if i.severity == "error") + + @property + def warnings(self) -> int: + return sum(1 for i in self.issues if i.severity == "warning") + + def to_dict(self) -> Dict[str, Any]: + return { + "total_clusters_checked": self.total_clusters_checked, + "total_trees_checked": self.total_trees_checked, + "errors": self.errors, + "warnings": self.warnings, + "auto_fixes_applied": self.auto_fixes_applied, + "issues": [i.to_dict() for i in self.issues], + } + + +class KnowledgeLint: + """Validate the health of the knowledge network and apply auto-fixes.""" + + def __init__( + self, + knowledge_storage: KnowledgeStorage, + work_path: Union[str, Path], + log_callback: LogCallback = None, + ): + self._storage = knowledge_storage + self._work_path = Path(work_path).expanduser().resolve() + self._tree_dir = self._work_path / ".cache" / "compile" / "trees" + self._manifest_path = self._work_path / ".cache" / "compile" / "manifest.json" + self._log = create_logger(log_callback=log_callback) + + async def run(self, *, auto_fix: bool = False) -> LintReport: + """Execute all lint checks and optionally apply auto-fixes.""" + report = LintReport() + + await self._log.info("[Lint] Starting knowledge health check") + + # Check clusters + await self._check_clusters(report, auto_fix=auto_fix) + + # Check orphaned tree caches + await self._check_orphan_trees(report, auto_fix=auto_fix) + + # Check manifest consistency + await self._check_manifest(report) + + await self._log.info( + f"[Lint] Done — clusters={report.total_clusters_checked}, " + f"trees={report.total_trees_checked}, " + f"errors={report.errors}, warnings={report.warnings}, " + f"fixes={report.auto_fixes_applied}" + ) + return report + + async def _check_clusters(self, report: LintReport, auto_fix: bool) -> None: + """Validate each knowledge cluster.""" + all_clusters = await self._storage.find("", limit=10000) + report.total_clusters_checked = len(all_clusters) + + for cluster in all_clusters: + # Check: empty content + if not cluster.content or ( + isinstance(cluster.content, str) and len(cluster.content.strip()) < 10 + ): + report.issues.append(LintIssue( + severity="warning", + category="empty_cluster", + message=f"Cluster has empty or minimal content", + cluster_id=cluster.id, + )) + + # Check: stale evidence (source files no longer exist) + stale_count = 0 + for ev in cluster.evidences: + fp = str(ev.file_or_url) + if fp.startswith("/") and not Path(fp).exists(): + stale_count += 1 + + if stale_count > 0: + report.issues.append(LintIssue( + severity="warning", + category="stale_evidence", + message=f"{stale_count} evidence source(s) no longer exist", + cluster_id=cluster.id, + )) + + if auto_fix and stale_count == len(cluster.evidences): + cluster.lifecycle = Lifecycle.DEPRECATED + await self._storage.update(cluster) + report.auto_fixes_applied += 1 + report.issues[-1].auto_fixed = True + + # Check: no queries and no evidences (orphan cluster) + if not cluster.evidences and not cluster.queries: + report.issues.append(LintIssue( + severity="info", + category="orphan_cluster", + message="Cluster has no evidence and no queries", + cluster_id=cluster.id, + )) + + # Check: isolated cluster (no WeakSemanticEdge connections) + if not cluster.related_clusters and cluster.evidences: + report.issues.append(LintIssue( + severity="info", + category="isolated_cluster", + message="Cluster has no cross-references to other clusters", + cluster_id=cluster.id, + )) + + async def _check_orphan_trees(self, report: LintReport, auto_fix: bool) -> None: + """Find tree cache files whose source documents no longer exist.""" + if not self._tree_dir.exists(): + return + + manifest = self._load_manifest() + # Build set of valid file hashes from the manifest + valid_hashes: Set[str] = set() + for entry_data in manifest.get("files", {}).values(): + fh = entry_data.get("file_hash", "") + if fh: + valid_hashes.add(fh) + + tree_files = list(self._tree_dir.glob("*.json")) + report.total_trees_checked = len(tree_files) + + for tf in tree_files: + tree_hash = tf.stem + if tree_hash not in valid_hashes: + report.issues.append(LintIssue( + severity="info", + category="orphan_tree", + message=f"Tree cache has no matching manifest entry", + file_path=str(tf), + )) + if auto_fix: + tf.unlink(missing_ok=True) + report.auto_fixes_applied += 1 + report.issues[-1].auto_fixed = True + + async def _check_manifest(self, report: LintReport) -> None: + """Validate manifest references.""" + manifest = self._load_manifest() + files = manifest.get("files", {}) + + for fp, entry_data in files.items(): + if not Path(fp).exists(): + report.issues.append(LintIssue( + severity="warning", + category="stale_manifest", + message=f"Manifest references non-existent file", + file_path=fp, + )) + + def _load_manifest(self) -> Dict[str, Any]: + if self._manifest_path.exists(): + try: + return json.loads(self._manifest_path.read_text(encoding="utf-8")) + except Exception: + pass + return {} diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py new file mode 100644 index 0000000..53ebf0b --- /dev/null +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -0,0 +1,444 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Document tree indexer — PageIndex-inspired hierarchical structure analysis. + +Builds a JSON tree index for structured long documents (PDF, DOCX, MD, HTML) +so that downstream search can navigate via LLM reasoning instead of brute-force +Monte Carlo sampling. +""" + +import json +import re +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from sirchmunk.llm.openai_chat import OpenAIChat +from sirchmunk.utils import LogCallback, create_logger +from sirchmunk.utils.file_utils import get_fast_hash + +# File-size threshold: skip tree indexing for small files +_TREE_MIN_CHARS = 50_000 # 50 K characters + +# Extensions eligible for tree indexing +_TREE_EXTENSIONS = { + ".pdf", ".docx", ".doc", ".md", ".markdown", + ".html", ".htm", ".rst", ".tex", ".txt", +} + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class TreeNode: + """Single node in the document tree.""" + + node_id: str + title: str + summary: str + char_range: Tuple[int, int] # [start, end) in the extracted text + level: int = 0 + page_range: Optional[Tuple[int, int]] = None + children: List["TreeNode"] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "node_id": self.node_id, + "title": self.title, + "summary": self.summary, + "char_range": list(self.char_range), + "level": self.level, + "page_range": list(self.page_range) if self.page_range else None, + "children": [c.to_dict() for c in self.children], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TreeNode": + children = [cls.from_dict(c) for c in data.get("children", [])] + pr = data.get("page_range") + return cls( + node_id=data["node_id"], + title=data["title"], + summary=data["summary"], + char_range=tuple(data["char_range"]), + level=data.get("level", 0), + page_range=tuple(pr) if pr else None, + children=children, + ) + + @property + def leaf(self) -> bool: + return len(self.children) == 0 + + def all_leaves(self) -> List["TreeNode"]: + """Return all leaf nodes under this subtree.""" + if self.leaf: + return [self] + leaves: List["TreeNode"] = [] + for c in self.children: + leaves.extend(c.all_leaves()) + return leaves + + +@dataclass +class DocumentTree: + """Complete tree index for a single document.""" + + file_path: str + file_hash: str + created_at: str + total_chars: int + total_pages: Optional[int] = None + root: Optional[TreeNode] = None + + def to_json(self) -> str: + return json.dumps({ + "file_path": self.file_path, + "file_hash": self.file_hash, + "created_at": self.created_at, + "total_chars": self.total_chars, + "total_pages": self.total_pages, + "root": self.root.to_dict() if self.root else None, + }, ensure_ascii=False, indent=2) + + @classmethod + def from_json(cls, json_str: str) -> "DocumentTree": + data = json.loads(json_str) + root = TreeNode.from_dict(data["root"]) if data.get("root") else None + return cls( + file_path=data["file_path"], + file_hash=data["file_hash"], + created_at=data["created_at"], + total_chars=data["total_chars"], + total_pages=data.get("total_pages"), + root=root, + ) + + +# --------------------------------------------------------------------------- +# Indexer +# --------------------------------------------------------------------------- + +class DocumentTreeIndexer: + """Build and cache PageIndex-style hierarchical tree indices for documents.""" + + def __init__( + self, + llm: OpenAIChat, + cache_dir: Union[str, Path], + log_callback: LogCallback = None, + ): + self._llm = llm + self._cache_dir = Path(cache_dir) + self._cache_dir.mkdir(parents=True, exist_ok=True) + self._log = create_logger(log_callback=log_callback) + + # ------------------------------------------------------------------ # + # Public API # + # ------------------------------------------------------------------ # + + async def build_tree( + self, + file_path: str, + content: str, + *, + max_depth: int = 4, + force_rebuild: bool = False, + total_pages: Optional[int] = None, + ) -> Optional[DocumentTree]: + """Build a tree index for a document. + + Returns None when the document is too small or unstructured. + """ + file_hash = get_fast_hash(file_path) + if file_hash is None: + return None + + if not force_rebuild: + cached = self._load_cache(file_hash) + if cached is not None: + await self._log.info(f"[TreeIndexer] Cache hit for {Path(file_path).name}") + return cached + + if len(content) < _TREE_MIN_CHARS: + return None + + ext = Path(file_path).suffix.lower() + if ext not in _TREE_EXTENSIONS: + return None + + await self._log.info( + f"[TreeIndexer] Building tree for {Path(file_path).name} " + f"({len(content)} chars, depth={max_depth})" + ) + + root = await self._build_node(content, level=0, max_depth=max_depth) + if root is None: + return None + + tree = DocumentTree( + file_path=file_path, + file_hash=file_hash, + created_at=datetime.now(timezone.utc).isoformat(), + total_chars=len(content), + total_pages=total_pages, + root=root, + ) + self._save_cache(file_hash, tree) + await self._log.info( + f"[TreeIndexer] Built tree: {self._count_nodes(root)} nodes, " + f"depth={self._max_node_depth(root)}" + ) + return tree + + async def navigate( + self, + tree: DocumentTree, + query: str, + *, + max_results: int = 3, + ) -> List[TreeNode]: + """Reasoning-based tree navigation: LLM selects the most relevant branches. + + Returns up to *max_results* leaf nodes with their char_range for + precise evidence extraction. + """ + if tree.root is None: + return [] + + candidates = tree.root.children if tree.root.children else [tree.root] + if not candidates: + return [tree.root] + + selected = await self._select_children(candidates, query) + if not selected: + return [] + + result_leaves: List[TreeNode] = [] + for node in selected: + if node.leaf: + result_leaves.append(node) + else: + deeper = await self._select_children(node.children, query) + for d in (deeper or node.children[:1]): + result_leaves.extend(d.all_leaves()[:max_results]) + + # Deduplicate and cap + seen_ids = set() + unique: List[TreeNode] = [] + for n in result_leaves: + if n.node_id not in seen_ids: + seen_ids.add(n.node_id) + unique.append(n) + return unique[:max_results] + + def load_tree(self, file_path: str) -> Optional[DocumentTree]: + """Load a cached tree index for the given file (sync).""" + file_hash = get_fast_hash(file_path) + if file_hash is None: + return None + return self._load_cache(file_hash) + + def has_tree(self, file_path: str) -> bool: + """Check whether a cached tree index exists for the file.""" + file_hash = get_fast_hash(file_path) + if file_hash is None: + return False + return self._cache_path(file_hash).exists() + + # ------------------------------------------------------------------ # + # Internals # + # ------------------------------------------------------------------ # + + async def _build_node( + self, text: str, level: int, max_depth: int, + offset: int = 0, + ) -> Optional[TreeNode]: + """Recursively build tree nodes via LLM structure analysis.""" + from sirchmunk.llm.prompts import COMPILE_TREE_STRUCTURE + + preview = text[:12000] if len(text) > 12000 else text + prompt = COMPILE_TREE_STRUCTURE.format( + document_content=preview, + max_sections=8, + ) + + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + sections = self._parse_sections(resp.content, text) + + if not sections: + return TreeNode( + node_id=f"N{offset:06d}", + title="Document", + summary=text[:300], + char_range=(offset, offset + len(text)), + level=level, + ) + + children: List[TreeNode] = [] + for i, sec in enumerate(sections): + child = TreeNode( + node_id=f"N{sec['start'] + offset:06d}", + title=sec["title"], + summary=sec["summary"], + char_range=(sec["start"] + offset, sec["end"] + offset), + level=level + 1, + ) + section_text = text[sec["start"]:sec["end"]] + if level + 1 < max_depth and len(section_text) > _TREE_MIN_CHARS: + deeper = await self._build_node( + section_text, level + 1, max_depth, offset=sec["start"] + offset, + ) + if deeper and deeper.children: + child.children = deeper.children + children.append(child) + + root_summary = await self._synthesize_root_summary(children) + + return TreeNode( + node_id=f"N{offset:06d}", + title="Document", + summary=root_summary, + char_range=(offset, offset + len(text)), + level=level, + children=children, + ) + + async def _synthesize_root_summary(self, children: List[TreeNode]) -> str: + """Synthesize a document-level summary from children's section summaries.""" + if not children: + return "" + from sirchmunk.llm.prompts import COMPILE_SYNTHESIZE_SUMMARY + sections_text = "\n".join( + f"- {c.title}: {c.summary}" for c in children + ) + prompt = COMPILE_SYNTHESIZE_SUMMARY.format(sections=sections_text) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + return resp.content.strip() + + def _parse_sections( + self, llm_output: str, full_text: str, + ) -> List[Dict[str, Any]]: + """Parse LLM section output into [{title, summary, start, end}, ...].""" + # Try JSON array first + try: + raw = llm_output + # Strip markdown fences + raw = re.sub(r"^```(?:json)?\s*", "", raw, flags=re.MULTILINE) + raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE).strip() + m = re.search(r"\[.*\]", raw, re.DOTALL) + if m: + items = json.loads(m.group()) + return self._resolve_positions(items, full_text) + except (json.JSONDecodeError, TypeError): + pass + return [] + + @staticmethod + def _resolve_positions( + items: List[Dict[str, Any]], full_text: str, + ) -> List[Dict[str, Any]]: + """Resolve section start/end character offsets from marker text.""" + resolved: List[Dict[str, Any]] = [] + prev_end = 0 + text_lower = full_text.lower() + for item in items: + title = item.get("title", "") + summary = item.get("summary", "") + marker = item.get("start_marker", title) + + pos = text_lower.find(marker.lower(), prev_end) if marker else -1 + start = pos if pos >= 0 else prev_end + + end_marker = item.get("end_marker", "") + if end_marker: + epos = text_lower.find(end_marker.lower(), start + 1) + end = epos if epos > start else min(start + 50000, len(full_text)) + else: + end = min(start + 50000, len(full_text)) + + resolved.append({ + "title": title, + "summary": summary, + "start": start, + "end": end, + }) + prev_end = end + + # Fix gaps: each section ends where the next begins + for i in range(len(resolved) - 1): + resolved[i]["end"] = resolved[i + 1]["start"] + if resolved: + resolved[-1]["end"] = len(full_text) + + return [s for s in resolved if s["end"] > s["start"]] + + async def _select_children( + self, nodes: List[TreeNode], query: str, + ) -> List[TreeNode]: + """LLM-driven branch selection: pick the most relevant children.""" + if len(nodes) <= 2: + return nodes + + listing = "\n".join( + f"[{i}] {n.title}: {n.summary[:150]}" + for i, n in enumerate(nodes) + ) + prompt = ( + f"Given the query: \"{query}\"\n\n" + f"Select the 1-2 most relevant sections (by index number):\n{listing}\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + try: + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + indices = json.loads(m.group()) + return [nodes[i] for i in indices if 0 <= i < len(nodes)] + except (json.JSONDecodeError, IndexError, TypeError): + pass + return nodes[:2] + + # ------------------------------------------------------------------ # + # Cache I/O # + # ------------------------------------------------------------------ # + + def _cache_path(self, file_hash: str) -> Path: + return self._cache_dir / f"{file_hash}.json" + + def _save_cache(self, file_hash: str, tree: DocumentTree) -> None: + path = self._cache_path(file_hash) + path.write_text(tree.to_json(), encoding="utf-8") + + def _load_cache(self, file_hash: str) -> Optional[DocumentTree]: + path = self._cache_path(file_hash) + if not path.exists(): + return None + try: + return DocumentTree.from_json(path.read_text(encoding="utf-8")) + except Exception: + return None + + # ------------------------------------------------------------------ # + # Helpers # + # ------------------------------------------------------------------ # + + @staticmethod + def _count_nodes(node: TreeNode) -> int: + return 1 + sum(DocumentTreeIndexer._count_nodes(c) for c in node.children) + + @staticmethod + def _max_node_depth(node: TreeNode) -> int: + if not node.children: + return node.level + return max(DocumentTreeIndexer._max_node_depth(c) for c in node.children) + + @staticmethod + def should_build_tree(file_path: str, content_length: int) -> bool: + """Determine whether a file is eligible for tree indexing.""" + ext = Path(file_path).suffix.lower() + return ext in _TREE_EXTENSIONS and content_length >= _TREE_MIN_CHARS diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 1a07e64..b3ded32 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -423,3 +423,90 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: true/false true/false """ + + +# --------------------------------------------------------------------------- +# Knowledge Compile prompts +# --------------------------------------------------------------------------- + +COMPILE_TREE_STRUCTURE = """Analyze the following document and identify its natural hierarchical structure (chapters, sections, subsections). + +### Document Content (may be truncated) +{document_content} + +### Output Requirements +Return a JSON array of top-level sections. Each section object must have: +- "title": Section heading or descriptive title +- "summary": 1-2 sentence summary of the section content +- "start_marker": A short text string (5-15 words) that appears verbatim at the start of this section in the document +- "end_marker": A short text string that appears at the start of the NEXT section (empty for the last section) + +Maximum {max_sections} sections. Identify only the most significant structural boundaries. + +### Output Format +Return ONLY a JSON array, no extra text: +[ + {{"title": "...", "summary": "...", "start_marker": "...", "end_marker": "..."}}, + ... +] +""" + + +COMPILE_SYNTHESIZE_SUMMARY = """Synthesize a comprehensive document summary from the following section summaries. + +### Section Summaries +{sections} + +### Output +Provide a unified, coherent summary in 3-8 sentences that captures the document's overall topic, key arguments, and conclusions. Do not simply list the sections — weave them into a natural narrative. +Write in the same language as the section summaries.""" + + +COMPILE_DOC_SUMMARY = """Summarize the following document concisely, capturing the key topics, arguments, conclusions, and important details. + +### File: {file_name} + +### Document Content (may be truncated) +{document_content} + +### Output +Provide a comprehensive summary in 3-8 sentences. Focus on: +1. What is this document about (main topic/purpose) +2. Key findings, arguments, or conclusions +3. Important details, data points, or methodologies + +Write the summary in the same language as the document content.""" + + +COMPILE_TOPIC_EXTRACTION = """Extract the 3-5 most important topics, concepts, or entities from the following summary. + +### Summary +{summary} + +### Output +Return ONLY a JSON array of topic strings, no extra text: +["topic1", "topic2", "topic3"] + +Rules: +- Each topic should be 1-4 words +- Prefer specific, domain-relevant terms over generic ones +- Use the same language as the summary""" + + +COMPILE_MERGE_KNOWLEDGE = """You are merging new information into an existing knowledge cluster. + +### Existing Knowledge +{existing_content} + +### New Information +{new_summary} + +### Task +Produce an updated, unified summary that: +1. Preserves all important information from the existing knowledge +2. Integrates the new information, avoiding redundancy +3. Highlights any contradictions or complementary perspectives +4. Maintains a coherent, well-structured narrative + +### Output +Return ONLY the merged summary text (no extra tags or metadata). Keep the same language as the inputs.""" diff --git a/src/sirchmunk/schema/knowledge.py b/src/sirchmunk/schema/knowledge.py index 336963d..2a6e149 100644 --- a/src/sirchmunk/schema/knowledge.py +++ b/src/sirchmunk/schema/knowledge.py @@ -57,6 +57,12 @@ class EvidenceUnit: # IDs of conflict group if this evidence contradicts others conflict_group: Optional[List[str]] = None + # Tree-index node path from root to the matched node (e.g. ["N000000", "N001234"]) + tree_path: Optional[List[str]] = None + + # Character range within the document for precise evidence location + page_range: Optional[List[int]] = None + def to_dict(self) -> Dict[str, Any]: """ Serialize EvidenceUnit to a dictionary. @@ -69,6 +75,8 @@ def to_dict(self) -> Dict[str, Any]: "snippets": self.snippets, "extracted_at": self.extracted_at.isoformat(), "conflict_group": self.conflict_group, + "tree_path": self.tree_path, + "page_range": self.page_range, } @@ -234,6 +242,9 @@ class KnowledgeCluster: # Used for semantic similarity matching and cluster reuse queries: List[str] = None + # Number of times this cluster has been merged with new evidence during compile + merge_count: int = 0 + def __post_init__(self): if self.related_clusters is None: self.related_clusters = [] @@ -391,5 +402,6 @@ def to_dict(self) -> Dict[str, Any]: "related_clusters": [rc.to_dict() for rc in self.related_clusters], "search_results": self.search_results, "queries": self.queries, + "merge_count": self.merge_count, } diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index aee1c16..52f9650 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -919,6 +919,119 @@ def _ensure_tool_registry( self._tool_registry_key = cache_key return registry + # ------------------------------------------------------------------ + # Knowledge compile entry point + # ------------------------------------------------------------------ + + async def compile( + self, + paths: Optional[Union[str, Path, List[str], List[Path]]] = None, + *, + incremental: bool = True, + shallow: bool = False, + max_files: Optional[int] = None, + concurrency: int = 3, + ) -> Dict[str, Any]: + """Compile document collections into structured knowledge indices. + + Optional offline pre-processing step that builds tree indices and + knowledge clusters. Products are automatically leveraged by + subsequent search() calls. + + Args: + paths: Directories or files to compile. Falls back to self.paths. + incremental: Skip unchanged files (default True). + shallow: Skip tree building — use direct LLM summarisation only. + max_files: Cap on files — triggers importance sampling for large sets. + concurrency: Max parallel file compilations. + + Returns: + CompileReport as a dict. + """ + from sirchmunk.learnings.compiler import KnowledgeCompiler + from sirchmunk.learnings.tree_indexer import DocumentTreeIndexer + + resolved = self._resolve_paths(paths) + await self._logger.info( + f"[Compile] Starting compile for {len(resolved)} path(s)" + ) + + tree_cache = self.work_path / ".cache" / "compile" / "trees" + _cb = getattr(self._logger, 'log_callback', None) + tree_indexer = DocumentTreeIndexer( + llm=self.llm, + cache_dir=tree_cache, + log_callback=_cb, + ) + + compiler = KnowledgeCompiler( + llm=self.llm, + embedding_client=self.embedding_client, + knowledge_storage=self.knowledge_storage, + tree_indexer=tree_indexer, + work_path=self.work_path, + log_callback=_cb, + ) + + report = await compiler.compile( + paths=resolved, + incremental=incremental, + shallow=shallow, + max_files=max_files, + concurrency=concurrency, + ) + + return report.to_dict() + + async def compile_status( + self, + paths: Optional[Union[str, Path, List[str], List[Path]]] = None, + ) -> Dict[str, Any]: + """Return current compile status for the given paths.""" + from sirchmunk.learnings.compiler import KnowledgeCompiler + from sirchmunk.learnings.tree_indexer import DocumentTreeIndexer + + resolved = self._resolve_paths(paths) + + tree_cache = self.work_path / ".cache" / "compile" / "trees" + tree_indexer = DocumentTreeIndexer( + llm=self.llm, cache_dir=tree_cache, + ) + + compiler = KnowledgeCompiler( + llm=self.llm, + embedding_client=self.embedding_client, + knowledge_storage=self.knowledge_storage, + tree_indexer=tree_indexer, + work_path=self.work_path, + ) + + status = await compiler.get_status(resolved) + return { + "total_compiled_files": status.total_compiled_files, + "total_clusters": status.total_clusters, + "total_trees": status.total_trees, + "last_compile_at": status.last_compile_at, + "manifest_path": status.manifest_path, + } + + async def compile_lint( + self, + *, + auto_fix: bool = False, + ) -> Dict[str, Any]: + """Run knowledge health checks and optionally auto-fix issues.""" + from sirchmunk.learnings.lint import KnowledgeLint + + linter = KnowledgeLint( + knowledge_storage=self.knowledge_storage, + work_path=self.work_path, + log_callback=getattr(self._logger, 'log_callback', None), + ) + + report = await linter.run(auto_fix=auto_fix) + return report.to_dict() + # ------------------------------------------------------------------ # Unified search entry point # ------------------------------------------------------------------ diff --git a/src/sirchmunk/storage/knowledge_storage.py b/src/sirchmunk/storage/knowledge_storage.py index 0a99168..e62c1cf 100644 --- a/src/sirchmunk/storage/knowledge_storage.py +++ b/src/sirchmunk/storage/knowledge_storage.py @@ -107,6 +107,10 @@ def _load_from_parquet(self): variable-length ``FLOAT[]`` from Parquet's list encoding, breaking ``list_cosine_similarity`` which requires matching fixed-size types. + Handles schema evolution gracefully: if the parquet file has fewer + columns than the current schema (e.g., missing ``merge_count``), + missing columns are filled with defaults instead of failing. + Also records the file's modification time so that ``_check_and_reload()`` can detect external changes later. """ @@ -117,11 +121,38 @@ def _load_from_parquet(self): self.db.drop_table(self.table_name, if_exists=True) # Create table with explicit schema (preserves FLOAT[384]) self._create_table() - # Insert data from parquet — DuckDB casts to the declared types - self.db.execute( - f"INSERT INTO {self.table_name} " - f"SELECT * FROM read_parquet('{self.parquet_file}')" - ) + # Detect parquet columns to handle schema evolution + try: + pq_cols = self.db.fetch_all( + f"SELECT column_name FROM parquet_schema('{self.parquet_file}')" + ) + pq_col_names = {row[0] for row in pq_cols} + except Exception: + pq_col_names = None + + if pq_col_names is not None: + # Build column-by-column SELECT with defaults for missing cols + schema_cols = list(self._get_schema_columns()) + select_parts = [] + for col_name in schema_cols: + if col_name in pq_col_names: + select_parts.append(col_name) + elif col_name == "merge_count": + select_parts.append("0 AS merge_count") + else: + select_parts.append(f"NULL AS {col_name}") + select_clause = ", ".join(select_parts) + self.db.execute( + f"INSERT INTO {self.table_name} " + f"SELECT {select_clause} FROM read_parquet('{self.parquet_file}')" + ) + else: + # Fallback: try direct SELECT * (works when schemas match) + self.db.execute( + f"INSERT INTO {self.table_name} " + f"SELECT * FROM read_parquet('{self.parquet_file}')" + ) + count = self.db.get_table_count(self.table_name) # Record mtime for stale-detection self._parquet_loaded_mtime = pq.stat().st_mtime @@ -138,6 +169,18 @@ def _load_from_parquet(self): self._create_table() self._parquet_loaded_mtime = 0.0 + def _get_schema_columns(self) -> List[str]: + """Return the ordered list of column names in the canonical schema.""" + return [ + "id", "name", "description", "content", "scripts", "resources", + "evidences", "patterns", "constraints", "confidence", + "abstraction_level", "landmark_potential", "hotness", "lifecycle", + "create_time", "last_modified", "version", "related_clusters", + "search_results", "queries", "merge_count", + "embedding_vector", "embedding_model", "embedding_timestamp", + "embedding_text_hash", + ] + def _check_and_reload(self): """Check if the parquet file was modified externally and reload if so. @@ -190,6 +233,7 @@ def _create_table(self): "related_clusters": "VARCHAR", # JSON array "search_results": "VARCHAR", # JSON array "queries": "VARCHAR", # JSON array of historical queries + "merge_count": "INTEGER", # compile merge counter "embedding_vector": "FLOAT[384]", # 384-dim embedding vector "embedding_model": "VARCHAR", # Model identifier "embedding_timestamp": "TIMESTAMP", # Embedding computation time @@ -338,21 +382,22 @@ def _cluster_to_row(self, cluster: KnowledgeCluster) -> Dict[str, Any]: "related_clusters": json.dumps([rc.to_dict() for rc in cluster.related_clusters]), "search_results": json.dumps(cluster.search_results) if cluster.search_results else None, "queries": json.dumps(cluster.queries) if cluster.queries else None, + "merge_count": cluster.merge_count or 0, } def _row_to_cluster(self, row: tuple) -> KnowledgeCluster: """ Convert database row to KnowledgeCluster. - Expected row structure (24 columns): + Expected row structure (25 columns): id, name, description, content, scripts, resources, evidences, patterns, constraints, confidence, abstraction_level, landmark_potential, hotness, lifecycle, create_time, last_modified, version, related_clusters, search_results, queries, - embedding_vector, embedding_model, embedding_timestamp, embedding_text_hash + merge_count, embedding_vector, embedding_model, embedding_timestamp, embedding_text_hash """ - if len(row) != 24: + if len(row) != 25: raise ValueError( - f"Expected 24 columns in knowledge_clusters row, got {len(row)}. " + f"Expected 25 columns in knowledge_clusters row, got {len(row)}. " f"Please ensure the table schema is up to date." ) @@ -361,6 +406,7 @@ def _row_to_cluster(self, row: tuple) -> KnowledgeCluster: id, name, description, content, scripts, resources, evidences, patterns, constraints, confidence, abstraction_level, landmark_potential, hotness, lifecycle, create_time, last_modified, version, related_clusters, search_results, queries, + merge_count, _embedding_vector, _embedding_model, _embedding_timestamp, _embedding_text_hash ) = row @@ -400,7 +446,9 @@ def _row_to_cluster(self, row: tuple) -> KnowledgeCluster: is_found=ev_dict["is_found"], snippets=ev_dict["snippets"], extracted_at=extracted_at_parsed or datetime.now(), - conflict_group=ev_dict.get("conflict_group") + conflict_group=ev_dict.get("conflict_group"), + tree_path=ev_dict.get("tree_path"), + page_range=ev_dict.get("page_range"), )) # Parse constraints @@ -463,6 +511,7 @@ def _row_to_cluster(self, row: tuple) -> KnowledgeCluster: related_clusters=related_clusters_parsed, search_results=search_results_parsed, queries=queries_parsed, + merge_count=merge_count or 0, ) # ------------------------------------------------------------------ # From c4f4b166df34f5e4258eefce9fae687ab1dae82d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 13 Apr 2026 20:07:35 +0800 Subject: [PATCH 03/70] improve compile infer --- src/sirchmunk/learnings/README.md | 30 ++ src/sirchmunk/learnings/knowledge_base.py | 4 + src/sirchmunk/search.py | 556 ++++++++++++++++++++-- 3 files changed, 563 insertions(+), 27 deletions(-) diff --git a/src/sirchmunk/learnings/README.md b/src/sirchmunk/learnings/README.md index 0fc1bbe..92bc22b 100644 --- a/src/sirchmunk/learnings/README.md +++ b/src/sirchmunk/learnings/README.md @@ -37,6 +37,36 @@ The module fuses insights from three frameworks: Compile products are automatically leveraged by search when present, but search functions independently without them. +### How Search Consumes Compile Products + +``` +Compile products Search consumption path +───────────────── ────────────────────────────────────────────── +KnowledgeCluster ─┬─ FAST + DEEP Phase 0: embedding similarity + .content │ reuse (instant short-circuit, no LLM cost) + .embedding │ → enriched with evidence snippets + .evidences[].file_or_url │ + ├─ DEEP Phase 1: _probe_knowledge_cache() + │ fuzzy text search → file path discovery + │ +WeakSemanticEdge ├─ DEEP Phase 1: one-hop graph expansion + .related_clusters │ follows edges to gather neighbour files + │ +DocumentTree (.json) └─ DEEP Phase 3: tree-navigated evidence + via tree_indexer _build_cluster() → knowledge_base.build() + → _extract_evidence_for_file(tree_indexer) + → narrows doc to relevant sections before + Monte Carlo sampling +``` + +| Compile product | FAST | DEEP | +|-----------------|------|------| +| Cluster embedding reuse | Yes | Yes | +| Evidence snippets in reused content | Yes | Yes | +| Fuzzy cluster → file path hints | — | Yes | +| Graph edge expansion (neighbours) | — | Yes | +| Tree-navigated evidence extraction | — | Yes | + --- ## Components diff --git a/src/sirchmunk/learnings/knowledge_base.py b/src/sirchmunk/learnings/knowledge_base.py index bd2946c..7296f71 100644 --- a/src/sirchmunk/learnings/knowledge_base.py +++ b/src/sirchmunk/learnings/knowledge_base.py @@ -208,6 +208,7 @@ async def build( top_k_snippets: Optional[int] = 5, confidence_threshold: Optional[float] = 8.0, verbose: bool = True, + tree_indexer=None, ) -> Union[KnowledgeCluster, None]: """Build a knowledge cluster from retrieved information and metadata. @@ -223,6 +224,8 @@ async def build( top_k_snippets: Max evidence snippets per file. confidence_threshold: Min confidence for evidence acceptance. verbose: Enable verbose logging. + tree_indexer: Optional DocumentTreeIndexer for tree-navigated + evidence extraction (uses compiled tree indices when available). Returns: KnowledgeCluster on success, None if no evidence found. @@ -250,6 +253,7 @@ async def build( confidence_threshold=confidence_threshold, top_k_snippets=top_k_snippets, verbose=verbose, + tree_indexer=tree_indexer, ) for info in retrieved_infos ] diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 52f9650..63d3ba8 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -8,6 +8,7 @@ import os import re import traceback +from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple, Union @@ -81,6 +82,43 @@ _NO_RESULTS_MESSAGE = "No results found." +# Soft-similarity threshold for gradient cluster reuse (P2) +_SOFT_SIM_THRESHOLD = 0.65 + + +@dataclass +class SoftClusterHit: + """Signals from clusters that are related but below the hard reuse threshold. + + Carries structured hints (keywords, file paths, background context) that + downstream retrieval phases can exploit without short-circuiting the search. + """ + + patterns: List[str] + file_paths: List[str] + context_summary: str + cluster_ids: List[str] + + +@dataclass +class KnowledgeProbeResult: + """Rich result from knowledge cache probing (P3). + + Replaces the flat ``List[str]`` that ``_probe_knowledge_cache`` used to return. + """ + + file_paths: List[str] + extra_keywords: List[str] + background_context: str + + +@dataclass +class CompileHints: + """Zero-LLM hints gathered from compile manifest and tree cache (P4).""" + + file_paths: List[str] + extra_keywords: List[str] + class AgenticSearch(BaseSearch): @@ -460,6 +498,72 @@ async def _try_reuse_cluster(self, query: str, paths: Optional[List[str]] = None ) return None + async def _try_soft_reuse( + self, query: str, paths: Optional[List[str]] = None, + ) -> Optional[SoftClusterHit]: + """Gradient reuse: extract structured hints from moderately similar clusters. + + Called when ``_try_reuse_cluster`` misses (similarity < hard threshold). + Uses a softer threshold to find clusters that are *related* but not + close enough for full reuse. Returns patterns, file paths, and a + background context summary that downstream phases can exploit. + """ + if not self.embedding_client or not self.embedding_client.is_ready(): + return None + + try: + query_embedding = (await self.embedding_client.embed([query]))[0] + similar = await self.knowledge_storage.search_similar_clusters( + query_embedding=query_embedding, + top_k=5, + similarity_threshold=_SOFT_SIM_THRESHOLD, + search_paths=paths, + ) + if not similar: + return None + + patterns: List[str] = [] + file_paths: List[str] = [] + context_parts: List[str] = [] + cluster_ids: List[str] = [] + seen_paths: set = set() + + for match in similar: + cid = match["id"] + cluster_ids.append(cid) + c = await self.knowledge_storage.get(cid) + if not c: + continue + for p in getattr(c, "patterns", []) or []: + if p and p not in patterns: + patterns.append(p) + for ev in getattr(c, "evidences", []): + fp = str(getattr(ev, "file_or_url", "")) + if fp and fp not in seen_paths and Path(fp).exists(): + seen_paths.add(fp) + file_paths.append(fp) + content = c.content + if isinstance(content, list): + content = "\n".join(content) + if content: + context_parts.append(str(content)[:500]) + + if not patterns and not file_paths: + return None + + await self._logger.info( + f"[SoftReuse] {len(similar)} soft hits: " + f"{len(patterns)} patterns, {len(file_paths)} files" + ) + return SoftClusterHit( + patterns=patterns[:10], + file_paths=file_paths[:10], + context_summary="\n\n".join(context_parts[:3]), + cluster_ids=cluster_ids, + ) + except Exception: + return None + def _add_query_to_cluster(self, cluster: KnowledgeCluster, query: str) -> None: """ Add query to cluster's queries list with FIFO strategy. @@ -478,6 +582,36 @@ def _add_query_to_cluster(self, cluster: KnowledgeCluster, query: str) -> None: # Remove oldest queries (from the beginning) cluster.queries = cluster.queries[-self.max_queries_per_cluster:] + @staticmethod + def _enrich_reused_content(cluster: KnowledgeCluster) -> str: + """Build the answer text from a reused cluster. + + When the cluster carries compiled evidence with non-empty snippets + (populated during ``sirchmunk compile``), appends them as supporting + excerpts so the user sees both the summary and the underlying source + material. + """ + content = cluster.content + if isinstance(content, list): + content = "\n".join(content) + content = str(content or "") + + evidence_parts: List[str] = [] + for ev in getattr(cluster, "evidences", []): + snippets = getattr(ev, "snippets", None) + if not snippets: + continue + source = str(getattr(ev, "file_or_url", "unknown")) + for snip in snippets: + text = snip if isinstance(snip, str) else snip.get("snippet", "") + if text and text.strip(): + evidence_parts.append(f"[{Path(source).name}] {text.strip()}") + + if evidence_parts: + content += "\n\n---\nSupporting evidence:\n" + "\n\n".join(evidence_parts[:5]) + + return content + async def _save_cluster_with_embedding(self, cluster: KnowledgeCluster) -> None: """Save knowledge cluster to persistent storage, compute embedding, and flush to parquet. @@ -1256,17 +1390,17 @@ async def _search_deep( # ============================================================== reused = await self._try_reuse_cluster(query, paths) if reused is not None: - content = reused.content - if isinstance(content, list): - content = "\n".join(content) - return str(content), reused, context + return self._enrich_reused_content(reused), reused, context + + # P2: gradient reuse — extract hints from moderately similar clusters + soft_hit = await self._try_soft_reuse(query, paths) await self._logger.info(f"[search] Starting multi-path retrieval for: '{query[:80]}'") # ============================================================== - # Phase 1: Parallel probing — all four paths fire concurrently + # Phase 1: Parallel probing — five paths fire concurrently # ============================================================== - await self._logger.info("[Phase 1] Parallel probing: keywords + dir_scan + knowledge + spec_cache") + await self._logger.info("[Phase 1] Parallel probing: keywords + dir_scan + knowledge + spec_cache + tree_index") context.increment_loop() phase1_results = await asyncio.gather( @@ -1274,24 +1408,53 @@ async def _search_deep( self._probe_dir_scan(paths, enable_dir_scan), self._probe_knowledge_cache(query), self._load_spec_context(paths, stale_hours=spec_stale_hours), + self._probe_tree_index(query), return_exceptions=True, ) kw_result = phase1_results[0] if not isinstance(phase1_results[0], Exception) else ({}, []) scan_result = phase1_results[1] if not isinstance(phase1_results[1], Exception) else None - knowledge_hits = phase1_results[2] if not isinstance(phase1_results[2], Exception) else [] + knowledge_probe = phase1_results[2] if not isinstance(phase1_results[2], Exception) else KnowledgeProbeResult([], [], "") spec_context = phase1_results[3] if not isinstance(phase1_results[3], Exception) else "" + tree_hits = phase1_results[4] if not isinstance(phase1_results[4], Exception) else [] - for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache"]): + for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index"]): if isinstance(phase1_results[i], Exception): await self._logger.warning(f"[Phase 1] {label} probe failed: {phase1_results[i]}") + # Backwards compat: knowledge_probe may be a plain list from old code paths + if isinstance(knowledge_probe, list): + knowledge_probe = KnowledgeProbeResult(file_paths=knowledge_probe, extra_keywords=[], background_context="") + query_keywords, initial_keywords = kw_result if isinstance(kw_result, tuple) else ({}, []) + # P2: inject soft-hit patterns into keywords + if soft_hit: + for p in soft_hit.patterns: + if p not in initial_keywords: + initial_keywords.append(p) + if p not in query_keywords: + query_keywords[p] = 0.6 + + # P3: inject extra keywords from structured knowledge probe + for kw in knowledge_probe.extra_keywords: + if kw not in initial_keywords: + initial_keywords.append(kw) + if kw not in query_keywords: + query_keywords[kw] = 0.5 + + # P2 + P3: append background context for Phase 4 LLM prompt + if soft_hit and soft_hit.context_summary: + spec_context = f"{spec_context}\n\n{soft_hit.context_summary}" if spec_context else soft_hit.context_summary + if knowledge_probe.background_context: + spec_context = f"{spec_context}\n\n{knowledge_probe.background_context}" if spec_context else knowledge_probe.background_context + await self._logger.info( f"[Phase 1] Results: keywords={len(initial_keywords)}, " f"dir_scan={'OK' if scan_result else 'N/A'}, " - f"knowledge_hits={len(knowledge_hits)}, " + f"knowledge_files={len(knowledge_probe.file_paths)}, " + f"tree_hits={len(tree_hits)}, " + f"soft_hit={'YES' if soft_hit else 'NO'}, " f"spec_cache={'YES' if spec_context else 'NO'}" ) @@ -1336,12 +1499,16 @@ async def _search_deep( # ============================================================== # Phase 3: Merge file paths + build KnowledgeCluster + # P1 tree hits get highest priority; P2 soft-hit files next # ============================================================== context.increment_loop() + extra_knowledge_files = knowledge_probe.file_paths + if soft_hit: + extra_knowledge_files = soft_hit.file_paths + extra_knowledge_files merged_files = self._merge_file_paths( - keyword_files=keyword_files, + keyword_files=list(tree_hits) + keyword_files, dir_scan_files=dir_scan_files, - knowledge_hits=knowledge_hits, + knowledge_hits=extra_knowledge_files, ) await self._logger.info(f"[Phase 3] Merged {len(merged_files)} unique candidate files") @@ -1352,6 +1519,19 @@ async def _search_deep( query_keywords=query_keywords, top_k_files=top_k_files, ) + # ============================================================== + # Phase 3.5: Graph context enrichment (P5) + # Append related knowledge from graph neighbours to cluster content + # so the answer-generation LLM has richer context. + # ============================================================== + graph_ctx = "" + if cluster: + graph_ctx = await self._gather_graph_context(cluster) + if graph_ctx and cluster.content: + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = f"{cluster.content}\n\n{graph_ctx}" + # ============================================================== # Phase 4: Generate answer — cluster summary or ReAct refinement # ============================================================== @@ -1383,9 +1563,11 @@ async def _search_deep( answer, should_save = await self._summarise_cluster_fallback(query) else: await self._logger.info("[Phase 4] Evidence insufficient, launching ReAct refinement") + # P5: enrich ReAct context with graph knowledge + react_spec = f"{spec_context}\n\n{graph_ctx}" if graph_ctx else spec_context react_answer, context = await self._react_refinement( query=query, paths=paths, - initial_keywords=initial_keywords, spec_context=spec_context, + initial_keywords=initial_keywords, spec_context=react_spec, enable_dir_scan=enable_dir_scan, max_loops=max_loops, max_token_budget=max_token_budget, max_depth=max_depth, include=include, exclude=exclude, @@ -1751,11 +1933,11 @@ async def _search_fast( # ============================================================== reused = await self._try_reuse_cluster(query, paths) if reused is not None: - content = reused.content - if isinstance(content, list): - content = "\n".join(content) await self._logger.success("[FAST] Reused cached knowledge cluster") - return str(content), reused, context + return self._enrich_reused_content(reused), reused, context + + # P2: gradient reuse — structured hints from moderately similar clusters + soft_hit = await self._try_soft_reuse(query, paths) # ============================================================== # Step 1: LLM query analysis only (dir scan deferred until needed) @@ -1833,6 +2015,38 @@ async def _search_fast( msg = f"Could not extract search terms from query: '{query}'" return msg, None, context + # ============================================================== + # Step 1.5: Compile-aware enrichment (P2 + P4, zero LLM calls) + # ============================================================== + all_kw_set = set(primary + fallback) + + # P2: inject soft-hit patterns as fallback keywords + if soft_hit: + for p in soft_hit.patterns: + if p not in all_kw_set: + fallback.append(p) + all_kw_set.add(p) + keyword_idfs.setdefault(p, 0.6) + + # P4: compile hints from manifest + tree cache + compile_hints = await self._probe_compile_hints(primary + fallback) + for kw in compile_hints.extra_keywords: + if kw not in all_kw_set: + fallback.append(kw) + all_kw_set.add(kw) + keyword_idfs.setdefault(kw, 0.5) + + compile_hint_files: List[str] = [] + if soft_hit: + compile_hint_files.extend(soft_hit.file_paths) + compile_hint_files.extend(compile_hints.file_paths) + + if compile_hint_files: + await self._logger.info( + f"[FAST:Step1.5] Compile hints: {len(compile_hint_files)} files, " + f"{len(compile_hints.extra_keywords)} extra keywords" + ) + await self._logger.info( f"[FAST:Step1] Primary: {primary}, Fallback: {fallback}" ) @@ -1870,6 +2084,17 @@ async def _search_fast( fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, **rga_kwargs ) + # --- Fallback: compile-hint files when rga misses (P2+P4) --- + if not best_files and compile_hint_files: + used_level = "compile_hint" + await self._logger.info( + f"[FAST:Step2] rga miss — using {len(compile_hint_files)} compile-hint files" + ) + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in compile_hint_files[:top_k_files] + ] + # --- Fallback: use dir_scan only when rga misses and dir scan is enabled --- if not best_files and enable_dir_scan: scan_result = await self._probe_dir_scan(paths, enable=True, max_files=300) @@ -2592,32 +2817,251 @@ async def _probe_dir_scan( async def _probe_knowledge_cache( self, query: str, - ) -> List[str]: - """Search knowledge cache for related clusters, return known file paths. + ) -> KnowledgeProbeResult: + """Structured knowledge probe: embedding search with graph expansion. - Returns: - List of file paths from previously cached clusters. + Uses embedding similarity (threshold 0.50) when available, falling back + to SQL LIKE. Extracts file paths, topic keywords, and background + context from matched clusters and their graph neighbours. """ + empty = KnowledgeProbeResult([], [], "") try: - clusters = await self.knowledge_storage.find(query, limit=3) + clusters: List[KnowledgeCluster] = [] + + # Prefer embedding search for semantic quality + if self.embedding_client and self.embedding_client.is_ready(): + try: + qe = (await self.embedding_client.embed([query]))[0] + similar = await self.knowledge_storage.search_similar_clusters( + query_embedding=qe, top_k=5, similarity_threshold=0.50, + ) + for m in (similar or []): + c = await self.knowledge_storage.get(m["id"]) + if c: + clusters.append(c) + except Exception: + pass + + # Fallback to SQL LIKE when embedding unavailable or empty if not clusters: - return [] + clusters = await self.knowledge_storage.find(query, limit=3) + if not clusters: + return empty + + seen_paths: set = set() file_paths: List[str] = [] - for c in clusters: + extra_keywords: List[str] = [] + context_parts: List[str] = [] + seen_kw: set = set() + + def _collect_cluster(c: KnowledgeCluster) -> None: for ev in getattr(c, "evidences", []): fp = str(getattr(ev, "file_or_url", "")) - if fp and Path(fp).exists(): + if fp and fp not in seen_paths and Path(fp).exists(): + seen_paths.add(fp) file_paths.append(fp) + for p in getattr(c, "patterns", []) or []: + if p and p.lower() not in seen_kw: + seen_kw.add(p.lower()) + extra_keywords.append(p) + content = c.content + if isinstance(content, list): + content = "\n".join(content) + if content: + context_parts.append(str(content)[:500]) + + for c in clusters: + _collect_cluster(c) + + # One-hop graph expansion via WeakSemanticEdge + neighbour_ids: set = set() + for c in clusters: + for edge in getattr(c, "related_clusters", []): + tid = getattr(edge, "target_cluster_id", None) + if tid and tid not in neighbour_ids: + neighbour_ids.add(tid) + + for nid in list(neighbour_ids)[:6]: + try: + neighbour = await self.knowledge_storage.get(nid) + if neighbour: + _collect_cluster(neighbour) + except Exception: + pass if file_paths: await self._logger.info( - f"[Probe:Knowledge] Found {len(file_paths)} files from cached clusters" + f"[Probe:Knowledge] {len(file_paths)} files, " + f"{len(extra_keywords)} keywords from " + f"{len(clusters)} clusters + {len(neighbour_ids)} neighbours" + ) + + return KnowledgeProbeResult( + file_paths=file_paths, + extra_keywords=extra_keywords[:15], + background_context="\n\n".join(context_parts[:3]), + ) + except Exception: + return empty + + async def _probe_tree_index(self, query: str) -> List[str]: + """LLM-driven file discovery via compiled tree root summaries (PageIndex). + + Loads all cached document trees, presents their root summaries to the + LLM, and asks it to select the most relevant 1-3 documents. For + selected trees, optionally drills one level deeper into children. + + Returns file paths of the most relevant documents. + """ + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if not tree_cache.exists(): + return [] + + try: + from sirchmunk.learnings.tree_indexer import DocumentTree + + trees: List[DocumentTree] = [] + for tree_file in sorted(tree_cache.glob("*.json"))[:50]: + try: + t = DocumentTree.from_json( + tree_file.read_text(encoding="utf-8") + ) + if t.root and t.file_path: + trees.append(t) + except Exception: + continue + + if not trees: + return [] + + # If few trees, return all without LLM + if len(trees) <= 2: + return [t.file_path for t in trees if Path(t.file_path).exists()] + + # LLM-driven selection among tree roots + listing = "\n".join( + f"[{i}] {Path(t.file_path).name}: {(t.root.summary or '')[:200]}" + for i, t in enumerate(trees) + ) + prompt = ( + f'Given the query: "{query}"\n\n' + f"Select the 1-3 most relevant documents (by index number):\n{listing}\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + resp = await self.llm.achat([{"role": "user", "content": prompt}]) + self.llm_usages.append(resp.usage) + + selected_indices: List[int] = [] + try: + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + selected_indices = [ + idx for idx in json.loads(m.group()) + if isinstance(idx, int) and 0 <= idx < len(trees) + ] + except (json.JSONDecodeError, TypeError): + pass + + if not selected_indices: + selected_indices = list(range(min(2, len(trees)))) + + result_paths: List[str] = [] + for idx in selected_indices: + fp = trees[idx].file_path + if Path(fp).exists(): + result_paths.append(fp) + + if result_paths: + await self._logger.info( + f"[Probe:TreeIndex] LLM selected {len(result_paths)} documents " + f"from {len(trees)} tree indices" ) - return file_paths + return result_paths except Exception: return [] + async def _probe_compile_hints(self, keywords: List[str]) -> CompileHints: + """Zero-LLM enrichment from compile manifest and tree cache. + + Scans the compile manifest for clusters whose patterns overlap with + the query keywords, and scans cached tree root summaries for keyword + matches. No LLM calls — only local JSON reads and in-memory DB lookups. + """ + empty = CompileHints([], []) + if not keywords: + return empty + + kw_lower = {k.lower() for k in keywords} + file_paths: List[str] = [] + extra_keywords: List[str] = [] + seen_paths: set = set() + seen_kw: set = set(kw_lower) + + # --- Cluster pattern matching via manifest --- + manifest_path = self.work_path / ".cache" / "compile" / "manifest.json" + if manifest_path.exists(): + try: + from sirchmunk.learnings.compiler import CompileManifest + manifest = CompileManifest.from_json( + manifest_path.read_text(encoding="utf-8") + ) + cluster_ids: set = set() + for entry in manifest.files.values(): + cluster_ids.update(entry.cluster_ids) + + for cid in list(cluster_ids)[:50]: + try: + c = await self.knowledge_storage.get(cid) + except Exception: + continue + if not c: + continue + cluster_patterns = [ + p.lower() for p in (getattr(c, "patterns", []) or []) if p + ] + if kw_lower & set(cluster_patterns): + for ev in getattr(c, "evidences", []): + fp = str(getattr(ev, "file_or_url", "")) + if fp and fp not in seen_paths and Path(fp).exists(): + seen_paths.add(fp) + file_paths.append(fp) + for p in cluster_patterns: + if p not in seen_kw: + seen_kw.add(p) + extra_keywords.append(p) + except Exception: + pass + + # --- Tree root summary scanning (keyword substring match) --- + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if tree_cache.exists(): + try: + from sirchmunk.learnings.tree_indexer import DocumentTree + for tree_file in sorted(tree_cache.glob("*.json"))[:100]: + try: + tree = DocumentTree.from_json( + tree_file.read_text(encoding="utf-8") + ) + except Exception: + continue + if not tree.root or not tree.file_path: + continue + summary_lower = (tree.root.summary or "").lower() + if any(kw in summary_lower for kw in kw_lower): + fp = tree.file_path + if fp not in seen_paths and Path(fp).exists(): + seen_paths.add(fp) + file_paths.append(fp) + except Exception: + pass + + return CompileHints( + file_paths=file_paths[:15], + extra_keywords=extra_keywords[:10], + ) + @staticmethod async def _async_noop(default=None): """No-op coroutine used as placeholder in gather().""" @@ -2744,6 +3188,20 @@ def _merge_file_paths( return merged + def _get_tree_indexer(self): + """Lazily construct a DocumentTreeIndexer for search-time tree navigation.""" + from sirchmunk.learnings.tree_indexer import DocumentTreeIndexer + + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if not tree_cache.exists(): + return None + _cb = getattr(self._logger, 'log_callback', None) + return DocumentTreeIndexer( + llm=self.llm, + cache_dir=tree_cache, + log_callback=_cb, + ) + async def _build_cluster( self, query: str, @@ -2755,7 +3213,9 @@ async def _build_cluster( """Build a KnowledgeCluster via knowledge_base.build(). Constructs the Request wrapper and delegates to the knowledge - base for parallel Monte Carlo evidence sampling. + base for parallel Monte Carlo evidence sampling. When compiled + tree indices exist, passes a ``tree_indexer`` so that evidence + extraction can navigate to relevant sections before sampling. """ try: request = Request( @@ -2775,6 +3235,7 @@ async def _build_cluster( top_k_files=top_k_files, top_k_snippets=top_k_snippets, verbose=self.verbose, + tree_indexer=self._get_tree_indexer(), ) self.llm_usages.extend(self.knowledge_base.llm_usages) self.knowledge_base.llm_usages.clear() @@ -2789,6 +3250,47 @@ async def _build_cluster( await self._logger.warning(f"[Phase 3] knowledge_base.build() failed: {exc}") return None + async def _gather_graph_context(self, cluster: KnowledgeCluster) -> str: + """Enrich answer context with knowledge from graph neighbours. + + Traverses the cluster's ``related_clusters`` edges (sorted by weight), + fetches the top neighbours, and returns a joined summary string that + can be appended to the cluster content before answer generation. + """ + edges = sorted( + getattr(cluster, "related_clusters", []) or [], + key=lambda e: getattr(e, "weight", 0), + reverse=True, + ) + if not edges: + return "" + + parts: List[str] = [] + for edge in edges[:3]: + tid = getattr(edge, "target_cluster_id", None) + if not tid: + continue + try: + neighbour = await self.knowledge_storage.get(tid) + except Exception: + continue + if not neighbour: + continue + content = neighbour.content + if isinstance(content, list): + content = "\n".join(content) + name = getattr(neighbour, "name", "") or "" + snippet = str(content or "")[:300] + if snippet: + parts.append(f"- {name}: {snippet}") + + if not parts: + return "" + await self._logger.info( + f"[Phase 3.5] Graph context: {len(parts)} neighbour summaries" + ) + return "Related knowledge:\n" + "\n".join(parts) + # ------------------------------------------------------------------ # Phase 4: Answer generation # ------------------------------------------------------------------ From 645847766859fb846c7b7ca899d529a07b1e904c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 13 Apr 2026 21:00:33 +0800 Subject: [PATCH 04/70] improve search pipeline for compile mode --- src/sirchmunk/learnings/compiler.py | 59 ++++- src/sirchmunk/llm/prompts.py | 25 +++ src/sirchmunk/search.py | 336 +++++++++++++++++++++++----- 3 files changed, 366 insertions(+), 54 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 3c2b0da..4ccd5da 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -380,11 +380,14 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: await self._log.info("[Compile] Phase 4: Building cross-references") report.cross_refs_built = await self._build_cross_references(results) - # Phase 5: persist manifest + # Phase 5: persist manifest + document catalog manifest.last_compile_at = datetime.now(timezone.utc).isoformat() self._save_manifest(manifest) self._storage.force_sync() + # Generate document catalog for search-time routing + self._build_document_catalog(manifest) + report.elapsed_seconds = time.monotonic() - t0 await self._log.info( f"[Compile] Done in {report.elapsed_seconds:.1f}s — " @@ -838,3 +841,57 @@ def _load_manifest(self) -> CompileManifest: def _save_manifest(self, manifest: CompileManifest) -> None: self._manifest_path.write_text(manifest.to_json(), encoding="utf-8") + + # ------------------------------------------------------------------ # + # Document catalog for search-time routing # + # ------------------------------------------------------------------ # + + def _build_document_catalog(self, manifest: CompileManifest) -> None: + """Generate a lightweight catalog mapping files to their tree root summaries. + + The catalog is consumed by FAST search to fuse query analysis with + LLM-driven document routing in a single prompt. Each entry carries + the filename and a truncated root summary (≤250 chars). + """ + tree_cache = self._compile_dir / "trees" + entries: List[Dict[str, str]] = [] + + for file_path, entry in manifest.files.items(): + summary = "" + if entry.has_tree and tree_cache.exists(): + tree_file = tree_cache / f"{entry.file_hash}.json" + if tree_file.exists(): + try: + tree = DocumentTree.from_json( + tree_file.read_text(encoding="utf-8"), + ) + if tree.root and tree.root.summary: + summary = tree.root.summary[:250] + except Exception: + pass + + if not summary: + # Fallback: use first cluster's description + for cid in entry.cluster_ids[:1]: + try: + import asyncio + loop = asyncio.get_event_loop() + if loop.is_running(): + break + c = loop.run_until_complete(self._storage.get(cid)) + if c and c.description: + summary = str(c.description[0])[:250] + except Exception: + break + + entries.append({ + "path": file_path, + "name": Path(file_path).name, + "summary": summary, + }) + + catalog_path = self._compile_dir / "document_catalog.json" + catalog_path.write_text( + json.dumps(entries, ensure_ascii=False, indent=2), + encoding="utf-8", + ) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index b3ded32..7b07b1c 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -389,6 +389,31 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: """ +FAST_QUERY_ANALYSIS_WITH_CATALOG = """Classify the user query, extract search terms, AND select the most relevant document(s) from the compiled index. + +### User Query +{user_input} + +### Compiled Document Index +{document_listing} + +### Output +Return JSON only, no extra text: +{{"type": "search", "primary": ["compound phrase"], "fallback": ["term1", "term2"], "idf": {{"compound phrase": 8.0, "term1": 2.5}}, "primary_alt": [], "fallback_alt": [], "file_hints": [], "intent": "...", "selected_docs": [0, 2], "doc_confidence": "high"}} + +Rules: +- **type**: "search" if the query requires retrieving information from files or documents; "chat" if it is a greeting, small talk, or conversational message — set primary/fallback to empty arrays, put a brief reply in "response". "summary" if the user wants to summarize entire documents. +- **primary**: 1 compound phrase (2-3 words) most likely to appear **verbatim** in the target document. +- **fallback**: 1-3 single-word atomic terms. Tried only if primary misses. +- **primary_alt / fallback_alt**: Cross-lingual equivalents (Chinese↔English). Only the most critical 1-2 terms. +- **file_hints**: filename fragments or glob patterns ONLY if clearly implied; empty array otherwise. +- **intent**: one sentence describing the query intent. +- **idf**: IDF weight (1.0-10.0) for EVERY keyword. Higher for rare terms. +- **selected_docs**: Index numbers (from the Compiled Document Index above) of the 1-3 most relevant documents for this query. Consider BOTH the filename and the summary. Choose documents whose content is most likely to answer the query. +- **doc_confidence**: "high" if you are very confident the selected documents contain the answer; "medium" if likely but uncertain; "low" if guessing. +""" + + ROI_RESULT_SUMMARY = """ ### Task Analyze the provided {text_content} and generate a concise summary in the form of a Markdown Briefing. diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 63d3ba8..9d192da 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -20,6 +20,7 @@ KEYWORD_QUERY_PLACEHOLDER, generate_keyword_extraction_prompt, FAST_QUERY_ANALYSIS, + FAST_QUERY_ANALYSIS_WITH_CATALOG, ROI_RESULT_SUMMARY, SEARCH_RESULT_SUMMARY, DOC_SUMMARY, @@ -1940,9 +1941,26 @@ async def _search_fast( soft_hit = await self._try_soft_reuse(query, paths) # ============================================================== - # Step 1: LLM query analysis only (dir scan deferred until needed) + # Step 1: Fused LLM query analysis + document routing + # When a compiled document catalog exists, the LLM sees all + # document summaries and selects the most relevant ones in the + # same call that extracts keywords (zero extra LLM cost). # ============================================================== - prompt = FAST_QUERY_ANALYSIS.format(user_input=query) + catalog = self._load_document_catalog() + catalog_routed_files: List[str] = [] + catalog_confidence: str = "low" + + if catalog: + listing = "\n".join( + f"[{i}] {e['name']}: {e['summary'][:200]}" + for i, e in enumerate(catalog) + ) + prompt = FAST_QUERY_ANALYSIS_WITH_CATALOG.format( + user_input=query, document_listing=listing, + ) + else: + prompt = FAST_QUERY_ANALYSIS.format(user_input=query) + resp = await self.llm.achat( messages=[{"role": "user", "content": prompt}], stream=False, @@ -1957,6 +1975,21 @@ async def _search_fast( query_type = analysis.get("type", "search") file_hints = analysis.get("file_hints", []) + # Extract catalog-routed files from the fused response + if catalog: + selected_indices = analysis.get("selected_docs", []) + catalog_confidence = analysis.get("doc_confidence", "low") + for idx in selected_indices: + if isinstance(idx, int) and 0 <= idx < len(catalog): + fp = catalog[idx]["path"] + if Path(fp).exists(): + catalog_routed_files.append(fp) + if catalog_routed_files: + await self._logger.info( + f"[FAST:Step1] Catalog routing ({catalog_confidence}): " + f"{[Path(p).name for p in catalog_routed_files]}" + ) + if query_type == "chat": chat_reply = analysis.get("response", "") if chat_reply: @@ -2017,6 +2050,7 @@ async def _search_fast( # ============================================================== # Step 1.5: Compile-aware enrichment (P2 + P4, zero LLM calls) + # Catalog-routed files from the fused Step 1 are merged here. # ============================================================== all_kw_set = set(primary + fallback) @@ -2037,13 +2071,26 @@ async def _search_fast( keyword_idfs.setdefault(kw, 0.5) compile_hint_files: List[str] = [] + # Catalog-routed files get highest priority + seen_hint_paths: set = set() + for fp in catalog_routed_files: + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) if soft_hit: - compile_hint_files.extend(soft_hit.file_paths) - compile_hint_files.extend(compile_hints.file_paths) + for fp in soft_hit.file_paths: + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) + for fp in compile_hints.file_paths: + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) if compile_hint_files: await self._logger.info( - f"[FAST:Step1.5] Compile hints: {len(compile_hint_files)} files, " + f"[FAST:Step1.5] Compile hints: {len(compile_hint_files)} files " + f"(catalog={len(catalog_routed_files)}, soft={len(soft_hit.file_paths) if soft_hit else 0}), " f"{len(compile_hints.extra_keywords)} extra keywords" ) @@ -2053,7 +2100,9 @@ async def _search_fast( # ============================================================== # Step 2: rga cascade — primary first, fallback only if needed - # Dir scan runs only when enabled, for fallback when rga misses. + # When catalog routing has high confidence, catalog-routed files + # are used directly (skipping rga) to avoid noise from unrelated + # files. Otherwise rga runs first and catalog acts as fallback. # ============================================================== context.add_search(query) include_patterns = list(include or []) @@ -2070,7 +2119,19 @@ async def _search_fast( used_level = "primary" evidence = "" - if primary: + # High-confidence catalog routing: skip rga, use catalog directly + if catalog_routed_files and catalog_confidence == "high": + used_level = "catalog_route" + await self._logger.info( + f"[FAST:Step2] High-confidence catalog routing → " + f"{[Path(p).name for p in catalog_routed_files[:top_k_files]]}" + ) + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in catalog_routed_files[:top_k_files] + ] + + if not best_files and primary: best_files = await self._fast_find_best_file( primary, top_k=top_k_files, keyword_idfs=keyword_idfs, **rga_kwargs ) @@ -2084,7 +2145,7 @@ async def _search_fast( fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, **rga_kwargs ) - # --- Fallback: compile-hint files when rga misses (P2+P4) --- + # --- Fallback: compile-hint files when rga misses (catalog + P2 + P4) --- if not best_files and compile_hint_files: used_level = "compile_hint" await self._logger.info( @@ -2128,52 +2189,59 @@ async def _search_fast( ) # ============================================================== - # Step 3: Context sampling around grep hits (no LLM) - # Multi-file evidence aggregation + # Step 2.5 + Step 3: Tree navigation (1 LLM call) runs in + # parallel with rga evidence sampling (0 LLM). The merged + # result is higher quality than either alone. # ============================================================== - evidence_parts = [] - total_evidence_chars = 0 - for bf in best_files: - if total_evidence_chars >= self._FAST_MAX_EVIDENCE_CHARS: - break - - file_path = bf["path"] - fname = Path(file_path).name - ext = Path(file_path).suffix.lower() - - # Small file short-circuit: read full content instead of grep sampling - ev = None - if ext in self._FAST_TEXT_EXTENSIONS: - try: - file_size = Path(file_path).stat().st_size - if file_size < self._FAST_SMALL_FILE_THRESHOLD: - full_text = Path(file_path).read_text(errors="replace") - if len(full_text) < self._FAST_SMALL_FILE_THRESHOLD: - ev = f"[{fname}]\n{full_text}" - await self._logger.info( - f"[FAST] Small file short-circuit: reading full content of {fname} " - f"({len(full_text)} chars)" - ) - except Exception: - pass # Fall through to normal evidence extraction - - # Normal path: grep-based evidence sampling - if ev is None: - ev = await self._fast_sample_evidence(file_path, bf.get("matches", [])) - if ev: - remaining = self._FAST_MAX_EVIDENCE_CHARS - total_evidence_chars - chunk = ev[:remaining] - evidence_parts.append(chunk) - total_evidence_chars += len(chunk) - context.mark_file_read(file_path) - - evidence = "\n\n---\n\n".join(evidence_parts) + async def _rga_evidence() -> str: + """Collect rga-based evidence from best_files (zero LLM).""" + parts: List[str] = [] + chars = 0 + for bf in best_files: + if chars >= self._FAST_MAX_EVIDENCE_CHARS: + break + fp = bf["path"] + fn = Path(fp).name + ext = Path(fp).suffix.lower() + ev = None + if ext in self._FAST_TEXT_EXTENSIONS: + try: + sz = Path(fp).stat().st_size + if sz < self._FAST_SMALL_FILE_THRESHOLD: + full = Path(fp).read_text(errors="replace") + if len(full) < self._FAST_SMALL_FILE_THRESHOLD: + ev = f"[{fn}]\n{full}" + except Exception: + pass + if ev is None: + ev = await self._fast_sample_evidence(fp, bf.get("matches", [])) + if ev: + remaining = self._FAST_MAX_EVIDENCE_CHARS - chars + parts.append(ev[:remaining]) + chars += len(parts[-1]) + context.mark_file_read(fp) + return "\n\n---\n\n".join(parts) + + # Launch tree navigation for the primary file alongside rga + tree_nav_target = best_files[0]["path"] + rga_task = _rga_evidence() + tree_task = self._navigate_tree_for_evidence(tree_nav_target, query) + + rga_ev, tree_ev = await asyncio.gather(rga_task, tree_task) + + # Merge: tree evidence first (highest quality), then rga + evidence_parts_final: List[str] = [] + if tree_ev: + evidence_parts_final.append(tree_ev) + if rga_ev: + evidence_parts_final.append(rga_ev) + evidence = "\n\n---\n\n".join(evidence_parts_final) if not evidence or len(evidence.strip()) < 20: if llm_fallback: await self._logger.info( - "[FAST:Step3] No usable evidence, llm_fallback=True \u2192 LLM summary" + "[FAST:Step3] No usable evidence, llm_fallback=True → LLM summary" ) evidence = self._LLM_FALLBACK_EVIDENCE else: @@ -2181,7 +2249,8 @@ async def _search_fast( return _NO_RESULTS_MESSAGE, None, context await self._logger.info( - f"[FAST:Step3] Evidence: {len(evidence)} chars from {Path(file_path).name}" + f"[FAST:Step3] Evidence: {len(evidence)} chars " + f"(tree={'yes' if tree_ev else 'no'}, rga={'yes' if rga_ev else 'no'})" ) keywords_used = primary if used_level == "primary" else fallback @@ -2206,21 +2275,52 @@ async def _search_fast( answer, should_save, should_answer = self._parse_summary_response( answer_resp.content or "" ) + + # ============================================================== + # Step 5: Self-correction retry (conditional, ≤1 extra LLM call) + # When the answer gate rejects the first attempt, try alternative + # evidence sources before giving up. + # ============================================================== + if not should_answer: + retry_evidence = await self._fast_self_correct( + query, best_files, catalog_routed_files, context, + ) + if retry_evidence: + await self._logger.info( + f"[FAST:Step5] Retrying with {len(retry_evidence)} chars of alternative evidence" + ) + retry_prompt = ROI_RESULT_SUMMARY.format( + user_input=query, text_content=retry_evidence, + ) + retry_resp = await self.llm.achat( + messages=[{"role": "user", "content": retry_prompt}], + stream=True, + ) + self.llm_usages.append(retry_resp.usage) + if retry_resp.usage and isinstance(retry_resp.usage, dict): + context.add_llm_tokens( + retry_resp.usage.get("total_tokens", 0), usage=retry_resp.usage, + ) + answer, should_save, should_answer = self._parse_summary_response( + retry_resp.content or "" + ) + if not should_answer: if llm_fallback: await self._logger.info( - "[FAST:Step4] Summary gate rejected evidence, llm_fallback=True → LLM fallback" + "[FAST:Step5] Retry also rejected, llm_fallback=True → LLM fallback" ) answer, should_save = await self._summarise_fast_fallback(query, context) else: await self._logger.warning( - "[FAST:Step4] Summary gate rejected evidence and llm_fallback=False " + "[FAST:Step5] Evidence rejected after retry, llm_fallback=False " "→ returning no results" ) return _NO_RESULTS_MESSAGE, None, context + if not should_save: await self._logger.info("[FAST] Quality gate: low-quality answer, skipping cluster save") - await self._logger.success("[FAST] Search complete (2 LLM calls, no persist)") + await self._logger.success("[FAST] Search complete (no persist)") return answer, None, context cluster = self._build_fast_cluster( @@ -2234,7 +2334,7 @@ async def _search_fast( f"[FAST] Failed to save cluster with embedding: {exc}" ) - await self._logger.success("[FAST] Search complete (2 LLM calls)") + await self._logger.success("[FAST] Search complete") return answer, cluster, context # ---- FAST helpers ---- @@ -2634,6 +2734,136 @@ async def _fast_read_file_head( pass return "" + def _load_document_catalog(self) -> Optional[List[Dict[str, str]]]: + """Load the compiled document catalog for fused query+route prompt. + + Returns None when compile has not been run or catalog is missing. + """ + catalog_path = self.work_path / ".cache" / "compile" / "document_catalog.json" + if not catalog_path.exists(): + return None + try: + entries = json.loads(catalog_path.read_text(encoding="utf-8")) + if isinstance(entries, list) and entries: + return entries + except Exception: + pass + return None + + async def _navigate_tree_for_evidence( + self, file_path: str, query: str, + ) -> Optional[str]: + """LLM-driven tree navigation: select relevant sections and read leaf content. + + Uses 1 LLM call to drill into the compiled tree index for + *file_path*, returning concatenated leaf content as evidence. + Returns None when no tree cache is available. + """ + indexer = self._get_tree_indexer() + if indexer is None: + return None + tree = indexer.load_tree(file_path) + if tree is None or tree.root is None: + return None + + try: + leaves = await indexer.navigate(tree, query, max_results=3) + except Exception: + return None + + if not leaves: + return None + + fname = Path(file_path).name + # Read leaf content from the original document via char_range + parts: List[str] = [] + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + + for leaf in leaves: + start, end = leaf.char_range + if full_text and end > start: + segment = full_text[start:end] + else: + segment = leaf.summary or "" + if segment.strip(): + header = f"[{fname} → {leaf.title}]" + parts.append(f"{header}\n{segment[:3000]}") + + if not parts: + return None + + evidence = "\n\n".join(parts) + await self._logger.info( + f"[FAST:TreeNav] Extracted {len(parts)} sections, " + f"{len(evidence)} chars from {fname}" + ) + return evidence + + async def _fast_self_correct( + self, + query: str, + best_files: Optional[List[Dict[str, Any]]], + catalog_routed_files: List[str], + context: SearchContext, + ) -> Optional[str]: + """Attempt to gather alternative evidence when the first answer is rejected. + + Three strategies tried in order: + A) Tree-navigate a 2nd catalog-routed file not yet tried. + B) Retrieve the most semantically similar compiled cluster's content. + C) Tree-navigate the 2nd-best rga file if available. + + Returns alternative evidence string, or None if all strategies fail. + """ + first_file = best_files[0]["path"] if best_files else "" + + # Strategy A: 2nd catalog-routed file via tree navigation + for fp in catalog_routed_files: + if fp == first_file: + continue + tree_ev = await self._navigate_tree_for_evidence(fp, query) + if tree_ev and len(tree_ev.strip()) > 50: + context.mark_file_read(fp) + return tree_ev + + # Strategy B: cluster content from knowledge storage + if self.embedding_client and self.knowledge_storage: + try: + qe = self.embedding_client.encode(query) + if qe is not None: + vec = qe.tolist() if hasattr(qe, "tolist") else list(qe) + hits = await self.knowledge_storage.search_similar_clusters( + query_embedding=vec, top_k=2, similarity_threshold=0.50, + ) + if hits: + parts: List[str] = [] + for h in hits[:2]: + c = await self.knowledge_storage.get(h["id"]) + if c and c.content: + parts.append(str(c.content)[:3000]) + for ev in (c.evidences or [])[:3]: + for s in (ev.snippets or [])[:2]: + parts.append(s[:500]) + if parts: + return "\n\n---\n\n".join(parts) + except Exception: + pass + + # Strategy C: 2nd rga file via tree navigation + if best_files and len(best_files) > 1: + fp2 = best_files[1]["path"] + tree_ev = await self._navigate_tree_for_evidence(fp2, query) + if tree_ev and len(tree_ev.strip()) > 50: + context.mark_file_read(fp2) + return tree_ev + + return None + @staticmethod def _parse_fast_json(text: str) -> Dict[str, Any]: """Extract JSON from the FAST query analysis LLM response.""" From 1f6f799fd1c3bcb357ae5cfb5d436b90c1c0f647 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 14 Apr 2026 19:21:25 +0800 Subject: [PATCH 05/70] fix and enhance llm wiki and tree index for FAST search --- src/sirchmunk/llm/prompts.py | 39 ++++ src/sirchmunk/search.py | 394 +++++++++++++++++++++++++++++++++-- 2 files changed, 417 insertions(+), 16 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 7b07b1c..8df111d 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -449,6 +449,45 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: true/false """ +ROI_RESULT_SUMMARY_WITH_CONTEXT = """ +### Task +Analyze the provided evidence and generate a concise summary in the form of a Markdown Briefing. +Leverage the document context below for better understanding of the source material's structure and purpose. + +### Constraints +1. **Language Continuity**: The output must be in the SAME language as the User Input. +2. **Format**: Use Markdown (headings, bullet points, and bold text) for high readability. +3. **Style**: Keep it professional, objective, and clear. Avoid fluff. + +### Document Context +{document_context} + +### Input Data +- **User Input**: {user_input} +- **Search Result Text**: {text_content} + +### Quality Evaluation +After generating the summary, make TWO decisions: +1) whether the query can be answered from the provided evidence; +2) whether this result is worth caching. + +Evaluate based on: +1. Does the search result contain substantial, relevant information for the user input? +2. Is the content meaningful and not just error messages or "no information found"? +3. Are there sufficient evidences and context to answer the user's query? + +- : output "true" only if the evidence is sufficient to answer the query. +- : output "true" only if the evidence is sufficient AND the result is worth caching. +- If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". + +### Output Format + +[Generate the Markdown Briefing here] + +true/false +true/false +""" + # --------------------------------------------------------------------------- # Knowledge Compile prompts diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 9d192da..2976f60 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -11,7 +11,7 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union from sirchmunk.base import BaseSearch from sirchmunk.learnings.knowledge_base import KnowledgeBase @@ -121,6 +121,22 @@ class CompileHints: extra_keywords: List[str] +@dataclass +class CompileArtifacts: + """Compile artifact availability context for adaptive activation in FAST mode. + + Created once at the start of ``_search_fast()`` via + ``_detect_compile_artifacts()`` and threaded through all pipeline steps. + Each step checks the relevant field and falls back gracefully when the + artifact is absent. + """ + + catalog: Optional[List[Dict[str, str]]] + catalog_map: Dict[str, Dict[str, str]] # path -> catalog entry for O(1) lookup + tree_indexer: Optional[Any] # DocumentTreeIndexer (lazy import) + tree_available_paths: Set[str] # file paths that have cached tree indices + + class AgenticSearch(BaseSearch): def __init__( @@ -1893,6 +1909,32 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _FAST_MAX_EVIDENCE_CHARS = 15_000 _FAST_SMALL_FILE_THRESHOLD = 100_000 # 100K chars - read full file instead of grep sampling + # --- Wiki-enhanced ranking constants --- + _WIKI_BLEND_ALPHA = 0.7 + """TF-IDF weight in the hybrid score; Wiki weight = 1 - alpha.""" + _WIKI_MAX_SCORE = 10.0 + """Upper bound for the wiki relevance score.""" + _WIKI_CATALOG_KEYWORD_OVERLAP_MAX = 5.0 + """Maximum sub-score for catalog summary keyword overlap.""" + _WIKI_TREE_AVAILABILITY_BONUS = 2.0 + """Bonus for files that have a compiled tree index.""" + _WIKI_CATALOG_PRESENCE_FULL = 3.0 + """Catalog presence bonus for summaries > 100 chars.""" + _WIKI_CATALOG_PRESENCE_MEDIUM = 2.0 + """Catalog presence bonus for summaries > 30 chars.""" + _WIKI_CATALOG_PRESENCE_MINIMAL = 1.0 + """Catalog presence bonus for summaries > 0 chars.""" + _TREE_CACHE_SCAN_LIMIT = 200 + """Max tree JSON files to parse during artifact detection.""" + _CATALOG_LISTING_MAX_ENTRIES = 20 + """Max catalog entries in the enriched listing for Step 1.""" + _CATALOG_KEYWORD_MIN_LEN = 2 + """Minimum character length for a catalog keyword token.""" + _CATALOG_KEYWORD_MAX_LEN = 20 + """Maximum character length for a catalog keyword token.""" + _CATALOG_SUMMARY_TRUNCATE = 200 + """Max chars of catalog summary shown in the listing.""" + _LLM_FALLBACK_EVIDENCE = ( "[No relevant documents found]\n\n" "The search did not find relevant content in the available documents. " @@ -1928,6 +1970,15 @@ async def _search_fast( context = SearchContext() await self._logger.info(f"[FAST] Starting greedy search for: '{query[:80]}'") + # --- Adaptive compile artifact detection (one-shot, zero LLM) --- + artifacts = self._detect_compile_artifacts() + if artifacts.catalog or artifacts.tree_available_paths: + await self._logger.info( + f"[FAST:Artifacts] catalog={'yes' if artifacts.catalog else 'no'} " + f"({len(artifacts.catalog) if artifacts.catalog else 0} docs), " + f"trees={len(artifacts.tree_available_paths)}" + ) + # ============================================================== # Step 0: Cluster reuse — instant short-circuit (no LLM cost) # When reuse succeeds we return here; no persistence step runs. @@ -1946,15 +1997,12 @@ async def _search_fast( # document summaries and selects the most relevant ones in the # same call that extracts keywords (zero extra LLM cost). # ============================================================== - catalog = self._load_document_catalog() + catalog = artifacts.catalog catalog_routed_files: List[str] = [] catalog_confidence: str = "low" if catalog: - listing = "\n".join( - f"[{i}] {e['name']}: {e['summary'][:200]}" - for i, e in enumerate(catalog) - ) + listing = self._build_enriched_catalog_listing(catalog) prompt = FAST_QUERY_ANALYSIS_WITH_CATALOG.format( user_input=query, document_listing=listing, ) @@ -2118,6 +2166,7 @@ async def _search_fast( best_files: Optional[List[Dict[str, Any]]] = None used_level = "primary" evidence = "" + file_path: Optional[str] = None # set when best_files found # High-confidence catalog routing: skip rga, use catalog directly if catalog_routed_files and catalog_confidence == "high": @@ -2133,7 +2182,9 @@ async def _search_fast( if not best_files and primary: best_files = await self._fast_find_best_file( - primary, top_k=top_k_files, keyword_idfs=keyword_idfs, **rga_kwargs + primary, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + **rga_kwargs, ) if not best_files and fallback: @@ -2142,7 +2193,9 @@ async def _search_fast( "[FAST:Step2] Primary miss, trying fine-grained fallback" ) best_files = await self._fast_find_best_file( - fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, **rga_kwargs + fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + **rga_kwargs, ) # --- Fallback: compile-hint files when rga misses (catalog + P2 + P4) --- @@ -2183,9 +2236,13 @@ async def _search_fast( if best_files: file_path = best_files[0]["path"] match_objects = best_files[0].get("matches", []) + wiki_info = "" + if best_files[0].get("wiki_relevance") is not None: + wiki_info = f", wiki={best_files[0]['wiki_relevance']:.1f}" await self._logger.info( f"[FAST:Step2] Best file ({used_level}): {Path(file_path).name} " - f"({best_files[0].get('total_matches', 0)} hits, score={best_files[0].get('weighted_score', 0):.2f})" + f"({best_files[0].get('total_matches', 0)} hits, " + f"score={best_files[0].get('weighted_score', 0):.2f}{wiki_info})" ) # ============================================================== @@ -2248,20 +2305,35 @@ async def _rga_evidence() -> str: await self._logger.warning("[FAST:Step3] No usable evidence extracted") return _NO_RESULTS_MESSAGE, None, context + tree_available = file_path in artifacts.tree_available_paths if artifacts else False await self._logger.info( f"[FAST:Step3] Evidence: {len(evidence)} chars " - f"(tree={'yes' if tree_ev else 'no'}, rga={'yes' if rga_ev else 'no'})" + f"(tree={'yes' if tree_ev else 'no'}, rga={'yes' if rga_ev else 'no'}, " + f"tree_indexed={'yes' if tree_available else 'no'})" ) keywords_used = primary if used_level == "primary" else fallback # ============================================================== # Step 4: LLM answer from focused evidence (single call) + # Wiki-enhanced: inject document context when catalog available. # ============================================================== - answer_prompt = ROI_RESULT_SUMMARY.format( - user_input=query, - text_content=evidence, - ) + doc_context = self._build_answer_context(file_path, artifacts) if best_files else None + if doc_context: + from sirchmunk.llm.prompts import ROI_RESULT_SUMMARY_WITH_CONTEXT + answer_prompt = ROI_RESULT_SUMMARY_WITH_CONTEXT.format( + user_input=query, + text_content=evidence, + document_context=doc_context, + ) + await self._logger.info( + f"[FAST:Step4] Wiki-enhanced answer generation with catalog context" + ) + else: + answer_prompt = ROI_RESULT_SUMMARY.format( + user_input=query, + text_content=evidence, + ) answer_resp = await self.llm.achat( messages=[{"role": "user", "content": answer_prompt}], stream=True, @@ -2324,7 +2396,7 @@ async def _rga_evidence() -> str: return answer, None, context cluster = self._build_fast_cluster( - query, answer, file_path, evidence, keywords_used, + query, answer, file_path or "", evidence, keywords_used, ) self._add_query_to_cluster(cluster, query) try: @@ -2475,6 +2547,86 @@ def _prune_by_score( # Cap at top_k return result[:top_k] + @staticmethod + def _compute_wiki_relevance( + file_path: str, + query: str, + keywords: List[str], + catalog_map: Dict[str, Dict[str, str]], + tree_available_paths: Set[str], + ) -> float: + """Compute wiki-based relevance score for a candidate file (0-10 scale). + + Uses three sub-scores derived from compile artifacts: + + 1. **Catalog summary overlap** (0-``_WIKI_CATALOG_KEYWORD_OVERLAP_MAX``): + proportion of query keywords that appear in the catalog entry's + summary. When *keywords* is empty, falls back to whole-query + substring matching against the summary to avoid returning 0 for + valid queries. + 2. **Tree availability bonus** (0-``_WIKI_TREE_AVAILABILITY_BONUS``): + a file with a compiled tree index likely has rich structure. + 3. **Catalog presence bonus** (0-``_WIKI_CATALOG_PRESENCE_FULL``): + files important enough to be in the catalog get a baseline boost. + + All scoring is pure text matching — no LLM, no embedding. + + Args: + file_path: Absolute path of the candidate file. + query: Original user query. + keywords: Extracted search keywords from FAST Step 1. + catalog_map: ``{path: catalog_entry}`` from CompileArtifacts. + tree_available_paths: Set of file paths with cached tree indices. + + Returns: + Float in [0, 10] representing wiki-derived relevance. + """ + cls = AgenticSearch # access class constants from static method + score = 0.0 + + entry = catalog_map.get(file_path) + + # Sub-score 1: Catalog summary keyword overlap + if entry: + summary_lower = (entry.get("summary", "") + " " + entry.get("name", "")).lower() + query_lower = query.lower() + matches = 0 + total = 0 + for kw in keywords: + if kw: + total += 1 + if kw.lower() in summary_lower: + matches += 1 + # Also check whole query as a substring + if len(query_lower) >= 2 and query_lower in summary_lower: + matches += 1 + total += 1 + # When keywords list is empty but query is non-empty, fall back to + # character-level overlap so the sub-score is not silently 0. + if total == 0 and query_lower: + # Simple overlap: count how many query chars appear in summary + overlap = sum(1 for ch in query_lower if ch in summary_lower) + ratio = overlap / max(len(query_lower), 1) + score += ratio * cls._WIKI_CATALOG_KEYWORD_OVERLAP_MAX + elif total > 0: + score += (matches / total) * cls._WIKI_CATALOG_KEYWORD_OVERLAP_MAX + + # Sub-score 2: Tree availability bonus + if file_path in tree_available_paths: + score += cls._WIKI_TREE_AVAILABILITY_BONUS + + # Sub-score 3: Catalog presence bonus + if entry: + summary_len = len(entry.get("summary", "")) + if summary_len > 100: + score += cls._WIKI_CATALOG_PRESENCE_FULL + elif summary_len > 30: + score += cls._WIKI_CATALOG_PRESENCE_MEDIUM + elif summary_len > 0: + score += cls._WIKI_CATALOG_PRESENCE_MINIMAL + + return min(score, cls._WIKI_MAX_SCORE) + async def _fast_find_best_file( self, keywords: List[str], @@ -2484,9 +2636,23 @@ async def _fast_find_best_file( exclude: Optional[List[str]] = None, top_k: int = 1, keyword_idfs: Optional[Dict[str, float]] = None, + query: str = "", + artifacts: Optional["CompileArtifacts"] = None, ) -> Optional[List[Dict[str, Any]]]: """Search per keyword via rga and return the top-k best-matching files - ranked by IDF-weighted log-TF scoring. + ranked by IDF-weighted log-TF scoring, optionally enhanced with + wiki-derived relevance from compile artifacts. + + Args: + keywords: Search keywords from FAST Step 1. + paths: Search paths. + max_depth: Maximum directory depth for rga. + include: Glob patterns to include. + exclude: Glob patterns to exclude. + top_k: Number of top files to return. + keyword_idfs: Pre-computed IDF values for keywords. + query: Original user query (used for wiki relevance scoring). + artifacts: Compile artifacts for adaptive wiki-enhanced ranking. Returns: List of merged file dicts (path, matches, lines, total_matches, weighted_score) or None. @@ -2576,6 +2742,25 @@ async def _fast_find_best_file( score += idf * (1.0 + math.log(tf)) f["weighted_score"] = score + # --- Wiki-enhanced hybrid scoring (adaptive: only when artifacts exist) --- + if artifacts and artifacts.catalog_map: + # Normalize TF-IDF scores to [0, 10] to align with Wiki score range + max_tf_idf = max((f["weighted_score"] for f in merged), default=1.0) + if max_tf_idf <= 0: + max_tf_idf = 1.0 + for f in merged: + wiki_score = self._compute_wiki_relevance( + f["path"], query, keywords, + artifacts.catalog_map, artifacts.tree_available_paths, + ) + f["wiki_relevance"] = wiki_score + # Normalize TF-IDF to [0, 10] before blending + tf_idf_norm = (f["weighted_score"] / max_tf_idf) * self._WIKI_MAX_SCORE + f["weighted_score"] = ( + self._WIKI_BLEND_ALPHA * tf_idf_norm + + (1 - self._WIKI_BLEND_ALPHA) * wiki_score + ) + merged.sort(key=lambda f: f["weighted_score"], reverse=True) pruned = self._prune_by_score(merged, top_k=top_k) @@ -2750,6 +2935,183 @@ def _load_document_catalog(self) -> Optional[List[Dict[str, str]]]: pass return None + def _detect_compile_artifacts(self) -> CompileArtifacts: + """One-shot probe of all compile artifacts for adaptive FAST activation. + + Reads the document catalog and scans the tree cache directory to + determine which compile products are available. Called once at the + start of ``_search_fast()``; the result is passed to downstream + helpers so they can enable enhanced logic only when artifacts exist. + + Cost: one JSON read (catalog) + one directory listing (tree cache). + Tree path results are cached in ``_tree_paths_cache`` so subsequent + calls within the same instance avoid re-parsing every JSON file. + Returns a ``CompileArtifacts`` with ``None``/empty fields when + compile has not been run. + """ + catalog = self._load_document_catalog() + catalog_map: Dict[str, Dict[str, str]] = {} + if catalog: + for entry in catalog: + p = entry.get("path", "") + if p: + catalog_map[p] = entry + + indexer = self._get_tree_indexer() + # Use cached tree paths when available to avoid re-parsing all JSONs + tree_paths: Set[str] = getattr(self, "_tree_paths_cache", None) or set() + if indexer is not None and not tree_paths: + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if tree_cache.exists(): + try: + from sirchmunk.learnings.tree_indexer import DocumentTree + for tf in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: + try: + tree = DocumentTree.from_json( + tf.read_text(encoding="utf-8") + ) + if tree.file_path: + tree_paths.add(tree.file_path) + except Exception: + pass + except Exception: + pass + # Cache for future calls within this instance + self._tree_paths_cache = tree_paths + + return CompileArtifacts( + catalog=catalog, + catalog_map=catalog_map, + tree_indexer=indexer, + tree_available_paths=tree_paths, + ) + + @staticmethod + def _extract_catalog_keywords(summary: str, max_kw: int = 3) -> List[str]: + """Extract salient keywords from a catalog summary via simple heuristics. + + Uses word-length filtering, Chinese character detection, and CJK n-gram + extraction to pick the most informative tokens. For CJK-heavy text + (which does not use whitespace word boundaries), consecutive CJK + character runs are extracted as additional candidate tokens. + + No LLM or embedding involved. + + Args: + summary: Document summary text from the compiled catalog. + max_kw: Maximum number of keywords to return. + + Returns: + List of up to *max_kw* keywords. + """ + cls = AgenticSearch + if not summary: + return [] + import re as _re + + # Split on whitespace and common punctuation (incl. CJK punctuation) + tokens = _re.split( + r'[\s,;\uff0c\uff1b\u3001\u3002\uff1a:!?\uff01\uff1f()\[\]{}\u201c\u201d\u2018\u2019\u0022\u0027]+', + summary, + ) + + # For CJK text, also extract consecutive CJK character runs (2-6 chars) + # so that e.g. "停车位申请条件" yields ["停车位申请条件", "停车位", "申请条件", ...] + cjk_runs = _re.findall(r'[\u4e00-\u9fff\u3400-\u4dbf]{2,}', summary) + # Generate sub-phrases from long CJK runs (bigrams/trigrams/4-grams) + cjk_ngrams: List[str] = [] + for run in cjk_runs: + cjk_ngrams.append(run) + if len(run) > 4: + # Extract 2-4 char sub-phrases from each run + for n in (4, 3, 2): + for i in range(len(run) - n + 1): + cjk_ngrams.append(run[i:i + n]) + + tokens = tokens + cjk_ngrams + + # Filter: keep tokens with appropriate length and not purely numeric + candidates = [ + t for t in tokens + if len(t) >= cls._CATALOG_KEYWORD_MIN_LEN + and not t.isdigit() + and len(t) <= cls._CATALOG_KEYWORD_MAX_LEN + ] + # Prefer longer tokens (more specific) + candidates.sort(key=len, reverse=True) + # Deduplicate case-insensitively + seen: Set[str] = set() + result: List[str] = [] + for c in candidates: + lower = c.lower() + if lower not in seen: + seen.add(lower) + result.append(c) + if len(result) >= max_kw: + break + return result + + def _build_enriched_catalog_listing( + self, + catalog: List[Dict[str, str]], + max_entries: Optional[int] = None, + ) -> str: + """Build an enriched catalog listing with keywords for FAST Step 1. + + Compared to the plain ``[i] name: summary[:200]`` format, this adds + extracted keywords to help the LLM make more informed document + selections. + + Args: + catalog: Entries from ``document_catalog.json``. + max_entries: Cap to prevent prompt overflow. + + Returns: + Formatted listing string for injection into the FAST query + analysis prompt. + """ + lines: List[str] = [] + _max = max_entries if max_entries is not None else self._CATALOG_LISTING_MAX_ENTRIES + _trunc = self._CATALOG_SUMMARY_TRUNCATE + for i, entry in enumerate(catalog[:_max]): + name = entry.get("name", "") + summary = entry.get("summary", "") + kws = AgenticSearch._extract_catalog_keywords(summary) + kw_str = ", ".join(kws) if kws else "" + if kw_str: + lines.append(f"[{i}] {name}: {summary[:_trunc]} [Keywords: {kw_str}]") + else: + lines.append(f"[{i}] {name}: {summary[:_trunc]}") + return "\n".join(lines) + + def _build_answer_context( + self, + best_file_path: str, + artifacts: CompileArtifacts, + ) -> Optional[str]: + """Build document context from catalog for wiki-enhanced answer generation. + + Returns a short context string describing the source document, or + None when no catalog entry exists for *best_file_path*. + + Args: + best_file_path: Path of the top-ranked file from Step 2. + artifacts: Compile artifact availability context. + + Returns: + Context string or None. + """ + if not artifacts.catalog_map: + return None + entry = artifacts.catalog_map.get(best_file_path) + if not entry: + return None + name = entry.get("name", Path(best_file_path).name) + summary = entry.get("summary", "") + if not summary: + return None + return f"Source Document: {name}\nDocument Overview: {summary}" + async def _navigate_tree_for_evidence( self, file_path: str, query: str, ) -> Optional[str]: From 077be35a63998958c0ab2225fe7297ffd2f2d3d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 14 Apr 2026 19:40:17 +0800 Subject: [PATCH 06/70] fix _extract_catalog_keywords for llm wiki --- src/sirchmunk/search.py | 49 +++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 2976f60..6880b3b 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -3005,47 +3005,64 @@ def _extract_catalog_keywords(summary: str, max_kw: int = 3) -> List[str]: List of up to *max_kw* keywords. """ cls = AgenticSearch - if not summary: + if max_kw <= 0: + return [] + summary_text = str(summary or "").strip() + if not summary_text: return [] import re as _re # Split on whitespace and common punctuation (incl. CJK punctuation) tokens = _re.split( - r'[\s,;\uff0c\uff1b\u3001\u3002\uff1a:!?\uff01\uff1f()\[\]{}\u201c\u201d\u2018\u2019\u0022\u0027]+', - summary, + r'[\s,;\uff0c\uff1b\u3001\u3002\uff1a:!?\uff01\uff1f()\[\]{}\u201c\u201d\u2018\u2019\u0022\u0027/\\|`~@#$%^&*=+<>]+', + summary_text, ) # For CJK text, also extract consecutive CJK character runs (2-6 chars) # so that e.g. "停车位申请条件" yields ["停车位申请条件", "停车位", "申请条件", ...] - cjk_runs = _re.findall(r'[\u4e00-\u9fff\u3400-\u4dbf]{2,}', summary) + cjk_runs = _re.findall(r'[\u4e00-\u9fff\u3400-\u4dbf]{2,}', summary_text) # Generate sub-phrases from long CJK runs (bigrams/trigrams/4-grams) cjk_ngrams: List[str] = [] + max_ngram_per_run = 40 for run in cjk_runs: cjk_ngrams.append(run) if len(run) > 4: # Extract 2-4 char sub-phrases from each run + added = 0 for n in (4, 3, 2): for i in range(len(run) - n + 1): cjk_ngrams.append(run[i:i + n]) + added += 1 + if added >= max_ngram_per_run: + break + if added >= max_ngram_per_run: + break tokens = tokens + cjk_ngrams # Filter: keep tokens with appropriate length and not purely numeric candidates = [ t for t in tokens - if len(t) >= cls._CATALOG_KEYWORD_MIN_LEN + if t + and len(t) >= cls._CATALOG_KEYWORD_MIN_LEN and not t.isdigit() and len(t) <= cls._CATALOG_KEYWORD_MAX_LEN + and not _re.fullmatch(r"[_\-.]+", t) ] # Prefer longer tokens (more specific) candidates.sort(key=len, reverse=True) # Deduplicate case-insensitively seen: Set[str] = set() + chosen_norms: List[str] = [] result: List[str] = [] for c in candidates: lower = c.lower() if lower not in seen: + # Avoid noisy micro-fragments when a longer token already exists. + if len(lower) <= 4 and any(lower in kept for kept in chosen_norms): + continue seen.add(lower) + chosen_norms.append(lower) result.append(c) if len(result) >= max_kw: break @@ -3070,18 +3087,32 @@ def _build_enriched_catalog_listing( Formatted listing string for injection into the FAST query analysis prompt. """ + if not isinstance(catalog, list) or not catalog: + return "" lines: List[str] = [] _max = max_entries if max_entries is not None else self._CATALOG_LISTING_MAX_ENTRIES + if _max <= 0: + return "" _trunc = self._CATALOG_SUMMARY_TRUNCATE for i, entry in enumerate(catalog[:_max]): - name = entry.get("name", "") - summary = entry.get("summary", "") + if not isinstance(entry, dict): + continue + name = str(entry.get("name") or entry.get("path") or "") + summary = str(entry.get("summary") or "") + # Keep one-line prompt entries to avoid accidental prompt pollution. + name = " ".join(name.split()) + summary = " ".join(summary.split()) + if not name: + name = f"doc_{i}" kws = AgenticSearch._extract_catalog_keywords(summary) kw_str = ", ".join(kws) if kws else "" + shown_summary = summary[:_trunc] + if len(summary) > _trunc: + shown_summary += "..." if kw_str: - lines.append(f"[{i}] {name}: {summary[:_trunc]} [Keywords: {kw_str}]") + lines.append(f"[{i}] {name}: {shown_summary} [Keywords: {kw_str}]") else: - lines.append(f"[{i}] {name}: {summary[:_trunc]}") + lines.append(f"[{i}] {name}: {shown_summary}") return "\n".join(lines) def _build_answer_context( From a602197eef20bfb77e75c726aad62cf1b1b88408 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 14 Apr 2026 20:33:31 +0800 Subject: [PATCH 07/70] add tree guided sampling --- src/sirchmunk/search.py | 191 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 188 insertions(+), 3 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 6880b3b..577c276 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -1935,6 +1935,14 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _CATALOG_SUMMARY_TRUNCATE = 200 """Max chars of catalog summary shown in the listing.""" + # --- Tree-guided sampling constants --- + _TREE_SAMPLE_MAX_SECTIONS = 3 + """Max tree sections to include per file in tree-guided sampling.""" + _TREE_SAMPLE_SECTION_MAX_CHARS = 3000 + """Max chars per tree section.""" + _TREE_SAMPLE_RGA_SUPPLEMENT = True + """Whether to append rga evidence after tree sections as supplementary context.""" + _LLM_FALLBACK_EVIDENCE = ( "[No relevant documents found]\n\n" "The search did not find relevant content in the available documents. " @@ -2249,10 +2257,24 @@ async def _search_fast( # Step 2.5 + Step 3: Tree navigation (1 LLM call) runs in # parallel with rga evidence sampling (0 LLM). The merged # result is higher quality than either alone. + # Tree-guided sampling is integrated into _rga_evidence() for + # secondary files; the primary file gets a dedicated parallel + # tree_task to avoid blocking rga. # ============================================================== + # Track files already receiving parallel tree navigation to + # avoid duplicate LLM calls inside _rga_evidence(). + tree_nav_done: Set[str] = set() + tree_nav_target = best_files[0]["path"] + + if artifacts and tree_nav_target in artifacts.tree_available_paths: + tree_task = self._navigate_tree_for_evidence(tree_nav_target, query) + tree_nav_done.add(tree_nav_target) + else: + tree_task = self._async_noop(None) + async def _rga_evidence() -> str: - """Collect rga-based evidence from best_files (zero LLM).""" + """Collect evidence from best_files: tree-guided when available, rga fallback.""" parts: List[str] = [] chars = 0 for bf in best_files: @@ -2262,6 +2284,8 @@ async def _rga_evidence() -> str: fn = Path(fp).name ext = Path(fp).suffix.lower() ev = None + + # 1. Small file: read entirely (existing logic) if ext in self._FAST_TEXT_EXTENSIONS: try: sz = Path(fp).stat().st_size @@ -2271,8 +2295,35 @@ async def _rga_evidence() -> str: ev = f"[{fn}]\n{full}" except Exception: pass + + # 2. Tree-guided sampling (adaptive, skip files handled + # by the parallel tree_task to avoid duplicate LLM) + if ( + ev is None + and artifacts + and fp in artifacts.tree_available_paths + and fp not in tree_nav_done + ): + try: + tree_ev_inner = await self._tree_guided_sample( + fp, query, + match_objects=bf.get("matches", []), + max_chars=self._FAST_MAX_EVIDENCE_CHARS - chars, + artifacts=artifacts, + ) + if tree_ev_inner: + ev = tree_ev_inner + await self._logger.info( + f"[FAST:Step3] Tree-guided sample for {fn} " + f"({len(tree_ev_inner)} chars)" + ) + except Exception: + pass + + # 3. Fallback: rga sampling (existing logic) if ev is None: ev = await self._fast_sample_evidence(fp, bf.get("matches", [])) + if ev: remaining = self._FAST_MAX_EVIDENCE_CHARS - chars parts.append(ev[:remaining]) @@ -2281,9 +2332,7 @@ async def _rga_evidence() -> str: return "\n\n---\n\n".join(parts) # Launch tree navigation for the primary file alongside rga - tree_nav_target = best_files[0]["path"] rga_task = _rga_evidence() - tree_task = self._navigate_tree_for_evidence(tree_nav_target, query) rga_ev, tree_ev = await asyncio.gather(rga_task, tree_task) @@ -3143,6 +3192,142 @@ def _build_answer_context( return None return f"Source Document: {name}\nDocument Overview: {summary}" + async def _tree_guided_sample( + self, + file_path: str, + query: str, + *, + match_objects: Optional[List[Dict[str, Any]]] = None, + max_chars: int = 0, + artifacts: Optional["CompileArtifacts"] = None, + pre_navigated_leaves: Optional[List[Any]] = None, + ) -> Optional[str]: + """Tree-guided evidence sampling: use compiled tree index to locate + relevant sections, then read precise char_range content. + + Falls back to None when no tree index is available, letting callers + use their default sampling strategy (rga windows, Monte Carlo, etc.). + + This method is designed to be called from both FAST and DEEP modes: + - FAST: called inside _rga_evidence() per-file loop + - DEEP: called before/alongside Monte Carlo sampling + + Args: + file_path: Absolute path to the target file. + query: User query for LLM-driven branch selection. + match_objects: Optional rga match objects for hybrid evidence. + max_chars: Character budget for this file's evidence. + Uses ``_FAST_MAX_EVIDENCE_CHARS`` when 0. + artifacts: Compile artifact context; when None, probes lazily. + pre_navigated_leaves: Pre-computed leaf nodes from a prior + ``navigate()`` call. When provided the method skips the + LLM navigation step (avoids duplicate LLM calls). + + Returns: + Formatted evidence string with tree-navigated sections, or None + when tree index is unavailable (caller should fall back). + """ + if max_chars <= 0: + max_chars = self._FAST_MAX_EVIDENCE_CHARS + + # --- Guard: tree availability --- + if artifacts is not None: + if file_path not in artifacts.tree_available_paths: + return None + else: + # Lazy probe when artifacts not provided (DEEP mode entry) + indexer = self._get_tree_indexer() + if indexer is None or not indexer.has_tree(file_path): + return None + + fname = Path(file_path).name + + # --- Obtain leaf nodes --- + leaves = pre_navigated_leaves + if leaves is None: + try: + indexer = self._get_tree_indexer() + if indexer is None: + return None + tree = indexer.load_tree(file_path) + if tree is None or tree.root is None: + return None + leaves = await indexer.navigate( + tree, query, + max_results=self._TREE_SAMPLE_MAX_SECTIONS, + ) + except Exception: + return None + + if not leaves: + return None + + # --- Read full text once for char_range slicing --- + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + + # --- Extract tree sections --- + parts: List[str] = [] + total_chars = 0 + for leaf in leaves[: self._TREE_SAMPLE_MAX_SECTIONS]: + start, end = leaf.char_range + if full_text and end > start: + segment = full_text[start:end] + else: + segment = leaf.summary or "" + segment = segment[: self._TREE_SAMPLE_SECTION_MAX_CHARS] + if not segment.strip(): + continue + header = f"[{fname} \u2192 {leaf.title}]" + chunk = f"{header}\n{segment}" + if total_chars + len(chunk) > max_chars: + remaining = max_chars - total_chars + if remaining > 200: + parts.append(chunk[:remaining]) + total_chars += remaining + break + parts.append(chunk) + total_chars += len(chunk) + + # --- Optional rga supplement --- + if ( + self._TREE_SAMPLE_RGA_SUPPLEMENT + and match_objects + and total_chars < max_chars + ): + hit_lines: List[int] = [] + for m in match_objects: + ln = m.get("data", {}).get("line_number") + if isinstance(ln, int): + hit_lines.append(ln) + if hit_lines: + ext = Path(file_path).suffix.lower() + if ext in self._FAST_TEXT_EXTENSIONS: + rga_ctx = self._read_context_windows( + file_path, hit_lines, + window=self._FAST_CONTEXT_WINDOW, + max_chars=max_chars - total_chars, + ) + if rga_ctx: + rga_section = f"[{fname} \u2192 rga hits]\n{rga_ctx}" + parts.append(rga_section) + total_chars += len(rga_section) + + if not parts: + return None + + evidence = "\n\n".join(parts) + await self._logger.info( + f"[TreeSample] {fname}: " + f"{len(parts)} sections, {total_chars} chars " + f"(pre_nav={'yes' if pre_navigated_leaves else 'no'})" + ) + return evidence + async def _navigate_tree_for_evidence( self, file_path: str, query: str, ) -> Optional[str]: From 8233c35b8a79f8af02d3dbe68bf9880c723ea35c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 14 Apr 2026 23:38:01 +0800 Subject: [PATCH 08/70] fix compile quality and large-file processing --- src/sirchmunk/learnings/compiler.py | 146 ++++++++++++++++++++---- src/sirchmunk/learnings/tree_indexer.py | 21 +++- src/sirchmunk/search.py | 12 +- 3 files changed, 152 insertions(+), 27 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 4ccd5da..10b56a6 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -40,6 +40,16 @@ # Similarity threshold for merging into existing clusters during compile _MERGE_SIMILARITY_THRESHOLD = 0.75 +# Max chars for manifest-persisted document summary (used in Phase 2 & catalog) +_MANIFEST_SUMMARY_MAX_LEN = 250 + +# Preview window for direct LLM summarisation (no tree), ~4K tokens +_SUMMARY_PREVIEW_CHARS = 16_000 + +# Multi-section sampling for large documents without a tree index +_SUMMARY_SAMPLE_SECTIONS = 3 # Number of sections to sample for large docs +_SUMMARY_SAMPLE_SECTION_CHARS = 5_000 # Chars per sampled section + # --------------------------------------------------------------------------- # Data structures @@ -54,6 +64,7 @@ class FileManifestEntry: has_tree: bool cluster_ids: List[str] size_bytes: int + summary: str = "" # 新增:存储编译期生成的文档摘要 def to_dict(self) -> Dict[str, Any]: return { @@ -62,6 +73,7 @@ def to_dict(self) -> Dict[str, Any]: "has_tree": self.has_tree, "cluster_ids": self.cluster_ids, "size_bytes": self.size_bytes, + "summary": self.summary, } @classmethod @@ -72,6 +84,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "FileManifestEntry": has_tree=data.get("has_tree", False), cluster_ids=data.get("cluster_ids", []), size_bytes=data.get("size_bytes", 0), + summary=data.get("summary", ""), ) @@ -365,6 +378,7 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: has_tree=result.tree is not None, cluster_ids=result.cluster_ids, size_bytes=Path(result.path).stat().st_size if Path(result.path).exists() else 0, + summary=result.summary[:_MANIFEST_SUMMARY_MAX_LEN] if result.summary else "", ) # Phase 3: aggregate results into knowledge network @@ -516,8 +530,12 @@ async def _compile_single_file( entry.path, content, ) + # Enrich content with structural metadata for non-text types + metadata_prefix = self._extract_structured_metadata(entry.path, content) + enriched_content = metadata_prefix + content if metadata_prefix else content + result.summary = await self._extract_summary( - entry.path, content, result.tree, + entry.path, enriched_content, result.tree, ) result.topics = await self._extract_topics(result.summary) result.evidence = self._build_evidence(entry, content, result) @@ -539,11 +557,14 @@ async def _extract_summary( When a tree is available its root already contains an LLM-synthesized summary (produced by ``_synthesize_root_summary`` during tree build), so we reuse it directly — no redundant LLM call. + + For large documents without a tree, uses multi-section sampling + (beginning, middle, end) to capture the full scope of the document. """ if tree and tree.root and tree.root.summary: return tree.root.summary - preview = content[:16000] if len(content) > 16000 else content + preview = self._build_summary_preview(content) from sirchmunk.llm.prompts import COMPILE_DOC_SUMMARY prompt = COMPILE_DOC_SUMMARY.format( file_name=Path(file_path).name, @@ -552,6 +573,100 @@ async def _extract_summary( resp = await self._llm.achat([{"role": "user", "content": prompt}]) return resp.content.strip() + @staticmethod + def _build_summary_preview(content: str) -> str: + """Build a representative preview for LLM summarisation. + + For short documents (≤ _SUMMARY_PREVIEW_CHARS), returns the full + content. For large documents, samples the beginning, middle, and + end to capture the document's full scope within the token budget. + """ + if len(content) <= _SUMMARY_PREVIEW_CHARS: + return content + + section_size = _SUMMARY_SAMPLE_SECTION_CHARS + mid_start = max(section_size, (len(content) - section_size) // 2) + + head = content[:section_size] + middle = content[mid_start:mid_start + section_size] + tail = content[-section_size:] + + return ( + f"[Beginning of document]\n{head}\n\n" + f"[... content omitted ...]\n\n" + f"[Middle of document]\n{middle}\n\n" + f"[... content omitted ...]\n\n" + f"[End of document]\n{tail}" + ) + + @staticmethod + def _extract_structured_metadata(file_path: str, content: str) -> str: + """Extract structural metadata for non-text document types. + + For spreadsheets and presentations, prepend a structural overview + (sheet names, column headers, slide titles) so the LLM summariser + has better context than raw extracted text alone. + + Returns a metadata prefix string (may be empty for unsupported types). + """ + ext = Path(file_path).suffix.lower() + + if ext == ".xlsx": + return KnowledgeCompiler._extract_xlsx_metadata(file_path) + if ext == ".pptx": + return KnowledgeCompiler._extract_pptx_metadata(file_path) + + return "" + + @staticmethod + def _extract_xlsx_metadata(file_path: str) -> str: + """Extract structural metadata from Excel files. + + Reads sheet names, row counts, and column headers (first row) to + provide the LLM with a structural overview of the workbook. + Caps at 10 sheets and 15 columns per sheet for bounded output. + """ + try: + import openpyxl + wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True) + lines: List[str] = ["[Excel Workbook Structure]"] + for sheet_name in wb.sheetnames[:10]: # Cap at 10 sheets + ws = wb[sheet_name] + # Extract column headers (first row) + headers: List[str] = [] + for cell in ws.iter_rows(min_row=1, max_row=1, values_only=True): + headers = [str(h) for h in cell if h is not None] + break + row_count = ws.max_row or 0 + header_str = ", ".join(headers[:15]) if headers else "no headers" + lines.append(f"- Sheet '{sheet_name}': {row_count} rows, columns: [{header_str}]") + wb.close() + return "\n".join(lines) + "\n\n" + except Exception: + return "" + + @staticmethod + def _extract_pptx_metadata(file_path: str) -> str: + """Extract structural metadata from PowerPoint files. + + Reads slide count and titles (from the title placeholder) to give + the LLM a table-of-contents-like overview of the presentation. + Caps at 20 slides for bounded output. + """ + try: + from pptx import Presentation + prs = Presentation(file_path) + lines: List[str] = [f"[PowerPoint Structure: {len(prs.slides)} slides]"] + for i, slide in enumerate(prs.slides[:20], 1): # Cap at 20 slides + title = "" + if slide.shapes.title: + title = slide.shapes.title.text.strip() + if title: + lines.append(f"- Slide {i}: {title}") + return "\n".join(lines) + "\n\n" + except Exception: + return "" + def _build_evidence( self, entry: FileEntry, @@ -851,14 +966,19 @@ def _build_document_catalog(self, manifest: CompileManifest) -> None: The catalog is consumed by FAST search to fuse query analysis with LLM-driven document routing in a single prompt. Each entry carries - the filename and a truncated root summary (≤250 chars). + the filename and a truncated root summary (<= _MANIFEST_SUMMARY_MAX_LEN chars). + + Summary is sourced from the manifest (populated during Phase 2 compile), + with a tree-root fallback for backward compatibility. """ tree_cache = self._compile_dir / "trees" entries: List[Dict[str, str]] = [] for file_path, entry in manifest.files.items(): - summary = "" - if entry.has_tree and tree_cache.exists(): + summary = entry.summary # Primary: manifest-persisted summary + + # Fallback: read from tree root if manifest summary is empty + if not summary and entry.has_tree and tree_cache.exists(): tree_file = tree_cache / f"{entry.file_hash}.json" if tree_file.exists(): try: @@ -866,24 +986,10 @@ def _build_document_catalog(self, manifest: CompileManifest) -> None: tree_file.read_text(encoding="utf-8"), ) if tree.root and tree.root.summary: - summary = tree.root.summary[:250] + summary = tree.root.summary[:_MANIFEST_SUMMARY_MAX_LEN] except Exception: pass - if not summary: - # Fallback: use first cluster's description - for cid in entry.cluster_ids[:1]: - try: - import asyncio - loop = asyncio.get_event_loop() - if loop.is_running(): - break - c = loop.run_until_complete(self._storage.get(cid)) - if c and c.description: - summary = str(c.description[0])[:250] - except Exception: - break - entries.append({ "path": file_path, "name": Path(file_path).name, diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 53ebf0b..8bd2983 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -21,6 +21,11 @@ # File-size threshold: skip tree indexing for small files _TREE_MIN_CHARS = 50_000 # 50 K characters +# Adaptive preview window for LLM structure analysis +_TREE_PREVIEW_MIN = 12_000 # Minimum preview window (chars) +_TREE_PREVIEW_MAX = 50_000 # Maximum preview window (~12K tokens) +_TREE_PREVIEW_RATIO = 0.15 # Fraction of document to preview + # Extensions eligible for tree indexing _TREE_EXTENSIONS = { ".pdf", ".docx", ".doc", ".md", ".markdown", @@ -260,7 +265,8 @@ async def _build_node( """Recursively build tree nodes via LLM structure analysis.""" from sirchmunk.llm.prompts import COMPILE_TREE_STRUCTURE - preview = text[:12000] if len(text) > 12000 else text + preview_size = self._compute_preview_size(len(text)) + preview = text[:preview_size] prompt = COMPILE_TREE_STRUCTURE.format( document_content=preview, max_sections=8, @@ -427,6 +433,19 @@ def _load_cache(self, file_hash: str) -> Optional[DocumentTree]: # Helpers # # ------------------------------------------------------------------ # + @staticmethod + def _compute_preview_size(text_len: int) -> int: + """Compute adaptive preview window size for LLM structure analysis. + + Scales with document length: at least *_TREE_PREVIEW_MIN* chars, + up to *_TREE_PREVIEW_MAX*, using *_TREE_PREVIEW_RATIO* of the + document length as the baseline. + """ + return max( + _TREE_PREVIEW_MIN, + min(int(text_len * _TREE_PREVIEW_RATIO), _TREE_PREVIEW_MAX), + ) + @staticmethod def _count_nodes(node: TreeNode) -> int: return 1 + sum(DocumentTreeIndexer._count_nodes(c) for c in node.children) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 577c276..f38028f 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -1910,18 +1910,18 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _FAST_SMALL_FILE_THRESHOLD = 100_000 # 100K chars - read full file instead of grep sampling # --- Wiki-enhanced ranking constants --- - _WIKI_BLEND_ALPHA = 0.7 + _WIKI_BLEND_ALPHA = 0.85 """TF-IDF weight in the hybrid score; Wiki weight = 1 - alpha.""" _WIKI_MAX_SCORE = 10.0 """Upper bound for the wiki relevance score.""" _WIKI_CATALOG_KEYWORD_OVERLAP_MAX = 5.0 """Maximum sub-score for catalog summary keyword overlap.""" - _WIKI_TREE_AVAILABILITY_BONUS = 2.0 - """Bonus for files that have a compiled tree index.""" - _WIKI_CATALOG_PRESENCE_FULL = 3.0 + _WIKI_TREE_AVAILABILITY_BONUS = 0.5 + """Bonus for files that have a compiled tree index (weak signal).""" + _WIKI_CATALOG_PRESENCE_FULL = 2.0 """Catalog presence bonus for summaries > 100 chars.""" - _WIKI_CATALOG_PRESENCE_MEDIUM = 2.0 - """Catalog presence bonus for summaries > 30 chars.""" + _WIKI_CATALOG_PRESENCE_MEDIUM = 1.5 + """Catalog presence bonus for summaries > 30 chars (must be < FULL).""" _WIKI_CATALOG_PRESENCE_MINIMAL = 1.0 """Catalog presence bonus for summaries > 0 chars.""" _TREE_CACHE_SCAN_LIMIT = 200 From 1de1c98817c8e6c9beb29045debe38df4682cab7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 15 Apr 2026 00:09:58 +0800 Subject: [PATCH 09/70] adopt the latest compile processing --- src/sirchmunk/search.py | 117 ++++++++++++++++++++++++++++------------ 1 file changed, 82 insertions(+), 35 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index f38028f..5128702 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -135,6 +135,7 @@ class CompileArtifacts: catalog_map: Dict[str, Dict[str, str]] # path -> catalog entry for O(1) lookup tree_indexer: Optional[Any] # DocumentTreeIndexer (lazy import) tree_available_paths: Set[str] # file paths that have cached tree indices + manifest_map: Dict[str, Any] = field(default_factory=dict) # {path: FileManifestEntry} class AgenticSearch(BaseSearch): @@ -1426,6 +1427,7 @@ async def _search_deep( self._probe_knowledge_cache(query), self._load_spec_context(paths, stale_hours=spec_stale_hours), self._probe_tree_index(query), + self._probe_compile_hints(initial_keywords if initial_keywords else [query]), return_exceptions=True, ) @@ -1434,8 +1436,9 @@ async def _search_deep( knowledge_probe = phase1_results[2] if not isinstance(phase1_results[2], Exception) else KnowledgeProbeResult([], [], "") spec_context = phase1_results[3] if not isinstance(phase1_results[3], Exception) else "" tree_hits = phase1_results[4] if not isinstance(phase1_results[4], Exception) else [] + compile_hints = phase1_results[5] if not isinstance(phase1_results[5], Exception) else CompileHints([], []) - for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index"]): + for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index", "compile_hints"]): if isinstance(phase1_results[i], Exception): await self._logger.warning(f"[Phase 1] {label} probe failed: {phase1_results[i]}") @@ -1471,6 +1474,7 @@ async def _search_deep( f"dir_scan={'OK' if scan_result else 'N/A'}, " f"knowledge_files={len(knowledge_probe.file_paths)}, " f"tree_hits={len(tree_hits)}, " + f"compile_hints={len(compile_hints.file_paths)}, " f"soft_hit={'YES' if soft_hit else 'NO'}, " f"spec_cache={'YES' if spec_context else 'NO'}" ) @@ -1523,7 +1527,7 @@ async def _search_deep( if soft_hit: extra_knowledge_files = soft_hit.file_paths + extra_knowledge_files merged_files = self._merge_file_paths( - keyword_files=list(tree_hits) + keyword_files, + keyword_files=list(tree_hits) + compile_hints.file_paths + keyword_files, dir_scan_files=dir_scan_files, knowledge_hits=extra_knowledge_files, ) @@ -2285,22 +2289,9 @@ async def _rga_evidence() -> str: ext = Path(fp).suffix.lower() ev = None - # 1. Small file: read entirely (existing logic) - if ext in self._FAST_TEXT_EXTENSIONS: - try: - sz = Path(fp).stat().st_size - if sz < self._FAST_SMALL_FILE_THRESHOLD: - full = Path(fp).read_text(errors="replace") - if len(full) < self._FAST_SMALL_FILE_THRESHOLD: - ev = f"[{fn}]\n{full}" - except Exception: - pass - - # 2. Tree-guided sampling (adaptive, skip files handled - # by the parallel tree_task to avoid duplicate LLM) + # 1. Tree-guided sampling FIRST for tree-indexed files if ( - ev is None - and artifacts + artifacts and fp in artifacts.tree_available_paths and fp not in tree_nav_done ): @@ -2320,6 +2311,17 @@ async def _rga_evidence() -> str: except Exception: pass + # 2. Small file: read entirely (only if tree didn't provide evidence) + if ev is None and ext in self._FAST_TEXT_EXTENSIONS: + try: + sz = Path(fp).stat().st_size + if sz < self._FAST_SMALL_FILE_THRESHOLD: + full = Path(fp).read_text(errors="replace") + if len(full) < self._FAST_SMALL_FILE_THRESHOLD: + ev = f"[{fn}]\n{full}" + except Exception: + pass + # 3. Fallback: rga sampling (existing logic) if ev is None: ev = await self._fast_sample_evidence(fp, bf.get("matches", [])) @@ -2641,11 +2643,15 @@ def _compute_wiki_relevance( query_lower = query.lower() matches = 0 total = 0 + summary_tokens = cls._tokenize_for_matching(summary_lower) for kw in keywords: if kw: total += 1 - if kw.lower() in summary_lower: - matches += 1 + kw_low = kw.lower() + if kw_low in summary_tokens: + matches += 1 # Full token match + elif kw_low in summary_lower: + matches += 0.5 # Substring-only match (lower confidence) # Also check whole query as a substring if len(query_lower) >= 2 and query_lower in summary_lower: matches += 1 @@ -3006,25 +3012,43 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: if p: catalog_map[p] = entry + # Load manifest for rich metadata (size, has_tree, cluster_ids) + manifest_map: Dict[str, Any] = {} + manifest_path = self.work_path / ".cache" / "compile" / "manifest.json" + if manifest_path.exists(): + try: + from sirchmunk.learnings.compiler import CompileManifest + manifest = CompileManifest.from_json( + manifest_path.read_text(encoding="utf-8") + ) + manifest_map = manifest.files # {file_path: FileManifestEntry} + except Exception: + pass + indexer = self._get_tree_indexer() # Use cached tree paths when available to avoid re-parsing all JSONs tree_paths: Set[str] = getattr(self, "_tree_paths_cache", None) or set() - if indexer is not None and not tree_paths: - tree_cache = self.work_path / ".cache" / "compile" / "trees" - if tree_cache.exists(): - try: - from sirchmunk.learnings.tree_indexer import DocumentTree - for tf in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: - try: - tree = DocumentTree.from_json( - tf.read_text(encoding="utf-8") - ) - if tree.file_path: - tree_paths.add(tree.file_path) - except Exception: - pass - except Exception: - pass + if not tree_paths: + # Prefer manifest-based detection (fast, O(1) per file) + if manifest_map: + tree_paths = {fp for fp, entry in manifest_map.items() if entry.has_tree} + # Fallback: scan tree cache directory (legacy path) + elif indexer is not None: + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if tree_cache.exists(): + try: + from sirchmunk.learnings.tree_indexer import DocumentTree + for tf in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: + try: + tree = DocumentTree.from_json( + tf.read_text(encoding="utf-8") + ) + if tree.file_path: + tree_paths.add(tree.file_path) + except Exception: + pass + except Exception: + pass # Cache for future calls within this instance self._tree_paths_cache = tree_paths @@ -3033,8 +3057,31 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: catalog_map=catalog_map, tree_indexer=indexer, tree_available_paths=tree_paths, + manifest_map=manifest_map, ) + @staticmethod + def _tokenize_for_matching(text: str) -> Set[str]: + """Tokenize text into meaningful units for keyword matching. + + Splits on whitespace and CJK/Latin punctuation boundaries, then + generates 2-3 char n-grams for CJK-heavy tokens to handle + unsegmented Chinese text. Returns a set of lowercased tokens. + """ + import re + tokens: Set[str] = set() + raw = re.split(r'[\s,;.!?,;。!?::、\u201c\u201d\u2018\u2019()()\[\]{}<>《》\-/]+', text.lower()) + for t in raw: + t = t.strip() + if not t: + continue + tokens.add(t) + if len(t) >= 2 and any('\u4e00' <= c <= '\u9fff' for c in t): + for n in (2, 3): + for i in range(len(t) - n + 1): + tokens.add(t[i:i + n]) + return tokens + @staticmethod def _extract_catalog_keywords(summary: str, max_kw: int = 3) -> List[str]: """Extract salient keywords from a catalog summary via simple heuristics. From 938ced1e657d58f0cc042631c97f52de5fbfe328 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 15 Apr 2026 16:02:45 +0800 Subject: [PATCH 10/70] refactor tree indexing with toc --- src/sirchmunk/learnings/compiler.py | 45 ++- src/sirchmunk/learnings/toc_extractor.py | 391 +++++++++++++++++++++++ src/sirchmunk/learnings/tree_indexer.py | 154 ++++++++- src/sirchmunk/search.py | 111 +++++++ 4 files changed, 697 insertions(+), 4 deletions(-) create mode 100644 src/sirchmunk/learnings/toc_extractor.py diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 10b56a6..fac9b79 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -65,6 +65,8 @@ class FileManifestEntry: cluster_ids: List[str] size_bytes: int summary: str = "" # 新增:存储编译期生成的文档摘要 + has_explicit_toc: bool = False # Whether a native TOC was extracted from the file + tree_node_count: int = 0 # Number of nodes in the tree index (quality metric) def to_dict(self) -> Dict[str, Any]: return { @@ -74,6 +76,8 @@ def to_dict(self) -> Dict[str, Any]: "cluster_ids": self.cluster_ids, "size_bytes": self.size_bytes, "summary": self.summary, + "has_explicit_toc": self.has_explicit_toc, + "tree_node_count": self.tree_node_count, } @classmethod @@ -85,6 +89,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "FileManifestEntry": cluster_ids=data.get("cluster_ids", []), size_bytes=data.get("size_bytes", 0), summary=data.get("summary", ""), + has_explicit_toc=data.get("has_explicit_toc", False), + tree_node_count=data.get("tree_node_count", 0), ) @@ -147,6 +153,8 @@ class FileCompileResult: evidence: Optional[EvidenceUnit] = None cluster_ids: List[str] = field(default_factory=list) error: Optional[str] = None + has_explicit_toc: bool = False # Whether TOC was extracted from native structure + tree_node_count: int = 0 # Number of nodes in the tree index @dataclass @@ -379,6 +387,8 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: cluster_ids=result.cluster_ids, size_bytes=Path(result.path).stat().st_size if Path(result.path).exists() else 0, summary=result.summary[:_MANIFEST_SUMMARY_MAX_LEN] if result.summary else "", + has_explicit_toc=result.has_explicit_toc, + tree_node_count=result.tree_node_count, ) # Phase 3: aggregate results into knowledge network @@ -525,11 +535,26 @@ async def _compile_single_file( and DocumentTreeIndexer.should_build_tree(entry.path, len(content)) ) + # Phase 0.5: TOC extraction (zero LLM calls) + toc_entries = None + if use_tree: + from sirchmunk.learnings.toc_extractor import TOCExtractor + toc_entries = TOCExtractor.extract(entry.path, content) + if toc_entries: + await self._log.info( + f"[Compile] Extracted TOC with {len(toc_entries)} entries " + f"for {Path(entry.path).name}" + ) + if use_tree: result.tree = await self._tree_indexer.build_tree( - entry.path, content, + entry.path, content, toc_entries=toc_entries, ) + # Record TOC / tree metrics on the result for manifest persistence + result.has_explicit_toc = toc_entries is not None and len(toc_entries) > 0 + result.tree_node_count = self._count_tree_nodes(result.tree) + # Enrich content with structural metadata for non-text types metadata_prefix = self._extract_structured_metadata(entry.path, content) enriched_content = metadata_prefix + content if metadata_prefix else content @@ -940,6 +965,24 @@ def _add_edge( WeakSemanticEdge(target_cluster_id=target_id, weight=weight, source=source) ) + @staticmethod + def _count_tree_nodes(tree: Optional[DocumentTree]) -> int: + """Count total nodes in a DocumentTree (recursive). + + Args: + tree: The tree to count, or None. + + Returns: + Total node count, or 0 if tree is None. + """ + if tree is None or tree.root is None: + return 0 + + def _count(node: Any) -> int: + return 1 + sum(_count(c) for c in node.children) + + return _count(tree.root) + # ------------------------------------------------------------------ # # Manifest I/O # # ------------------------------------------------------------------ # diff --git a/src/sirchmunk/learnings/toc_extractor.py b/src/sirchmunk/learnings/toc_extractor.py new file mode 100644 index 0000000..85f3b8e --- /dev/null +++ b/src/sirchmunk/learnings/toc_extractor.py @@ -0,0 +1,391 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +TOC (Table of Contents) extractor — pure local operations, zero LLM calls. + +Extracts hierarchical table-of-contents structures from various document +formats (PDF, Markdown, DOCX, HTML) using native format features (bookmarks, +heading styles, heading tags). The extracted TOCEntry list is consumed by +the tree indexer to accelerate tree construction. +""" + +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional + +# Minimum number of TOC entries required to form a meaningful structure +_MIN_TOC_ENTRIES = 3 + +# Known heading-style prefixes across locales (English, Chinese, etc.) +_HEADING_STYLE_PREFIXES = ("Heading", "heading", "\u6807\u9898") # "标题" = Chinese + + +@dataclass +class TOCEntry: + """Single entry in an extracted table of contents.""" + + title: str + level: int # 0=root, 1=section, 2=subsection + char_start: int # Character offset in extracted text + char_end: Optional[int] = None + page_start: Optional[int] = None + page_end: Optional[int] = None + children: List["TOCEntry"] = field(default_factory=list) + + +class TOCExtractor: + """Extract TOC structure from documents using native format features. + + All methods are static — no instance state required. Each extraction + method handles one file format and returns a flat or nested list of + ``TOCEntry`` objects. The main ``extract()`` entry point dispatches + by file extension and resolves character positions against the + extracted text content. + + Design constraints: + - Pure local operations, zero LLM calls + - Exceptions handled internally; failure returns None + """ + + @staticmethod + def extract(file_path: str, content: str) -> Optional[List[TOCEntry]]: + """Main entry point: extract TOC entries from a file. + + Dispatches to format-specific extractors based on file extension, + then resolves character positions in the extracted text content. + + Args: + file_path: Absolute path to the source file. + content: Extracted text content of the file. + + Returns: + List of TOCEntry with resolved char positions, or None if + the file format is unsupported or fewer than _MIN_TOC_ENTRIES + entries are found. + """ + ext = Path(file_path).suffix.lower() + + entries: Optional[List[TOCEntry]] = None + if ext == ".pdf": + entries = TOCExtractor._extract_pdf_toc(file_path) + elif ext in (".md", ".markdown"): + entries = TOCExtractor._extract_markdown_toc(content) + elif ext in (".docx",): + entries = TOCExtractor._extract_docx_toc(file_path) + elif ext in (".html", ".htm"): + entries = TOCExtractor._extract_html_toc(content) + else: + return None + + if not entries: + return None + + # Flatten nested children for total count check + total = TOCExtractor._count_entries(entries) + if total < _MIN_TOC_ENTRIES: + return None + + # Resolve character positions in extracted text + entries = TOCExtractor._resolve_char_positions(entries, content) + return entries + + @staticmethod + def _extract_pdf_toc(file_path: str) -> Optional[List[TOCEntry]]: + """Extract TOC from PDF bookmarks/outline using pypdf. + + Recursively parses the nested bookmark structure from + ``PdfReader.outline``. + + Args: + file_path: Path to the PDF file. + + Returns: + List of TOCEntry with page_start populated, or None on failure. + """ + try: + from pypdf import PdfReader + + reader = PdfReader(file_path) + outline = reader.outline + if not outline: + return None + + entries: List[TOCEntry] = [] + TOCExtractor._parse_pdf_outline(reader, outline, entries, level=1) + return entries if entries else None + except Exception: + return None + + @staticmethod + def _parse_pdf_outline( + reader: "PdfReader", + outline_items: List, + entries: List[TOCEntry], + level: int, + ) -> None: + """Recursively parse pypdf outline items into TOCEntry list. + + Args: + reader: PdfReader instance for page number resolution. + outline_items: Nested list of outline Destination objects. + entries: Accumulator list to append entries to. + level: Current nesting level (1=top-level section). + """ + for item in outline_items: + if isinstance(item, list): + # Nested list means sub-bookmarks — attach to last entry + if entries: + sub_entries: List[TOCEntry] = [] + TOCExtractor._parse_pdf_outline( + reader, item, sub_entries, level=level + 1, + ) + entries[-1].children.extend(sub_entries) + else: + TOCExtractor._parse_pdf_outline( + reader, item, entries, level=level, + ) + else: + # Single bookmark destination + try: + title = item.title if hasattr(item, "title") else str(item) + page_num = None + try: + page_num = reader.get_destination_page_number(item) + except Exception: + pass + entry = TOCEntry( + title=title.strip(), + level=level, + char_start=0, + page_start=page_num, + ) + entries.append(entry) + except Exception: + continue + + @staticmethod + def _extract_markdown_toc(content: str) -> Optional[List[TOCEntry]]: + """Extract TOC from Markdown heading syntax (# / ## / ###). + + Matches ATX-style headings: lines beginning with 1-6 '#' characters + followed by whitespace and the heading text. + + Args: + content: Markdown text content. + + Returns: + List of TOCEntry with level derived from '#' count, or None. + """ + try: + pattern = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE) + matches = pattern.findall(content) + if not matches: + return None + + entries: List[TOCEntry] = [] + for hashes, title in matches: + entries.append(TOCEntry( + title=title.strip(), + level=len(hashes), + char_start=0, + )) + return entries if entries else None + except Exception: + return None + + @staticmethod + def _extract_docx_toc(file_path: str) -> Optional[List[TOCEntry]]: + """Extract TOC from DOCX heading styles using python-docx. + + Reads paragraphs with heading style names (English ``Heading``, + Chinese ``\u6807\u9898``, etc.), extracting the heading level from the style + name suffix (e.g., ``Heading 1`` -> level 1). + + Args: + file_path: Path to the DOCX file. + + Returns: + List of TOCEntry with level from heading style, or None. + """ + try: + import docx + + doc = docx.Document(file_path) + entries: List[TOCEntry] = [] + for para in doc.paragraphs: + style_name = para.style.name or "" + # Match heading styles across locales ("Heading 1", "标题 1", etc.) + matched_prefix = "" + for prefix in _HEADING_STYLE_PREFIXES: + if style_name.startswith(prefix): + matched_prefix = prefix + break + if not matched_prefix: + continue + level_str = style_name[len(matched_prefix):].strip() + try: + level = int(level_str) if level_str else 1 + except ValueError: + level = 1 + title = para.text.strip() + if title: + entries.append(TOCEntry( + title=title, + level=level, + char_start=0, + )) + return entries if entries else None + except Exception: + return None + + @staticmethod + def _extract_html_toc(content: str) -> Optional[List[TOCEntry]]: + """Extract TOC from HTML heading tags (

through

). + + Uses regex to match heading tags and strips inner HTML tags + from the title text. + + Args: + content: HTML text content. + + Returns: + List of TOCEntry with level from tag number, or None. + """ + try: + pattern = re.compile( + r"]*>(.*?)", + re.IGNORECASE | re.DOTALL, + ) + matches = pattern.findall(content) + if not matches: + return None + + entries: List[TOCEntry] = [] + for level_str, raw_title in matches: + # Strip HTML tags from title + title = re.sub(r"<[^>]+>", "", raw_title).strip() + if title: + entries.append(TOCEntry( + title=title, + level=int(level_str), + char_start=0, + )) + return entries if entries else None + except Exception: + return None + + @staticmethod + def _resolve_char_positions( + entries: List[TOCEntry], + content: str, + ) -> List[TOCEntry]: + """Resolve character start/end positions for TOC entries in content. + + Searches for each entry's title in the content text using + case-insensitive matching, progressing forward to avoid duplicate + matches. Sets char_end to the start of the next entry (or + len(content) for the last entry). + + Also recurses into children to resolve their positions. + + Args: + entries: Flat list of TOCEntry to resolve. + content: Full extracted text to search within. + + Returns: + The same list with char_start and char_end populated. + """ + if not content or not entries: + return entries + + content_lower = content.lower() + search_from = 0 + + # Collect all entries in document order (top-level + children) + flat: List[TOCEntry] = [] + TOCExtractor._flatten_entries(entries, flat) + + # Pass 1: resolve char_start for each entry + for entry in flat: + title_lower = entry.title.lower().strip() + if not title_lower: + entry.char_start = search_from + continue + # Normalise whitespace for fuzzy matching (PDF extracts may + # insert extra spaces inside headings). + title_normalised = re.sub(r"\s+", " ", title_lower) + pos = content_lower.find(title_normalised, search_from) + if pos < 0: + # Retry with the original (un-normalised) title + pos = content_lower.find(title_lower, search_from) + if pos >= 0: + entry.char_start = pos + search_from = pos + len(title_lower) + else: + # Title not found after search_from; try from beginning + pos = content_lower.find(title_normalised) + if pos < 0: + pos = content_lower.find(title_lower) + if pos >= 0: + entry.char_start = pos + # Do NOT reset search_from to avoid breaking order + else: + # Last resort: place at current search frontier + entry.char_start = search_from + + # Pass 2: resolve char_end as start of next entry (or len(content)) + for i in range(len(flat) - 1): + flat[i].char_end = flat[i + 1].char_start + if flat: + flat[-1].char_end = len(content) + + return entries + + @staticmethod + def _flatten_entries( + entries: List[TOCEntry], + flat: List[TOCEntry], + ) -> None: + """Flatten nested TOCEntry tree into document-order list. + + Args: + entries: Nested entry list. + flat: Accumulator for flattened output. + """ + for entry in entries: + flat.append(entry) + if entry.children: + TOCExtractor._flatten_entries(entry.children, flat) + + @staticmethod + def _count_entries(entries: List[TOCEntry]) -> int: + """Count total entries including nested children. + + Args: + entries: Nested entry list. + + Returns: + Total number of entries in the tree. + """ + count = 0 + for entry in entries: + count += 1 + if entry.children: + count += TOCExtractor._count_entries(entry.children) + return count + @staticmethod + def _count_entries(entries: List[TOCEntry]) -> int: + """Count total entries including nested children. + + Args: + entries: Nested entry list. + + Returns: + Total number of entries in the tree. + """ + count = 0 + for entry in entries: + count += 1 + if entry.children: + count += TOCExtractor._count_entries(entry.children) + return count diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 8bd2983..abf5459 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -19,7 +19,18 @@ from sirchmunk.utils.file_utils import get_fast_hash # File-size threshold: skip tree indexing for small files -_TREE_MIN_CHARS = 50_000 # 50 K characters +_TREE_MIN_CHARS = 20_000 # 20 K characters (lowered from 50K for broader coverage) + +# Adaptive depth thresholds: (min_chars, max_depth) — evaluated top-down; +# **must** be sorted by min_chars descending so the first match wins. +_TREE_ADAPTIVE_DEPTH_THRESHOLDS: tuple = ( + (100_000, 4), + (50_000, 3), + (20_000, 2), +) + +# Summary snippet length extracted from section content (chars) +_TOC_NODE_SUMMARY_MAX_CHARS = 300 # Adaptive preview window for LLM structure analysis _TREE_PREVIEW_MIN = 12_000 # Minimum preview window (chars) @@ -153,9 +164,14 @@ async def build_tree( max_depth: int = 4, force_rebuild: bool = False, total_pages: Optional[int] = None, + toc_entries: Optional[List[Any]] = None, ) -> Optional[DocumentTree]: """Build a tree index for a document. + When *toc_entries* are provided (from TOCExtractor), uses the + TOC-accelerated path that skips recursive LLM analysis and builds + the tree directly from extracted headings. + Returns None when the document is too small or unstructured. """ file_hash = get_fast_hash(file_path) @@ -175,12 +191,34 @@ async def build_tree( if ext not in _TREE_EXTENSIONS: return None + # Use adaptive depth based on document length + effective_depth = self._compute_adaptive_depth(len(content)) + await self._log.info( f"[TreeIndexer] Building tree for {Path(file_path).name} " - f"({len(content)} chars, depth={max_depth})" + f"({len(content)} chars, depth={effective_depth})" ) - root = await self._build_node(content, level=0, max_depth=max_depth) + # TOC-accelerated path: skip recursive LLM analysis + if toc_entries: + root = await self._build_tree_from_toc(toc_entries, content) + if root is not None: + tree = DocumentTree( + file_path=file_path, + file_hash=file_hash, + created_at=datetime.now(timezone.utc).isoformat(), + total_chars=len(content), + total_pages=total_pages, + root=root, + ) + self._save_cache(file_hash, tree) + await self._log.info( + f"[TreeIndexer] Built tree from TOC: {self._count_nodes(root)} nodes" + ) + return tree + + # Fallback: existing recursive LLM path (with adaptive depth) + root = await self._build_node(content, level=0, max_depth=effective_depth) if root is None: return None @@ -258,6 +296,116 @@ def has_tree(self, file_path: str) -> bool: # Internals # # ------------------------------------------------------------------ # + async def _build_tree_from_toc( + self, + toc_entries: List[Any], + content: str, + ) -> Optional[TreeNode]: + """Build tree directly from extracted TOC entries, avoiding recursive LLM. + + Each TOCEntry becomes a TreeNode with char_range from the entry positions. + Only the root summary requires an LLM call (_synthesize_root_summary). + + Args: + toc_entries: List of TOCEntry from toc_extractor. + content: Full extracted text of the document. + + Returns: + Root TreeNode, or None if no children could be created. + """ + seen_ids: set = set() + children = self._toc_entries_to_nodes( + toc_entries, content, len(content), seen_ids, fallback_level=1, + ) + + if not children: + return None + + root_summary = await self._synthesize_root_summary(children) + return TreeNode( + node_id=self._unique_node_id(0, seen_ids), + title="Document", + summary=root_summary, + char_range=(0, len(content)), + level=0, + children=children, + ) + + @staticmethod + def _toc_entries_to_nodes( + entries: List[Any], + content: str, + parent_end: int, + seen_ids: set, + fallback_level: int, + ) -> List["TreeNode"]: + """Recursively convert TOCEntry trees into TreeNode trees. + + Handles arbitrary nesting depth and guards against invalid + char_start / char_end values. + """ + nodes: List[TreeNode] = [] + content_len = len(content) + for entry in entries: + start = max(0, min(entry.char_start, content_len)) + end = entry.char_end if entry.char_end and entry.char_end > start else parent_end + end = min(end, content_len) + + section_text = content[start:min(start + _TOC_NODE_SUMMARY_MAX_CHARS, end)] + nid = DocumentTreeIndexer._unique_node_id(start, seen_ids) + level = entry.level if entry.level > 0 else fallback_level + + child_nodes: List[TreeNode] = [] + if entry.children: + child_nodes = DocumentTreeIndexer._toc_entries_to_nodes( + entry.children, content, end, seen_ids, + fallback_level=level + 1, + ) + + node = TreeNode( + node_id=nid, + title=entry.title, + summary=section_text.strip(), + char_range=(start, end), + level=level, + children=child_nodes, + ) + nodes.append(node) + return nodes + + @staticmethod + def _unique_node_id(start: int, seen_ids: set) -> str: + """Generate a unique node_id based on char offset, appending a + disambiguator when collisions occur.""" + base = f"N{start:06d}" + if base not in seen_ids: + seen_ids.add(base) + return base + suffix = 1 + while f"{base}_{suffix}" in seen_ids: + suffix += 1 + nid = f"{base}_{suffix}" + seen_ids.add(nid) + return nid + + @staticmethod + def _compute_adaptive_depth(content_length: int) -> int: + """Compute max tree depth based on document length. + + Longer documents get deeper trees for finer-grained navigation. + Uses _TREE_ADAPTIVE_DEPTH_THRESHOLDS for threshold-based selection. + + Args: + content_length: Character count of the document. + + Returns: + Maximum tree depth (2-4). + """ + for threshold, depth in _TREE_ADAPTIVE_DEPTH_THRESHOLDS: + if content_length >= threshold: + return depth + return 2 # minimum depth + async def _build_node( self, text: str, level: int, max_depth: int, offset: int = 0, diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 5128702..a9323fa 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -138,6 +138,38 @@ class CompileArtifacts: manifest_map: Dict[str, Any] = field(default_factory=dict) # {path: FileManifestEntry} +class _TreeNavCache: + """Per-search-session cache for tree navigation results. + + Avoids duplicate LLM navigation calls for the same file+query pair. + Created at the start of each ``_search_fast()`` invocation and reset + per search session. + """ + + __slots__ = ("_store",) + + def __init__(self) -> None: + self._store: Dict[str, Optional[List[Any]]] = {} + + @staticmethod + def _key(file_path: str, query: str) -> str: + import hashlib + return hashlib.md5(f"{file_path}:{query}".encode()).hexdigest() + + def get(self, file_path: str, query: str) -> Optional[List[Any]]: + """Retrieve cached navigation leaves for a file+query pair.""" + key = self._key(file_path, query) + return self._store.get(key) + + def has(self, file_path: str, query: str) -> bool: + """Check whether a cached result exists.""" + return self._key(file_path, query) in self._store + + def put(self, file_path: str, query: str, leaves: Optional[List[Any]]) -> None: + """Store navigation leaves for a file+query pair.""" + self._store[self._key(file_path, query)] = leaves + + class AgenticSearch(BaseSearch): def __init__( @@ -1518,6 +1550,29 @@ async def _search_deep( f"dir_scan_files={len(dir_scan_files)}" ) + # --- Phase 2.5: Parallel tree pre-navigation for top tree hits --- + _pre_nav_evidence: Dict[str, str] = {} + if tree_hits: + _nav_fps = [fp for fp in tree_hits[:self._DEEP_PRE_NAV_MAX_FILES]] + if _nav_fps: + _nav_results = await asyncio.gather( + *[self._tree_guided_sample( + fp, query, max_chars=self._FAST_MAX_EVIDENCE_CHARS, + ) for fp in _nav_fps], + return_exceptions=True, + ) + for fp, nav_res in zip(_nav_fps, _nav_results): + if isinstance(nav_res, Exception): + await self._logger.warning( + f"[Phase 2.5] Tree pre-nav failed for {Path(fp).name}: {nav_res}" + ) + elif isinstance(nav_res, str) and nav_res: + _pre_nav_evidence[fp] = nav_res + if _pre_nav_evidence: + await self._logger.info( + f"[Phase 2.5] Pre-navigated {len(_pre_nav_evidence)} tree files" + ) + # ============================================================== # Phase 3: Merge file paths + build KnowledgeCluster # P1 tree hits get highest priority; P2 soft-hit files next @@ -1547,6 +1602,17 @@ async def _search_deep( # ============================================================== graph_ctx = "" if cluster: + # Merge pre-navigated tree evidence into cluster content + if _pre_nav_evidence and cluster.content: + pre_nav_parts = [] + for fp, ev in _pre_nav_evidence.items(): + pre_nav_parts.append(f"[Tree evidence: {Path(fp).name}]\n{ev}") + if pre_nav_parts: + pre_nav_ctx = "\n\n".join(pre_nav_parts) + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = f"{cluster.content}\n\n{pre_nav_ctx}" + graph_ctx = await self._gather_graph_context(cluster) if graph_ctx and cluster.content: if isinstance(cluster.content, list): @@ -1946,6 +2012,10 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Max chars per tree section.""" _TREE_SAMPLE_RGA_SUPPLEMENT = True """Whether to append rga evidence after tree sections as supplementary context.""" + _TREE_ROOT_HINTS_MAX_FILES = 10 + """Maximum number of tree roots to include in FAST Step 1 hints.""" + _DEEP_PRE_NAV_MAX_FILES = 3 + """Maximum number of tree files to pre-navigate in DEEP Phase 2.5.""" _LLM_FALLBACK_EVIDENCE = ( "[No relevant documents found]\n\n" @@ -1982,6 +2052,9 @@ async def _search_fast( context = SearchContext() await self._logger.info(f"[FAST] Starting greedy search for: '{query[:80]}'") + # Reset per-session tree navigation cache + self._tree_nav_cache = _TreeNavCache() + # --- Adaptive compile artifact detection (one-shot, zero LLM) --- artifacts = self._detect_compile_artifacts() if artifacts.catalog or artifacts.tree_available_paths: @@ -2013,6 +2086,11 @@ async def _search_fast( catalog_routed_files: List[str] = [] catalog_confidence: str = "low" + # Build tree root hints for enhanced query analysis + tree_hints = "" + if artifacts and artifacts.tree_available_paths: + tree_hints = self._build_tree_root_hints(artifacts) + if catalog: listing = self._build_enriched_catalog_listing(catalog) prompt = FAST_QUERY_ANALYSIS_WITH_CATALOG.format( @@ -2021,6 +2099,10 @@ async def _search_fast( else: prompt = FAST_QUERY_ANALYSIS.format(user_input=query) + # Append tree structure hints to the prompt when available + if tree_hints: + prompt = prompt + tree_hints + resp = await self.llm.achat( messages=[{"role": "user", "content": prompt}], stream=False, @@ -3060,6 +3142,35 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: manifest_map=manifest_map, ) + def _build_tree_root_hints(self, artifacts: CompileArtifacts) -> str: + """Build tree root summary hints for FAST Step 1 query analysis. + + Loads root summaries from cached trees and formats them as context + for the LLM to understand document-level structure. + + Args: + artifacts: Compile artifact context with tree metadata. + + Returns: + Formatted hint string, or empty string when no trees are available. + """ + if not artifacts.tree_available_paths: + return "" + indexer = artifacts.tree_indexer + if indexer is None: + return "" + hints: List[str] = [] + for i, fp in enumerate(sorted(artifacts.tree_available_paths)): + if i >= self._TREE_ROOT_HINTS_MAX_FILES: + break + tree = indexer.load_tree(fp) + if tree and tree.root and tree.root.summary: + name = Path(fp).name + hints.append(f"[{i}] {name}: {tree.root.summary[:150]}") + if not hints: + return "" + return "\nDocument structure hints:\n" + "\n".join(hints) + "\n" + @staticmethod def _tokenize_for_matching(text: str) -> Set[str]: """Tokenize text into meaningful units for keyword matching. From 29c0909b166ca1ee08a8464f4e50a4aad8ff43ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 15 Apr 2026 21:18:22 +0800 Subject: [PATCH 11/70] enhance compile for excel and add embedding fallback for rga keywords retrieval --- src/sirchmunk/learnings/compiler.py | 265 +++++++++++++++++++++-- src/sirchmunk/learnings/summary_index.py | 255 ++++++++++++++++++++++ src/sirchmunk/search.py | 77 +++++++ 3 files changed, 578 insertions(+), 19 deletions(-) create mode 100644 src/sirchmunk/learnings/summary_index.py diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index fac9b79..2f8983a 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -41,7 +41,7 @@ _MERGE_SIMILARITY_THRESHOLD = 0.75 # Max chars for manifest-persisted document summary (used in Phase 2 & catalog) -_MANIFEST_SUMMARY_MAX_LEN = 250 +_MANIFEST_SUMMARY_MAX_LEN = 500 # Preview window for direct LLM summarisation (no tree), ~4K tokens _SUMMARY_PREVIEW_CHARS = 16_000 @@ -50,6 +50,13 @@ _SUMMARY_SAMPLE_SECTIONS = 3 # Number of sections to sample for large docs _SUMMARY_SAMPLE_SECTION_CHARS = 5_000 # Chars per sampled section +# Excel table-level adaptive sampling constants +_XLSX_TOTAL_ROW_BUDGET = 100 # Total sampled rows budget across all sheets +_XLSX_MIN_ROWS_PER_SHEET = 3 # Minimum sampled rows per sheet +_XLSX_MAX_ROWS_PER_SHEET = 50 # Maximum sampled rows per sheet +_XLSX_MAX_SHEETS = 10 # Maximum number of sheets to process +_XLSX_MAX_COLS_DISPLAY = 20 # Maximum columns to display per sheet + # --------------------------------------------------------------------------- # Data structures @@ -67,6 +74,7 @@ class FileManifestEntry: summary: str = "" # 新增:存储编译期生成的文档摘要 has_explicit_toc: bool = False # Whether a native TOC was extracted from the file tree_node_count: int = 0 # Number of nodes in the tree index (quality metric) + has_xlsx_digest: bool = False # Whether a pre-compiled Excel evidence digest exists def to_dict(self) -> Dict[str, Any]: return { @@ -78,6 +86,7 @@ def to_dict(self) -> Dict[str, Any]: "summary": self.summary, "has_explicit_toc": self.has_explicit_toc, "tree_node_count": self.tree_node_count, + "has_xlsx_digest": self.has_xlsx_digest, } @classmethod @@ -91,6 +100,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "FileManifestEntry": summary=data.get("summary", ""), has_explicit_toc=data.get("has_explicit_toc", False), tree_node_count=data.get("tree_node_count", 0), + has_xlsx_digest=data.get("has_xlsx_digest", False), ) @@ -155,6 +165,7 @@ class FileCompileResult: error: Optional[str] = None has_explicit_toc: bool = False # Whether TOC was extracted from native structure tree_node_count: int = 0 # Number of nodes in the tree index + has_xlsx_digest: bool = False # Whether a pre-compiled Excel evidence digest exists @dataclass @@ -389,6 +400,7 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: summary=result.summary[:_MANIFEST_SUMMARY_MAX_LEN] if result.summary else "", has_explicit_toc=result.has_explicit_toc, tree_node_count=result.tree_node_count, + has_xlsx_digest=result.has_xlsx_digest, ) # Phase 3: aggregate results into knowledge network @@ -412,6 +424,9 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: # Generate document catalog for search-time routing self._build_document_catalog(manifest) + # Phase: Build summary index for embedding+BM25 fallback (optional, non-blocking) + await self._build_summary_index(manifest) + report.elapsed_seconds = time.monotonic() - t0 await self._log.info( f"[Compile] Done in {report.elapsed_seconds:.1f}s — " @@ -556,8 +571,16 @@ async def _compile_single_file( result.tree_node_count = self._count_tree_nodes(result.tree) # Enrich content with structural metadata for non-text types - metadata_prefix = self._extract_structured_metadata(entry.path, content) - enriched_content = metadata_prefix + content if metadata_prefix else content + ext = Path(entry.path).suffix.lower() + evidence_digest = "" + + if ext in (".xlsx", ".xls"): + # Excel: use adaptive sampling for both metadata and evidence + metadata_prefix, evidence_digest = self._extract_xlsx_sampling(entry.path) + enriched_content = metadata_prefix + content if metadata_prefix else content + else: + metadata_prefix = self._extract_structured_metadata(entry.path, content) + enriched_content = metadata_prefix + content if metadata_prefix else content result.summary = await self._extract_summary( entry.path, enriched_content, result.tree, @@ -565,6 +588,19 @@ async def _compile_single_file( result.topics = await self._extract_topics(result.summary) result.evidence = self._build_evidence(entry, content, result) + # Persist Excel evidence digest for search-time consumption + if evidence_digest.strip(): + try: + digest_dir = self._compile_dir / "xlsx_digests" + digest_dir.mkdir(parents=True, exist_ok=True) + file_hash = get_fast_hash(entry.path) or "" + if file_hash: + digest_path = digest_dir / f"{file_hash}.txt" + digest_path.write_text(evidence_digest, encoding="utf-8") + result.has_xlsx_digest = True + except Exception: + pass + except Exception as exc: result.error = str(exc) await self._log.warning(f"[Compile] Failed: {entry.path}: {exc}") @@ -637,38 +673,159 @@ def _extract_structured_metadata(file_path: str, content: str) -> str: ext = Path(file_path).suffix.lower() if ext == ".xlsx": - return KnowledgeCompiler._extract_xlsx_metadata(file_path) + metadata, _evidence = KnowledgeCompiler._extract_xlsx_sampling(file_path) + return metadata if ext == ".pptx": return KnowledgeCompiler._extract_pptx_metadata(file_path) return "" @staticmethod - def _extract_xlsx_metadata(file_path: str) -> str: - """Extract structural metadata from Excel files. + def _compute_xlsx_sample_rows(total_rows: int, num_sheets: int, sheet_rows: int) -> int: + """Compute adaptive sample row count per sheet. + + Strategy: + - Divides _XLSX_TOTAL_ROW_BUDGET equally across sheets + - Small sheets (<=budget) are fully sampled + - Large sheets are capped at budget + - Result clamped to [_XLSX_MIN_ROWS_PER_SHEET, _XLSX_MAX_ROWS_PER_SHEET] + """ + budget_per_sheet = max(1, _XLSX_TOTAL_ROW_BUDGET // max(1, num_sheets)) + n = min(sheet_rows, budget_per_sheet) + return max(_XLSX_MIN_ROWS_PER_SHEET, min(_XLSX_MAX_ROWS_PER_SHEET, n)) - Reads sheet names, row counts, and column headers (first row) to - provide the LLM with a structural overview of the workbook. - Caps at 10 sheets and 15 columns per sheet for bounded output. + @staticmethod + def _extract_xlsx_sampling(file_path: str) -> Tuple[str, str]: + """Extract structural metadata AND sampled content from Excel workbook. + + Performs table-level intelligent sampling with adaptive row counts + based on workbook size and sheet complexity. + + Returns: + (metadata_prefix, evidence_digest) + - metadata_prefix: injected into summary generation context + - evidence_digest: structured text usable directly as search evidence """ try: import openpyxl wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True) - lines: List[str] = ["[Excel Workbook Structure]"] - for sheet_name in wb.sheetnames[:10]: # Cap at 10 sheets + + sheet_names = wb.sheetnames[:_XLSX_MAX_SHEETS] + num_sheets = len(sheet_names) + + # Phase 1: Collect sheet statistics + sheet_stats: List[Dict[str, Any]] = [] + for sheet_name in sheet_names: ws = wb[sheet_name] - # Extract column headers (first row) + row_count = ws.max_row or 0 + col_count = ws.max_column or 0 + # Read headers (first row) headers: List[str] = [] - for cell in ws.iter_rows(min_row=1, max_row=1, values_only=True): - headers = [str(h) for h in cell if h is not None] + for row in ws.iter_rows(min_row=1, max_row=1, values_only=True): + headers = [str(h) for h in row if h is not None] break - row_count = ws.max_row or 0 - header_str = ", ".join(headers[:15]) if headers else "no headers" - lines.append(f"- Sheet '{sheet_name}': {row_count} rows, columns: [{header_str}]") + sheet_stats.append({ + "name": sheet_name, + "rows": row_count, + "cols": col_count, + "headers": headers[:_XLSX_MAX_COLS_DISPLAY], + "ws": ws, + }) + + # Phase 2: Calculate total rows for adaptive sampling + total_rows = sum(s["rows"] for s in sheet_stats) + + meta_lines: List[str] = ["[Excel Workbook Structure]"] + evidence_lines: List[str] = [] + + for stat in sheet_stats: + ws = stat["ws"] + sheet_name = stat["name"] + row_count = stat["rows"] + col_count = stat["cols"] + headers = stat["headers"] + header_str = ", ".join(headers) if headers else "no headers" + + # Metadata line + meta_lines.append( + f"- Sheet '{sheet_name}': {row_count} rows, {col_count} columns, " + f"headers: [{header_str}]" + ) + + # Adaptive sampling + sample_n = KnowledgeCompiler._compute_xlsx_sample_rows( + total_rows, num_sheets, row_count + ) + + evidence_lines.append( + f"[Sheet '{sheet_name}' ({row_count} rows, {col_count} columns)]" + ) + evidence_lines.append(f"Columns: {header_str}") + + # Sample rows + if row_count <= sample_n: + evidence_lines.append(f"(Full content - {row_count} rows)") + else: + evidence_lines.append(f"Sample rows (top {sample_n} of {row_count}):") + + # Build table header + display_headers = headers[:_XLSX_MAX_COLS_DISPLAY] + if display_headers: + evidence_lines.append("| " + " | ".join(display_headers) + " |") + evidence_lines.append("|" + "|".join(["---"] * len(display_headers)) + "|") + + # Read sample rows (skip header row) + numeric_cols: Dict[int, List[float]] = {} # col_index -> numeric values + sampled = 0 + for row in ws.iter_rows( + min_row=2, + max_row=min(row_count, sample_n + 1), + values_only=True, + ): + cells: List[str] = [] + for ci, cell_val in enumerate(row): + if ci >= _XLSX_MAX_COLS_DISPLAY: + break + str_val = str(cell_val) if cell_val is not None else "" + cells.append(str_val[:50]) # truncate long cell values + # Track numeric values for statistics + if isinstance(cell_val, (int, float)) and cell_val == cell_val: + numeric_cols.setdefault(ci, []).append(float(cell_val)) + if cells: + evidence_lines.append("| " + " | ".join(cells) + " |") + sampled += 1 + + # Statistics for numeric columns + stat_parts: List[str] = [] + for ci, values in numeric_cols.items(): + if len(values) >= 2 and ci < len(display_headers): + col_name = display_headers[ci] + stat_parts.append( + f"{col_name} range [{min(values):.4g}-{max(values):.4g}]" + ) + if stat_parts: + evidence_lines.append(f"Statistics: {', '.join(stat_parts[:5])}") + + evidence_lines.append("") # blank line between sheets + wb.close() - return "\n".join(lines) + "\n\n" + + metadata = "\n".join(meta_lines) + "\n\n" + evidence = "\n".join(evidence_lines) + return metadata, evidence + except Exception: - return "" + return "", "" + + @staticmethod + def _extract_xlsx_metadata(file_path: str) -> str: + """Extract structural metadata from Excel files (legacy wrapper). + + Delegates to _extract_xlsx_sampling and returns only the metadata prefix + for backward compatibility. + """ + metadata, _evidence = KnowledgeCompiler._extract_xlsx_sampling(file_path) + return metadata @staticmethod def _extract_pptx_metadata(file_path: str) -> str: @@ -983,6 +1140,76 @@ def _count(node: Any) -> int: return _count(tree.root) + # ------------------------------------------------------------------ # + # Summary index for embedding + BM25 fallback # + # ------------------------------------------------------------------ # + + async def _build_summary_index(self, manifest: CompileManifest) -> None: + """Build summary embedding + BM25 index for fallback search. + + Creates a lightweight index mapping each compiled file to: + - Its summary text + - Pre-computed embedding vector (384-dim, if EmbeddingUtil available) + - Tokenized summary with term frequencies (via TokenizerUtil) + + The index is saved to .cache/compile/summary_index.json and consumed + by search.py as a last-resort fallback when rga keyword search fails. + + Skips gracefully if dependencies (EmbeddingUtil/TokenizerUtil) are unavailable. + """ + try: + from sirchmunk.utils.tokenizer_util import TokenizerUtil + from sirchmunk.learnings.summary_index import CompileSummaryIndex, SummaryIndexEntry + + entries: List[SummaryIndexEntry] = [] + summaries: List[str] = [] + + for file_path, entry in manifest.files.items(): + if entry.summary: + entries.append(SummaryIndexEntry( + file_path=file_path, + summary=entry.summary, + )) + summaries.append(entry.summary) + + if not entries: + return + + # Tokenize summaries + compute TF (always available) + tokenizer = TokenizerUtil() + for idx, entry in enumerate(entries): + tokens = tokenizer.segment(entry.summary) + entry.tokens = tokens + entry.token_freqs = {} + for t in tokens: + entry.token_freqs[t] = entry.token_freqs.get(t, 0) + 1 + + # Compute embeddings (optional — requires EmbeddingUtil) + try: + from sirchmunk.utils.embedding_util import EmbeddingUtil + embedding_util = EmbeddingUtil() + embedding_util.start_loading() + # Wait up to 60 seconds for model load + await embedding_util._ensure_model_async(timeout=60) + + if embedding_util.is_ready(): + embeddings = await embedding_util.embed(summaries) + for i, emb in enumerate(embeddings): + entries[i].embedding = emb + await self._log.info( + f"Summary index: computed embeddings for {len(entries)} entries" + ) + except Exception as emb_exc: + await self._log.warning( + f"Summary index: embedding computation skipped: {emb_exc}" + ) + + index = CompileSummaryIndex(entries) + index.save(self._compile_dir / "summary_index.json") + + except Exception as exc: + await self._log.warning(f"Failed to build summary index: {exc}") + # ------------------------------------------------------------------ # # Manifest I/O # # ------------------------------------------------------------------ # diff --git a/src/sirchmunk/learnings/summary_index.py b/src/sirchmunk/learnings/summary_index.py new file mode 100644 index 0000000..7ec355a --- /dev/null +++ b/src/sirchmunk/learnings/summary_index.py @@ -0,0 +1,255 @@ +"""Compile-time summary index for embedding + BM25 fallback retrieval. + +This module provides a lightweight, file-level index that combines: +- Semantic similarity via pre-computed embeddings (384-dim MiniLM) +- Lexical matching via BM25 scoring (TokenizerUtil segmentation) + +Used ONLY as a fallback when rga keyword search returns zero results. +""" + +import json +import math +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +@dataclass +class SummaryIndexEntry: + """Single file entry in the summary index.""" + file_path: str + summary: str + embedding: Optional[List[float]] = None # 384-dim, pre-normalized + tokens: Optional[List[str]] = None # TokenizerUtil.segment() output + token_freqs: Optional[Dict[str, int]] = None # pre-computed term frequencies + + +class CompileSummaryIndex: + """Pre-computed summary index for hybrid embedding + BM25 fallback search. + + This index is built at compile time and loaded at search time. + It provides a fallback retrieval mechanism when rga keyword search + returns zero results, combining semantic similarity (embedding cosine) + with lexical matching (BM25). + + The fusion algorithm uses Sigmoid Z-Score normalization: + 1. Compute raw scores from both channels + 2. Z-Score normalize each channel independently + 3. Weighted combination: alpha * z_embedding + (1-alpha) * z_bm25 + 4. Sigmoid activation for final score + """ + + # BM25 parameters (Okapi BM25 standard defaults) + _BM25_K1: float = 1.5 + _BM25_B: float = 0.75 + + # Fusion parameters + _DEFAULT_ALPHA: float = 0.5 # embedding weight; (1-alpha) = BM25 weight + + # Z-Score fallback for missing channel + _MISSING_CHANNEL_Z: float = -3.0 # ~0.1 percentile + + def __init__(self, entries: List[SummaryIndexEntry]) -> None: + self._entries = entries + self._num_docs = len(entries) + self._avg_doc_len = self._compute_avg_doc_len() + self._doc_freqs: Dict[str, int] = self._compute_doc_freqs() + + def _compute_avg_doc_len(self) -> float: + """Compute average document length (in tokens) across all entries.""" + lengths = [len(e.tokens or []) for e in self._entries] + return sum(lengths) / max(1, len(lengths)) + + def _compute_doc_freqs(self) -> Dict[str, int]: + """Compute document frequency for each unique token.""" + df: Dict[str, int] = {} + for entry in self._entries: + if entry.token_freqs: + for token in entry.token_freqs: + df[token] = df.get(token, 0) + 1 + return df + + @classmethod + def load(cls, index_path: Path) -> Optional["CompileSummaryIndex"]: + """Load index from JSON file. Returns None on failure.""" + try: + if not index_path.exists(): + return None + data = json.loads(index_path.read_text(encoding="utf-8")) + entries = [] + for item in data.get("entries", []): + entries.append(SummaryIndexEntry( + file_path=item["file_path"], + summary=item.get("summary", ""), + embedding=item.get("embedding"), + tokens=item.get("tokens"), + token_freqs=item.get("token_freqs"), + )) + if not entries: + return None + return cls(entries) + except Exception as exc: + logger.warning("Failed to load summary index from %s: %s", index_path, exc) + return None + + def save(self, index_path: Path) -> None: + """Persist index to JSON file.""" + index_path.parent.mkdir(parents=True, exist_ok=True) + data = { + "version": 1, + "num_entries": len(self._entries), + "entries": [ + { + "file_path": e.file_path, + "summary": e.summary, + "embedding": e.embedding, + "tokens": e.tokens, + "token_freqs": e.token_freqs, + } + for e in self._entries + ], + } + index_path.write_text( + json.dumps(data, ensure_ascii=False), + encoding="utf-8", + ) + logger.info("Summary index saved: %d entries -> %s", len(self._entries), index_path) + + def search( + self, + query_embedding: Optional[List[float]], + query_tokens: List[str], + top_k: int = 5, + alpha: float = _DEFAULT_ALPHA, + ) -> List[Tuple[str, float]]: + """Hybrid search combining embedding cosine similarity and BM25. + + Uses Sigmoid Z-Score fusion: + 1. Compute raw embedding cosine sim and BM25 score per document + 2. Z-Score normalize each channel + 3. Weighted linear combination + 4. Sigmoid activation + + Args: + query_embedding: 384-dim query vector (None to use BM25 only). + query_tokens: Tokenized query from TokenizerUtil.segment(). + top_k: Maximum number of results. + alpha: Embedding weight in [0, 1]. BM25 weight = 1 - alpha. + + Returns: + List of (file_path, fusion_score) sorted descending by score. + """ + if not self._entries: + return [] + + # Compute raw scores + emb_scores: List[Optional[float]] = [] + bm25_scores: List[float] = [] + + has_embedding = query_embedding is not None + + for entry in self._entries: + # Embedding channel + if has_embedding and entry.embedding: + emb_scores.append(self._cosine_similarity(query_embedding, entry.embedding)) + else: + emb_scores.append(None) + + # BM25 channel + bm25_scores.append(self._bm25_score(query_tokens, entry)) + + # Z-Score normalization + z_emb = self._z_score_normalize(emb_scores) + z_bm25 = self._z_score_normalize(bm25_scores) + + # Sigmoid fusion + results: List[Tuple[str, float]] = [] + for i, entry in enumerate(self._entries): + z_e = z_emb[i] if z_emb[i] is not None else self._MISSING_CHANNEL_Z + z_b = z_bm25[i] if z_bm25[i] is not None else self._MISSING_CHANNEL_Z + + combined = alpha * z_e + (1.0 - alpha) * z_b + score = 1.0 / (1.0 + math.exp(-combined)) + results.append((entry.file_path, score)) + + # Sort descending and return top_k + results.sort(key=lambda x: x[1], reverse=True) + return results[:top_k] + + def _bm25_score(self, query_tokens: List[str], entry: SummaryIndexEntry) -> float: + """Compute BM25 score for a single document. + + Uses standard Okapi BM25 formula: + score = sum over query terms: + IDF(t) * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * dl / avgdl)) + """ + if not query_tokens or not entry.token_freqs: + return 0.0 + + dl = len(entry.tokens or []) + score = 0.0 + + for token in query_tokens: + tf = entry.token_freqs.get(token, 0) + if tf == 0: + continue + + # IDF: log((N - df + 0.5) / (df + 0.5) + 1) + df = self._doc_freqs.get(token, 0) + idf = math.log((self._num_docs - df + 0.5) / (df + 0.5) + 1.0) + + # TF component + tf_component = (tf * (self._BM25_K1 + 1.0)) / ( + tf + self._BM25_K1 * (1.0 - self._BM25_B + self._BM25_B * dl / max(1.0, self._avg_doc_len)) + ) + + score += idf * tf_component + + return score + + @staticmethod + def _cosine_similarity(a: List[float], b: List[float]) -> float: + """Compute cosine similarity between two vectors. + + When embeddings are pre-normalized (L2 norm = 1), this reduces + to a simple dot product. + """ + if len(a) != len(b): + return 0.0 + dot = sum(x * y for x, y in zip(a, b)) + # Clamp to [-1, 1] for numerical safety + return max(-1.0, min(1.0, dot)) + + @staticmethod + def _z_score_normalize(scores: List[Optional[float]]) -> List[Optional[float]]: + """Z-Score normalize a list of scores, preserving None entries. + + None entries remain None (handled as _MISSING_CHANNEL_Z at fusion). + """ + valid = [s for s in scores if s is not None] + if len(valid) < 2: + # Not enough data points for meaningful normalization + return scores + + mean = sum(valid) / len(valid) + variance = sum((s - mean) ** 2 for s in valid) / len(valid) + std = math.sqrt(variance) if variance > 0 else 1.0 + + if std < 1e-9: + # All scores identical — return zeros + return [0.0 if s is not None else None for s in scores] + + return [(s - mean) / std if s is not None else None for s in scores] + + @property + def num_entries(self) -> int: + """Number of indexed documents.""" + return self._num_docs + + @property + def has_embeddings(self) -> bool: + """Whether any entry has a pre-computed embedding.""" + return any(e.embedding is not None for e in self._entries) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index a9323fa..9b7bf47 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -136,6 +136,7 @@ class CompileArtifacts: tree_indexer: Optional[Any] # DocumentTreeIndexer (lazy import) tree_available_paths: Set[str] # file paths that have cached tree indices manifest_map: Dict[str, Any] = field(default_factory=dict) # {path: FileManifestEntry} + summary_index: Optional[Any] = None # CompileSummaryIndex (lazy-loaded) class _TreeNavCache: @@ -1998,6 +1999,8 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Max tree JSON files to parse during artifact detection.""" _CATALOG_LISTING_MAX_ENTRIES = 20 """Max catalog entries in the enriched listing for Step 1.""" + _ENABLE_EMBEDDING_FALLBACK: bool = True + """Enable embedding + BM25 hybrid fallback when rga returns zero results.""" _CATALOG_KEYWORD_MIN_LEN = 2 """Minimum character length for a catalog keyword token.""" _CATALOG_KEYWORD_MAX_LEN = 20 @@ -2371,6 +2374,22 @@ async def _rga_evidence() -> str: ext = Path(fp).suffix.lower() ev = None + # 0. Excel digest priority (pre-compiled evidence) + if artifacts and artifacts.manifest_map: + manifest_entry = artifacts.manifest_map.get(fp) + if manifest_entry and getattr(manifest_entry, 'has_xlsx_digest', False): + digest_path = ( + self.work_path / ".cache" / "compile" / "xlsx_digests" + / f"{manifest_entry.file_hash}.txt" + ) + if digest_path.exists(): + try: + digest_content = digest_path.read_text(encoding="utf-8") + if digest_content.strip(): + ev = f"[{fn} - Pre-compiled Evidence]\n{digest_content}" + except Exception: + pass + # 1. Tree-guided sampling FIRST for tree-indexed files if ( artifacts @@ -2857,6 +2876,53 @@ async def _fast_find_best_file( await self._logger.warning( f"[FAST] filename search failed: {exc}" ) + + # Layer 4: Embedding + BM25 hybrid fallback + # Triggered ONLY when layers 1-3 all return empty results + if (not all_raw + and self._ENABLE_EMBEDDING_FALLBACK + and artifacts is not None + and artifacts.summary_index is not None): + try: + query_emb = None + query_tokens: List[str] = [] + + # Compute query embedding (if embedding client available) + if (self.embedding_client + and self.embedding_client.is_ready() + and artifacts.summary_index.has_embeddings): + query_emb = (await self.embedding_client.embed([query]))[0] + + # Tokenize query for BM25 + from sirchmunk.utils.tokenizer_util import TokenizerUtil + _tokenizer = TokenizerUtil() + query_tokens = _tokenizer.segment(query) + + if query_emb is not None or query_tokens: + results = artifacts.summary_index.search( + query_embedding=query_emb, + query_tokens=query_tokens, + top_k=top_k or 3, + ) + + for file_path, score in results: + if Path(file_path).exists(): + all_raw.append({ + "path": file_path, + "matches": [], + "weighted_score": score * self._WIKI_MAX_SCORE, + }) + + if all_raw: + await self._logger.info( + f"[FAST] Embedding+BM25 fallback found {len(all_raw)} candidates" + ) + except Exception as exc: + await self._logger.warning( + f"[FAST] Embedding+BM25 fallback failed: {exc}" + ) + + if not all_raw: return None merged = GrepRetriever.merge_results(all_raw, limit=20) @@ -3134,12 +3200,23 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: # Cache for future calls within this instance self._tree_paths_cache = tree_paths + # Load summary index for embedding fallback (optional) + summary_index = None + summary_index_path = self.work_path / ".cache" / "compile" / "summary_index.json" + if summary_index_path.exists(): + try: + from sirchmunk.learnings.summary_index import CompileSummaryIndex + summary_index = CompileSummaryIndex.load(summary_index_path) + except Exception: + pass + return CompileArtifacts( catalog=catalog, catalog_map=catalog_map, tree_indexer=indexer, tree_available_paths=tree_paths, manifest_map=manifest_map, + summary_index=summary_index, ) def _build_tree_root_hints(self, artifacts: CompileArtifacts) -> str: From d1f1fd4a0c425689f8fb8cef0cc037a9d463f3bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 15 Apr 2026 21:35:05 +0800 Subject: [PATCH 12/70] fix storage --- src/sirchmunk/storage/knowledge_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sirchmunk/storage/knowledge_storage.py b/src/sirchmunk/storage/knowledge_storage.py index e62c1cf..0f09071 100644 --- a/src/sirchmunk/storage/knowledge_storage.py +++ b/src/sirchmunk/storage/knowledge_storage.py @@ -124,7 +124,7 @@ def _load_from_parquet(self): # Detect parquet columns to handle schema evolution try: pq_cols = self.db.fetch_all( - f"SELECT column_name FROM parquet_schema('{self.parquet_file}')" + f"SELECT name FROM parquet_schema('{self.parquet_file}')" ) pq_col_names = {row[0] for row in pq_cols} except Exception: From caf8e052f112ed17d5963b7b969ef8b18a1e14c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 16 Apr 2026 15:56:01 +0800 Subject: [PATCH 13/70] add financebench --- benchmarks/financebench/README.md | 103 +++++++ benchmarks/financebench/analyze_results.py | 272 +++++++++++++++++ benchmarks/financebench/config.py | 99 +++++++ benchmarks/financebench/data_loader.py | 108 +++++++ benchmarks/financebench/evaluate.py | 323 +++++++++++++++++++++ benchmarks/financebench/run_benchmark.py | 239 +++++++++++++++ benchmarks/financebench/runner.py | 279 ++++++++++++++++++ 7 files changed, 1423 insertions(+) create mode 100644 benchmarks/financebench/README.md create mode 100644 benchmarks/financebench/analyze_results.py create mode 100644 benchmarks/financebench/config.py create mode 100644 benchmarks/financebench/data_loader.py create mode 100644 benchmarks/financebench/evaluate.py create mode 100644 benchmarks/financebench/run_benchmark.py create mode 100644 benchmarks/financebench/runner.py diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md new file mode 100644 index 0000000..23bd67d --- /dev/null +++ b/benchmarks/financebench/README.md @@ -0,0 +1,103 @@ +# FinanceBench Benchmark + +FinanceBench evaluation pipeline for **Sirchmunk AgenticSearch**. + +## Overview + +[FinanceBench](https://arxiv.org/abs/2311.11944) is an open-book financial QA benchmark +with **150 expert-annotated questions** across **40+ US public companies** (10-K/10-Q filings). + +### Evaluation Modes + +| Mode | Description | +|------|-------------| +| `singleDoc` | Each question searches only its target PDF (standard) | +| `sharedCorpus` | All questions search the full 41-PDF corpus | + +### Metrics + +- **3-Class Scoring**: Correct / Hallucination / Refusal (per FinanceBench paper) +- **EM / F1**: Exact Match and token-level F1 with financial value normalisation +- **Evidence Recall**: Retrieved pages vs gold evidence pages + +## Quick Start + +### 1. Setup + +```bash +cd benchmarks/financebench + +# Copy and edit the config file +cp .env.example .env.financebench +# Edit .env.financebench — set your LLM_API_KEY at minimum + +# Download FinanceBench data +# Place financebench_open_source.jsonl in ./data/ +# Place PDF corpus (41 files) in ./data/pdfs/ +``` + +### 2. Run + +```bash +# Run full benchmark (150 questions) +python run_benchmark.py + +# Run with custom config and question limit +python run_benchmark.py --env .env.financebench --limit 20 +``` + +### 3. Analyze + +```bash +# Analyze a completed run +python analyze_results.py output/results_YYYYMMDD_HHMMSS.jsonl + +# Show more error cases +python analyze_results.py output/results_*.jsonl --max-errors 50 +``` + +## Data Format + +The dataset file `financebench_open_source.jsonl` contains one JSON object per line: + +```json +{ + "financebench_id": "financebench_id_00001", + "question": "What is the FY2018 capital expenditure amount for 3M?", + "answer": "$1,577.00", + "doc_name": "3M_2018_10K", + "company": "3M", + "question_type": "fact-based-w-numerical-answer", + "question_reasoning": "retrieve", + "evidence": [{"evidence_text": "...", "evidence_page_num": 42}] +} +``` + +## File Structure + +``` +benchmarks/financebench/ +├── .env.example # Config template (copy to .env.financebench) +├── config.py # FinanceBenchConfig dataclass +├── data_loader.py # Dataset + PDF corpus loader +├── evaluate.py # EM/F1/3-class scoring + aggregation +├── runner.py # Async batch runner (AgenticSearch) +├── run_benchmark.py # CLI entry point +├── analyze_results.py # Post-hoc analysis tool +├── data/ +│ ├── financebench_open_source.jsonl +│ └── pdfs/ # 41 SEC-filing PDFs +├── output/ # Results + metrics (auto-created) +└── logs/ # Run logs (auto-created) +``` + +## SOTA Reference + +| System | Accuracy | Coverage | +|--------|----------|----------| +| Mafin 2.5 (SOTA) | 98.7% | 100% | +| Fintool | 98.0% | 66.7% | +| Quantly | 94.0% | 100% | +| GPT-4 (zero-shot) | 29.3% | 100% | + +> Mafin 2.5 uses PageIndex + Agentic Vectorless RAG 3.0 architecture. diff --git a/benchmarks/financebench/analyze_results.py b/benchmarks/financebench/analyze_results.py new file mode 100644 index 0000000..24d2b64 --- /dev/null +++ b/benchmarks/financebench/analyze_results.py @@ -0,0 +1,272 @@ +"""Analyze FinanceBench benchmark results. + +Read a JSONL results file produced by ``run_benchmark.py`` and print a +comprehensive analysis report including per-type breakdowns, per-company +accuracy, error cases, and a SOTA comparison table. + +Usage: + python analyze_results.py output/results_YYYYMMDD_HHMMSS.jsonl + python analyze_results.py output/results_*.jsonl --max-errors 30 +""" +from __future__ import annotations + +import argparse +import json +import sys +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from evaluate import compute_metrics + + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- + + +def load_results(path: str) -> List[Dict[str, Any]]: + """Load a JSONL results file into a list of dicts. + + Args: + path: Path to a ``.jsonl`` file where each line is a JSON object. + + Returns: + List of result dicts. + + Raises: + FileNotFoundError: If *path* does not exist. + json.JSONDecodeError: If a line contains invalid JSON. + """ + p = Path(path) + if not p.exists(): + print(f"ERROR: file not found — {path}", file=sys.stderr) + sys.exit(1) + + results: list[dict] = [] + with open(p, encoding="utf-8") as f: + for lineno, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + results.append(json.loads(line)) + except json.JSONDecodeError as exc: + print(f"WARNING: skipping malformed line {lineno}: {exc}", file=sys.stderr) + return results + + +# --------------------------------------------------------------------------- +# Pretty-print helpers +# --------------------------------------------------------------------------- + + +def print_breakdown(title: str, breakdown: Dict[str, Dict[str, Any]]) -> None: + """Pretty-print a metrics breakdown table. + + Args: + title: Section header text. + breakdown: ``{group_name: {accuracy, hallucination_rate, ...}}``. + """ + print(f"\n=== Breakdown by {title} ===\n") + header = f" {'Group':<30} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}" + print(header) + print(" " + "-" * (len(header) - 2)) + + for group, m in sorted(breakdown.items(), key=lambda kv: -kv[1].get("accuracy", 0)): + acc = m.get("accuracy", 0) + hal = m.get("hallucination_rate", 0) + ref = m.get("refusal_rate", 0) + n = m.get("n", 0) + print(f" {group:<30} {acc:>5.1f} {hal:>7.1f} {ref:>7.1f} {n:>4}") + + +def _compute_company_breakdown( + results: List[Dict[str, Any]], +) -> List[Tuple[str, float, int, int, int]]: + """Group results by company and return sorted by accuracy ascending. + + Returns: + List of ``(company, accuracy, correct, total, halluc)`` tuples, + sorted by accuracy ascending (worst first). + """ + groups: dict[str, list[dict]] = defaultdict(list) + for r in results: + company = r.get("company", "unknown") or "unknown" + groups[company].append(r) + + rows: list[tuple[str, float, int, int, int]] = [] + for company, items in groups.items(): + n = len(items) + correct = sum(1 for r in items if r.get("classification") == "correct") + halluc = sum(1 for r in items if r.get("classification") == "hallucination") + acc = (correct / n * 100) if n else 0.0 + rows.append((company, acc, correct, n, halluc)) + + rows.sort(key=lambda x: x[1]) # worst first + return rows + + +def print_company_breakdown(results: List[Dict[str, Any]], top_n: int = 10) -> None: + """Print per-company accuracy table, showing worst *top_n* companies. + + Args: + results: List of per-question result dicts. + top_n: Number of worst-performing companies to display. + """ + rows = _compute_company_breakdown(results) + if not rows: + return + + print(f"\n=== Worst {top_n} Companies by Accuracy ===\n") + header = f" {'Company':<40} {'Acc%':>6} {'Correct':>8} {'Hallu':>6} {'N':>4}" + print(header) + print(" " + "-" * (len(header) - 2)) + + for company, acc, correct, n, halluc in rows[:top_n]: + print(f" {company:<40} {acc:>5.1f} {correct:>8} {halluc:>6} {n:>4}") + + +def print_error_cases(results: List[Dict[str, Any]], max_show: int = 20) -> None: + """Print detailed listing of error cases (hallucination + refusal). + + Args: + results: List of per-question result dicts. + max_show: Maximum number of error cases to display. + """ + errors = [r for r in results if r.get("classification") != "correct"] + if not errors: + print("\n=== Error Cases ===\n None — perfect score!") + return + + print(f"\n=== Error Cases ({len(errors)} total, showing up to {max_show}) ===\n") + + for i, r in enumerate(errors[:max_show], 1): + fb_id = r.get("financebench_id", "?") + cls = r.get("classification", "?") + question = r.get("question", "")[:100] + pred = r.get("prediction", "")[:80] + gold = r.get("gold_answer", "")[:80] + company = r.get("company", "") + em = r.get("em", False) + f1 = r.get("f1", 0.0) + + print(f" [{i:>2}] {fb_id} [{cls.upper()}]") + print(f" Company: {company}") + print(f" Question: {question}{'...' if len(r.get('question', '')) > 100 else ''}") + print(f" Predicted: {pred}{'...' if len(r.get('prediction', '')) > 80 else ''}") + print(f" Gold: {gold}{'...' if len(r.get('gold_answer', '')) > 80 else ''}") + print(f" EM={em} F1={f1:.3f}") + if r.get("error"): + print(f" Error: {r['error'][:120]}") + print() + + if len(errors) > max_show: + print(f" ... and {len(errors) - max_show} more error(s) not shown.\n") + + +def print_comparison_with_sota(metrics: Dict[str, Any]) -> None: + """Compare with published SOTA results on FinanceBench. + + Reference baselines from the FinanceBench leaderboard and recent papers. + """ + print("\n=== Comparison with SOTA ===\n") + header = f" {'System':<30} {'Accuracy':>10} {'Coverage':>10}" + print(header) + print(" " + "-" * (len(header) - 2)) + print(f" {'Mafin 2.5 (SOTA)':<30} {'98.7%':>10} {'100%':>10}") + print(f" {'Fintool':<30} {'98.0%':>10} {'66.7%':>10}") + print(f" {'Quantly':<30} {'94.0%':>10} {'100%':>10}") + print(f" {'GPT-4 (zero-shot)':<30} {'29.3%':>10} {'100%':>10}") + + acc = metrics.get("accuracy", 0) + n = metrics.get("n", 0) + coverage = min(100.0, n / 150.0 * 100) + print(f" {'Sirchmunk (This Run)':<30} {f'{acc:.1f}%':>10} {f'{coverage:.0f}%':>10}") + print(f"\n (This run evaluated {n} questions)") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Parse CLI arguments and generate a full analysis report.""" + parser = argparse.ArgumentParser( + description="Analyze FinanceBench benchmark results from a JSONL file", + ) + parser.add_argument( + "results_file", + help="Path to the results JSONL file produced by run_benchmark.py", + ) + parser.add_argument( + "--max-errors", + type=int, + default=20, + help="Maximum number of error cases to display (default: 20)", + ) + parser.add_argument( + "--top-companies", + type=int, + default=10, + help="Number of worst-performing companies to show (default: 10)", + ) + args = parser.parse_args() + + # Load + results = load_results(args.results_file) + if not results: + print("ERROR: no results loaded.", file=sys.stderr) + sys.exit(1) + + # Compute metrics + metrics = compute_metrics(results) + + # --- Overall summary --- + n = metrics.get("n", 0) + acc = metrics.get("accuracy", 0) + hallu = metrics.get("hallucination_rate", 0) + refuse = metrics.get("refusal_rate", 0) + avg_em = metrics.get("avg_em", 0) + avg_f1 = metrics.get("avg_f1", 0) + ev_recall = metrics.get("evidence_recall") + avg_latency = metrics.get("avg_latency", 0) + + print(f"\n{'=' * 60}") + print(f" FinanceBench Analysis ({n} questions)") + print(f"{'=' * 60}") + print(f" Accuracy: {acc:.1f}%") + print(f" Hallucination Rate: {hallu:.1f}%") + print(f" Refusal Rate: {refuse:.1f}%") + print(f" Avg EM: {avg_em:.3f}") + print(f" Avg F1: {avg_f1:.3f}") + if metrics.get("avg_evidence_recall") is not None: + print(f" Evidence Recall: {metrics['avg_evidence_recall']:.3f}") + else: + print(f" Evidence Recall: N/A (page-level telemetry unavailable)") + print(f" Avg Latency: {avg_latency:.1f}s") + + # --- Breakdowns --- + if "by_question_type" in metrics: + print_breakdown("Question Type", metrics["by_question_type"]) + + if "by_question_reasoning" in metrics: + print_breakdown("Question Reasoning", metrics["by_question_reasoning"]) + + # --- Per-company breakdown (worst performers) --- + print_company_breakdown(results, top_n=args.top_companies) + + # --- Error cases --- + print_error_cases(results, max_show=args.max_errors) + + # --- SOTA comparison --- + print_comparison_with_sota(metrics) + + print(f"\n{'=' * 60}") + print(f" Source: {args.results_file}") + print(f"{'=' * 60}\n") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/financebench/config.py b/benchmarks/financebench/config.py new file mode 100644 index 0000000..f2e0fdb --- /dev/null +++ b/benchmarks/financebench/config.py @@ -0,0 +1,99 @@ +"""FinanceBench benchmark configuration.""" +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class FinanceBenchConfig: + """All settings for a FinanceBench evaluation run.""" + + # LLM + llm_api_key: str = "" + llm_base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1" + llm_model: str = "qwen3.5-plus" + llm_timeout: int = 120 + + # Data paths + data_dir: str = "./data" + pdf_dir: str = "./data/pdfs" + output_dir: str = "./output" + + # Dataset + limit: int = 0 # 0 = all 150 + seed: int = 42 + + # Search + mode: str = "FAST" + top_k_files: int = 5 + max_token_budget: int = 128000 + enable_dir_scan: bool = True + + # Evaluation + eval_mode: str = "singleDoc" # singleDoc / sharedCorpus + enable_llm_judge: bool = True # TODO: LLM Judge not yet implemented, reserved for future use + extract_answer: bool = True + + # Concurrency + max_concurrent: int = 3 + request_delay: float = 0.5 + + @classmethod + def from_env(cls, env_path: str = ".env.financebench") -> "FinanceBenchConfig": + """Load config from .env file with ``os.environ`` fallback.""" + # Read .env file + env_vars: dict[str, str] = {} + p = Path(env_path) + if p.exists(): + for line in p.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" in line: + k, v = line.split("=", 1) + env_vars[k.strip()] = v.strip() + + def _get(key: str, default: str = "") -> str: + return env_vars.get(key, os.environ.get(key, default)) + + def _bool(key: str, default: bool = False) -> bool: + v = _get(key, str(default)).lower() + return v in ("true", "1", "yes") + + def _int(key: str, default: int = 0) -> int: + try: + return int(_get(key, str(default))) + except (ValueError, TypeError): + return default + + def _float(key: str, default: float = 0.0) -> float: + try: + return float(_get(key, str(default))) + except (ValueError, TypeError): + return default + + return cls( + llm_api_key=_get("LLM_API_KEY"), + llm_base_url=_get( + "LLM_BASE_URL", + "https://dashscope.aliyuncs.com/compatible-mode/v1", + ), + llm_model=_get("LLM_MODEL_NAME", "qwen3.5-plus"), + llm_timeout=_int("LLM_TIMEOUT", 120), + data_dir=_get("FB_DATA_DIR", "./data"), + pdf_dir=_get("FB_PDF_DIR", "./data/pdfs"), + output_dir=_get("FB_OUTPUT_DIR", "./output"), + limit=_int("FB_LIMIT", 0), + seed=_int("FB_SEED", 42), + mode=_get("FB_MODE", "FAST"), + top_k_files=_int("FB_TOP_K_FILES", 5), + max_token_budget=_int("FB_MAX_TOKEN_BUDGET", 128000), + enable_dir_scan=_bool("FB_ENABLE_DIR_SCAN", True), + eval_mode=_get("FB_EVAL_MODE", "singleDoc"), + enable_llm_judge=_bool("FB_ENABLE_LLM_JUDGE", True), + extract_answer=_bool("FB_EXTRACT_ANSWER", True), + max_concurrent=_int("FB_MAX_CONCURRENT", 3), + request_delay=_float("FB_REQUEST_DELAY", 0.5), + ) diff --git a/benchmarks/financebench/data_loader.py b/benchmarks/financebench/data_loader.py new file mode 100644 index 0000000..7770865 --- /dev/null +++ b/benchmarks/financebench/data_loader.py @@ -0,0 +1,108 @@ +"""FinanceBench dataset loader.""" +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple + + +class FinanceBenchLoader: + """Load and validate FinanceBench JSONL data. + + Expects: + - ``data_dir/financebench_open_source.jsonl`` – 150 QA rows + - ``data_dir/financebench_document_information.jsonl`` – doc metadata (optional) + - ``pdf_dir/`` – corpus of 41 SEC-filing PDFs named by ``doc_name`` + """ + + _QUESTIONS_FILE = "financebench_open_source.jsonl" + _DOC_INFO_FILE = "financebench_document_information.jsonl" + + def __init__(self, data_dir: str, pdf_dir: str) -> None: + self._data_dir = Path(data_dir) + self._pdf_dir = Path(pdf_dir) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def load_questions(self) -> List[Dict[str, Any]]: + """Load the 150 open-source questions from JSONL. + + Raises: + FileNotFoundError: If the questions file is missing. + """ + path = self._data_dir / self._QUESTIONS_FILE + if not path.exists(): + raise FileNotFoundError(f"Questions file not found: {path}") + items: list[dict] = [] + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + items.append(json.loads(line)) + return items + + def load_doc_info(self) -> Dict[str, Dict[str, Any]]: + """Load document metadata, keyed by ``doc_name``. + + Returns an empty dict when the file is absent (it is optional). + """ + path = self._data_dir / self._DOC_INFO_FILE + if not path.exists(): + return {} + result: dict[str, dict] = {} + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + obj = json.loads(line) + doc_name = obj.get("doc_name", "") + if doc_name: + result[doc_name] = obj + return result + + def get_pdf_path(self, doc_name: str) -> Optional[str]: + """Resolve *doc_name* to a PDF file path. + + Resolution order: + 1. ``/.pdf`` + 2. ``/`` (file with no extension) + 3. Case-insensitive stem match across ``pdf_dir`` + """ + candidates = [ + self._pdf_dir / f"{doc_name}.pdf", + self._pdf_dir / doc_name, + ] + for c in candidates: + if c.exists(): + return str(c) + # Case-insensitive fallback + if self._pdf_dir.exists(): + lower = doc_name.lower() + for f in self._pdf_dir.iterdir(): + if f.stem.lower() == lower: + return str(f) + return None + + def get_unique_docs(self, questions: List[Dict[str, Any]]) -> Set[str]: + """Extract the unique set of ``doc_name`` values from *questions*.""" + return {q["doc_name"] for q in questions if "doc_name" in q} + + def validate_corpus( + self, questions: List[Dict[str, Any]] + ) -> Tuple[int, List[str]]: + """Check PDF availability for all referenced documents. + + Returns: + A tuple of ``(found_count, missing_doc_names)``. + """ + docs = self.get_unique_docs(questions) + missing: list[str] = [] + found = 0 + for doc in sorted(docs): + if self.get_pdf_path(doc): + found += 1 + else: + missing.append(doc) + return found, missing diff --git a/benchmarks/financebench/evaluate.py b/benchmarks/financebench/evaluate.py new file mode 100644 index 0000000..688cf41 --- /dev/null +++ b/benchmarks/financebench/evaluate.py @@ -0,0 +1,323 @@ +"""FinanceBench evaluation metrics. + +Implements the three-class scoring scheme from the FinanceBench paper +(Islam et al., 2023): **correct**, **hallucination**, **refusal**. + +Financial-value normalisation handles currency symbols, thousand separators, +trailing zeros, and percentage signs so that ``$1,577.00`` matches ``1577``. +""" +from __future__ import annotations + +import re +from collections import Counter, defaultdict +from typing import Any, Dict, List + +# ------------------------------------------------------------------ +# Constants +# ------------------------------------------------------------------ + +_REFUSAL_PHRASES: list[str] = [ + "i cannot", + "i can't", + "i could not", + "i couldn't", + "no results found", + "unable to", + "not able to", + "i don't know", + "i do not know", + "information is not available", + "not enough information", + "cannot determine", + "cannot be determined", + "insufficient data", + "no relevant information", + "data not found", + "unknown", +] + +_F1_CORRECT_THRESHOLD: float = 0.8 + +# Markdown / wrapper patterns compiled once +_RE_BOLD = re.compile(r"\*\*(.+?)\*\*") +_RE_ITALIC = re.compile(r"\*(.+?)\*") +_RE_QUOTES = re.compile(r'^["\u201c\u201d\']+|["\u201c\u201d\']+$') +_RE_ANSWER_PREFIX = re.compile( + r"^(the\s+(short\s+)?answer\s+is\s*:?\s*|answer\s*:\s*|short\s+answer\s*:\s*)", + re.IGNORECASE, +) +# Financial value helpers +_RE_DOLLAR = re.compile(r"^\$\s*") +_RE_THOUSAND_SEP = re.compile(r",(\d{3})") +_RE_TRAILING_ZEROS = re.compile(r"\.0+$") + + +# ------------------------------------------------------------------ +# Normalisation +# ------------------------------------------------------------------ + + +def normalize_answer(answer: str) -> str: + """Normalise an answer string for comparison. + + Steps: + 1. Strip Markdown bold / italic. + 2. Strip surrounding quotes. + 3. Strip trailing punctuation (``.``, ``:``). + 4. Remove common LLM wrapper phrases. + 5. Financial value normalisation (currency, commas, trailing zeros). + 6. Lowercase. + """ + s = answer.strip() + if not s: + return "" + + # 1. Markdown + s = _RE_BOLD.sub(r"\1", s) + s = _RE_ITALIC.sub(r"\1", s) + + # 2. Quotes + s = _RE_QUOTES.sub("", s).strip() + + # 3. Trailing punctuation + s = s.rstrip(".:") + + # 4. Wrapper phrases + s = _RE_ANSWER_PREFIX.sub("", s).strip() + + # 5. Financial normalisation + s = _normalize_financial_value(s) + + # 6. Lowercase + return s.lower().strip() + + +def _normalize_financial_value(text: str) -> str: + """Normalise financial figures for robust comparison. + + - ``$1,577.00`` → ``1577`` + - ``15.3%`` → ``15.3%`` + - ``$1577`` → ``1577`` + - ``1,577`` → ``1577`` + """ + s = text.strip() + + # Detect if value looks numeric (possibly with $ / % / commas) + stripped_for_check = _RE_DOLLAR.sub("", s) + stripped_for_check = stripped_for_check.replace(",", "").rstrip("%").strip() + try: + float(stripped_for_check) + except ValueError: + return s # Not a numeric value – return as-is + + # Remove dollar sign + s = _RE_DOLLAR.sub("", s) + + # Remember and temporarily strip percentage + has_pct = s.endswith("%") + if has_pct: + s = s[:-1].strip() + + # Remove thousand-separator commas + s = s.replace(",", "") + + # Remove trailing decimal zeros: 1577.00 → 1577, 15.30 → 15.3 + if "." in s: + s = s.rstrip("0").rstrip(".") + + # Re-attach percentage + if has_pct: + s = s + "%" + + return s + + +# ------------------------------------------------------------------ +# Matching helpers +# ------------------------------------------------------------------ + + +def exact_match(prediction: str, gold: str) -> bool: + """Return ``True`` when normalised strings are identical.""" + return normalize_answer(prediction) == normalize_answer(gold) + + +def f1_score(prediction: str, gold: str) -> float: + """Compute token-level F1 between *prediction* and *gold*. + + Tokenisation is simple whitespace splitting after normalisation. + Each token is further normalised as a financial value so that + ``$1577`` matches ``1577`` at the token level. + Returns 0.0 when either side is empty. + """ + pred_tokens = [_normalize_financial_value(t) for t in normalize_answer(prediction).split()] + gold_tokens = [_normalize_financial_value(t) for t in normalize_answer(gold).split()] + if not pred_tokens or not gold_tokens: + return 0.0 + + common = Counter(pred_tokens) & Counter(gold_tokens) + num_common = sum(common.values()) + if num_common == 0: + return 0.0 + + precision = num_common / len(pred_tokens) + recall = num_common / len(gold_tokens) + return 2 * precision * recall / (precision + recall) + + +# ------------------------------------------------------------------ +# Three-class classification +# ------------------------------------------------------------------ + + +def classify_answer( + prediction: str, + gold: str, + *, + is_no_result: bool = False, + f1_threshold: float = _F1_CORRECT_THRESHOLD, +) -> str: + """Classify a prediction into ``correct``, ``refusal``, or ``hallucination``. + + Classification logic (faithful to FinanceBench paper): + 1. If the system explicitly refused (``is_no_result=True``) or the + prediction contains a refusal phrase → **refusal**. + 2. If EM passes or token-level F1 ≥ *f1_threshold* → **correct**. + 3. Otherwise → **hallucination**. + """ + norm_pred = normalize_answer(prediction) + + # --- Refusal --- + if is_no_result: + return "refusal" + pred_lower = norm_pred.lower() + for phrase in _REFUSAL_PHRASES: + if phrase in pred_lower: + return "refusal" + + # --- Correct --- + if exact_match(prediction, gold): + return "correct" + if f1_score(prediction, gold) >= f1_threshold: + return "correct" + + # --- Hallucination --- + return "hallucination" + + +# ------------------------------------------------------------------ +# Evidence recall +# ------------------------------------------------------------------ + + +def evidence_recall( + retrieved_pages: List[int], + gold_evidence: List[Dict[str, Any]], +) -> float: + """Compute page-level evidence recall. + + ``gold_evidence`` entries carry ``evidence_page_num`` (0-indexed). + Returns 1.0 when there is no gold evidence (vacuously true). + """ + if not gold_evidence: + return 1.0 + + gold_pages = { + int(e["evidence_page_num"]) + for e in gold_evidence + if "evidence_page_num" in e + } + if not gold_pages: + return 1.0 + + retrieved_set = set(retrieved_pages) + hits = gold_pages & retrieved_set + return len(hits) / len(gold_pages) + + +# ------------------------------------------------------------------ +# Aggregate metrics +# ------------------------------------------------------------------ + + +def compute_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate per-question results into benchmark-level metrics. + + Expected keys per result dict: ``classification``, ``em``, ``f1``, + ``elapsed``, ``telemetry``, ``question_type``, ``question_reasoning``, + ``evidence_recall`` (optional). + + Returns a dict with overall stats plus breakdowns by *question_type* + and *question_reasoning*. + """ + n = len(results) + if n == 0: + return {"n": 0} + + # --- Overall counts --- + correct = sum(1 for r in results if r.get("classification") == "correct") + halluc = sum(1 for r in results if r.get("classification") == "hallucination") + refusal = sum(1 for r in results if r.get("classification") == "refusal") + + em_sum = sum(1 for r in results if r.get("em")) + f1_sum = sum(r.get("f1", 0.0) for r in results) + + latencies = [r["elapsed"] for r in results if "elapsed" in r] + avg_latency = sum(latencies) / len(latencies) if latencies else 0.0 + + token_counts = [ + r.get("telemetry", {}).get("total_tokens", 0) for r in results + ] + avg_tokens = sum(token_counts) / len(token_counts) if token_counts else 0 + + ev_recalls = [r["evidence_recall"] for r in results if r.get("evidence_recall") is not None] + avg_ev_recall = sum(ev_recalls) / len(ev_recalls) if ev_recalls else None + + overall = { + "n": n, + "accuracy": round(correct / n * 100, 2), + "hallucination_rate": round(halluc / n * 100, 2), + "refusal_rate": round(refusal / n * 100, 2), + "correct": correct, + "hallucination": halluc, + "refusal": refusal, + "avg_em": em_sum / n, + "avg_f1": f1_sum / n, + "avg_latency": round(avg_latency, 2), + "avg_tokens": round(avg_tokens, 1), + } + if avg_ev_recall is not None: + overall["evidence_recall"] = round(avg_ev_recall, 4) + + # --- Breakdowns --- + overall["by_question_type"] = _breakdown(results, "question_type") + overall["by_question_reasoning"] = _breakdown(results, "question_reasoning") + + return overall + + +def _breakdown(results: List[Dict[str, Any]], key: str) -> Dict[str, Dict[str, Any]]: + """Compute per-group accuracy / hallucination / refusal breakdown.""" + groups: dict[str, list[dict]] = defaultdict(list) + for r in results: + group = r.get(key, "unknown") + groups[group].append(r) + + out: dict[str, dict] = {} + for group, items in sorted(groups.items()): + g_n = len(items) + g_correct = sum(1 for r in items if r.get("classification") == "correct") + g_halluc = sum( + 1 for r in items if r.get("classification") == "hallucination" + ) + g_refusal = sum(1 for r in items if r.get("classification") == "refusal") + out[group] = { + "n": g_n, + "accuracy": round(g_correct / g_n * 100, 2) if g_n else 0.0, + "hallucination_rate": round(g_halluc / g_n * 100, 2) if g_n else 0.0, + "refusal_rate": round(g_refusal / g_n * 100, 2) if g_n else 0.0, + "correct": g_correct, + "hallucination": g_halluc, + "refusal": g_refusal, + } + return out diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py new file mode 100644 index 0000000..28e99d7 --- /dev/null +++ b/benchmarks/financebench/run_benchmark.py @@ -0,0 +1,239 @@ +"""FinanceBench benchmark entry point. + +Usage: + cd benchmarks/financebench + python run_benchmark.py [--env .env.financebench] [--limit N] + +Examples: + # Run all 150 questions with default config + python run_benchmark.py + + # Run a quick sanity check with 10 questions + python run_benchmark.py --limit 10 + + # Use a custom .env file + python run_benchmark.py --env .env.custom --limit 20 +""" +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import random +import sys +import time +from datetime import datetime +from pathlib import Path +from typing import List + +from config import FinanceBenchConfig +from data_loader import FinanceBenchLoader +from evaluate import compute_metrics +from runner import run_batch + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- + + +def setup_logging(output_dir: str) -> str: + """Configure logging to file + console. + + Creates a timestamped log file under ``logs/`` (relative to *output_dir*'s + parent, i.e. the benchmark root directory). + + Returns: + Absolute path to the log file. + """ + log_dir = Path("logs") + log_dir.mkdir(parents=True, exist_ok=True) + + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + log_path = log_dir / f"benchmark_{ts}.log" + + root_logger = logging.getLogger("financebench") + root_logger.setLevel(logging.DEBUG) + + # File handler – DEBUG level, full detail + fh = logging.FileHandler(str(log_path), encoding="utf-8") + fh.setLevel(logging.DEBUG) + fh.setFormatter( + logging.Formatter( + "%(asctime)s %(name)-28s %(levelname)-7s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + + # Console handler – INFO level, concise + ch = logging.StreamHandler(sys.stdout) + ch.setLevel(logging.INFO) + ch.setFormatter( + logging.Formatter("%(asctime)s %(levelname)-7s %(message)s", datefmt="%H:%M:%S") + ) + + root_logger.addHandler(fh) + root_logger.addHandler(ch) + + return str(log_path.resolve()) + + +# --------------------------------------------------------------------------- +# Summary printing +# --------------------------------------------------------------------------- + + +def _print_summary( + results: List[dict], + metrics: dict, + total_time: float, + results_path: Path, + metrics_path: Path, + log_path: str, +) -> None: + """Print a human-readable run summary to stdout.""" + n = len(results) + acc = metrics.get("accuracy", 0) + hallu = metrics.get("hallucination_rate", 0) + refuse = metrics.get("refusal_rate", 0) + avg_em = metrics.get("avg_em", 0) + avg_f1 = metrics.get("avg_f1", 0) + ev_recall = metrics.get("evidence_recall") + avg_latency = metrics.get("avg_latency", 0) + + print("\n" + "=" * 60) + print(f"FinanceBench Results ({n} questions)") + print("=" * 60) + print(f" Accuracy: {acc:.1f}%") + print(f" Hallucination Rate: {hallu:.1f}%") + print(f" Refusal Rate: {refuse:.1f}%") + print(f" Avg EM: {avg_em:.3f}") + print(f" Avg F1: {avg_f1:.3f}") + if ev_recall is not None: + print(f" Evidence Recall: {ev_recall:.3f}") + else: + print(f" Evidence Recall: N/A (page-level telemetry unavailable)") + print(f" Avg Latency: {avg_latency:.1f}s") + print(f" Total Time: {total_time:.1f}s") + print(f"\n Results: {results_path}") + print(f" Metrics: {metrics_path}") + print(f" Log: {log_path}") + + # Breakdown by question_type + by_qt = metrics.get("by_question_type") + if by_qt: + print(f"\n {'Question Type':<25} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}") + print(" " + "-" * 52) + for qt, m in sorted(by_qt.items()): + qt_acc = m.get("accuracy", 0) + qt_hal = m.get("hallucination_rate", 0) + qt_ref = m.get("refusal_rate", 0) + qt_n = m.get("n", 0) + print(f" {qt:<25} {qt_acc:>5.1f} {qt_hal:>7.1f} {qt_ref:>7.1f} {qt_n:>4}") + + print("=" * 60) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Parse CLI arguments, run the benchmark, and save results.""" + parser = argparse.ArgumentParser( + description="Run FinanceBench benchmark against Sirchmunk AgenticSearch", + ) + parser.add_argument( + "--env", + default=".env.financebench", + help="Path to .env config file (default: .env.financebench)", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Override FB_LIMIT — number of questions to evaluate", + ) + args = parser.parse_args() + + # 1. Load config + cfg = FinanceBenchConfig.from_env(args.env) + if args.limit is not None: + cfg.limit = args.limit + + # 2. Setup logging + log_path = setup_logging(cfg.output_dir) + logger = logging.getLogger("financebench") + + # 3. Load data + loader = FinanceBenchLoader(cfg.data_dir, cfg.pdf_dir) + questions = loader.load_questions() + logger.info("Loaded %d questions from %s", len(questions), cfg.data_dir) + + # 4. Validate corpus + found, missing = loader.validate_corpus(questions) + logger.info("PDF corpus: %d found, %d missing", found, len(missing)) + if missing: + preview = missing[:10] + suffix = "..." if len(missing) > 10 else "" + logger.warning("Missing PDFs: %s%s", preview, suffix) + + # 5. Apply limit / seed + if cfg.limit > 0 and cfg.limit < len(questions): + random.seed(cfg.seed) + questions = random.sample(questions, cfg.limit) + logger.info("Sampled %d questions (seed=%d)", len(questions), cfg.seed) + + # 6. Print run config + logger.info( + "Config: mode=%s, eval_mode=%s, extract_answer=%s, " + "llm_judge=%s, concurrent=%d, model=%s", + cfg.mode, + cfg.eval_mode, + cfg.extract_answer, + cfg.enable_llm_judge, + cfg.max_concurrent, + cfg.llm_model, + ) + + # 7. Run benchmark + t0 = time.time() + results = asyncio.run(run_batch(questions, cfg)) + total_time = time.time() - t0 + + # 8. Compute metrics + metrics = compute_metrics(results) + metrics["total_time_seconds"] = round(total_time, 2) + metrics["num_questions"] = len(questions) + metrics["config"] = { + "mode": cfg.mode, + "eval_mode": cfg.eval_mode, + "model": cfg.llm_model, + "top_k_files": cfg.top_k_files, + "extract_answer": cfg.extract_answer, + } + + # 9. Save results (JSONL) + metrics (JSON) + out_dir = Path(cfg.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + results_path = out_dir / f"results_{ts}.jsonl" + metrics_path = out_dir / f"metrics_{ts}.json" + + with open(results_path, "w", encoding="utf-8") as f: + for r in results: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + + with open(metrics_path, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=2, ensure_ascii=False) + + logger.info("Results saved to %s", results_path) + logger.info("Metrics saved to %s", metrics_path) + + # 10. Print summary + _print_summary(results, metrics, total_time, results_path, metrics_path, log_path) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/financebench/runner.py b/benchmarks/financebench/runner.py new file mode 100644 index 0000000..72d0a0b --- /dev/null +++ b/benchmarks/financebench/runner.py @@ -0,0 +1,279 @@ +"""Run AgenticSearch on FinanceBench questions. + +Supports two evaluation modes: +- **singleDoc**: each question searches only its target PDF directory. +- **sharedCorpus**: all questions search the full PDF corpus. + +After search, an optional LLM extraction step converts the verbose +briefing into a short factoid answer suitable for EM/F1. +""" +from __future__ import annotations + +import asyncio +import json as json_mod +import logging +import time +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +from config import FinanceBenchConfig +from data_loader import FinanceBenchLoader +from evaluate import ( + classify_answer, + compute_metrics, + exact_match, + evidence_recall, + f1_score, + normalize_answer, +) + +logger = logging.getLogger("financebench.runner") + +# ------------------------------------------------------------------ +# Answer extraction prompt (financial domain) +# ------------------------------------------------------------------ + +_EXTRACT_PROMPT = """\ +Given the financial question and a verbose response, extract ONLY the short factoid answer. +Rules: +- Output ONLY the answer value/phrase (1-20 words). No explanation. +- If the response says it cannot find the answer, output: unknown +- For monetary values, keep the currency format (e.g., $1,577.00) +- For percentages, keep the % sign (e.g., 15.3%) +- For yes/no questions, output: yes or no + +Question: {question} +Response: {response} + +Short answer:""" + + +# NOTE: _normalize_prediction removed — use evaluate.normalize_answer instead. + + +# ------------------------------------------------------------------ +# LLM short-answer extraction +# ------------------------------------------------------------------ + + +async def _extract_short_answer( + question: str, + verbose: str, + llm: Any, +) -> str: + """Use *llm* to distil *verbose* into a short factoid answer.""" + prompt = _EXTRACT_PROMPT.format(question=question, response=verbose[:4000]) + try: + resp = await llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, + ) + return resp.content.strip() + except Exception: + logger.warning("Short-answer extraction failed; falling back to raw answer.") + return verbose + + +# ------------------------------------------------------------------ +# Page extraction helper +# ------------------------------------------------------------------ + + +def _try_extract_pages(telemetry: Dict[str, Any]) -> List[int]: + """Best-effort extraction of retrieved page numbers from telemetry. + + Current limitation: Sirchmunk's ``read_file_ids`` contains plain file + paths without page-level suffixes, so this function will typically + return an empty list. When empty, callers should treat evidence + recall as *unavailable* (``None``) rather than zero. + """ + pages: list[int] = [] + for fid in telemetry.get("read_file_ids", []): + # Convention: page indices may be embedded in file IDs + if isinstance(fid, str) and "_page_" in fid: + try: + pages.append(int(fid.rsplit("_page_", 1)[-1])) + except (ValueError, IndexError): + pass + return pages + + +# ------------------------------------------------------------------ +# Single question execution +# ------------------------------------------------------------------ + + +async def run_single( + entry: Dict[str, Any], + loader: FinanceBenchLoader, + searcher: Any, + llm: Any, + cfg: FinanceBenchConfig, + semaphore: asyncio.Semaphore, +) -> Dict[str, Any]: + """Execute one FinanceBench question end-to-end.""" + fb_id = entry.get("financebench_id", "") + question = entry["question"] + gold = entry.get("answer", "") + gold_evidence = entry.get("evidence", []) + + async with semaphore: + t0 = time.time() + error: str | None = None + raw_answer = "" + answer = "" + telemetry: dict[str, Any] = {} + retrieved_pages: list[int] = [] + + try: + # Determine search paths based on eval mode + if cfg.eval_mode == "singleDoc": + pdf_path = loader.get_pdf_path(entry.get("doc_name", "")) + if pdf_path: + search_paths = [pdf_path] # pass the single PDF file directly + else: + logger.warning("PDF not found for %s, falling back to full corpus", entry.get("doc_name", "")) + search_paths = [cfg.pdf_dir] + else: + search_paths = [cfg.pdf_dir] + + result = await searcher.search( + query=question, + paths=search_paths, + mode=cfg.mode, + top_k_files=cfg.top_k_files, + max_token_budget=cfg.max_token_budget, + enable_dir_scan=cfg.enable_dir_scan, + return_context=True, + ) + + raw_answer = getattr(result, "answer", "") or str(result) + + # Collect telemetry + read_files = list(getattr(result, "read_file_ids", None) or set()) + telemetry = { + "read_file_ids": read_files, + "total_tokens": getattr(result, "total_llm_tokens", 0), + "loop_count": getattr(result, "loop_count", 0), + "llm_calls": len(getattr(result, "llm_usages", None) or []), + "num_files_read": len(read_files), + } + retrieved_pages = _try_extract_pages(telemetry) + + # Answer extraction + if cfg.extract_answer and raw_answer: + answer = await _extract_short_answer(question, raw_answer, llm) + answer = normalize_answer(answer) + else: + answer = normalize_answer(raw_answer) + + except Exception as exc: + error = str(exc) + logger.error("Error on %s: %s", fb_id, error) + + elapsed = time.time() - t0 + + # Delay between requests + if cfg.request_delay > 0: + await asyncio.sleep(cfg.request_delay) + + # --- Evaluation --- + is_no_result = not answer or answer.lower() in ("unknown", "") + em = exact_match(answer, gold) + f1 = f1_score(answer, gold) + classification = classify_answer(answer, gold, is_no_result=is_no_result) + if retrieved_pages: # only compute when page-level data is available + ev_recall = evidence_recall(retrieved_pages, gold_evidence) + else: + ev_recall = None # mark as unavailable, avoid false 0 + + return { + "financebench_id": fb_id, + "question": question, + "prediction": answer, + "raw_prediction": raw_answer, + "gold_answer": gold, + "company": entry.get("company", ""), + "doc_name": entry.get("doc_name", ""), + "question_type": entry.get("question_type", ""), + "question_reasoning": entry.get("question_reasoning", ""), + "elapsed": round(elapsed, 2), + "telemetry": telemetry, + "classification": classification, + "em": em, + "f1": round(f1, 4), + "evidence_recall": round(ev_recall, 4) if ev_recall is not None else None, + "error": error, + } + + +# ------------------------------------------------------------------ +# Batch execution +# ------------------------------------------------------------------ + + +async def run_batch( + samples: List[Dict[str, Any]], + cfg: FinanceBenchConfig, +) -> List[Dict[str, Any]]: + """Run all *samples* concurrently and persist results incrementally.""" + from sirchmunk.llm.openai_chat import OpenAIChat + from sirchmunk.search import AgenticSearch + + llm = OpenAIChat( + api_key=cfg.llm_api_key, + base_url=cfg.llm_base_url, + model=cfg.llm_model, + ) + searcher = AgenticSearch(llm=llm, reuse_knowledge=False, verbose=False) + loader = FinanceBenchLoader(data_dir=cfg.data_dir, pdf_dir=cfg.pdf_dir) + semaphore = asyncio.Semaphore(cfg.max_concurrent) + + # Prepare output directory / file + out_dir = Path(cfg.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + out_path = out_dir / f"financebench_{ts}.jsonl" + + results: list[dict] = [] + completed = 0 + total = len(samples) + + async def _run_and_record(entry: Dict[str, Any]) -> Dict[str, Any]: + nonlocal completed + res = await run_single(entry, loader, searcher, llm, cfg, semaphore) + # Incremental save + with open(out_path, "a", encoding="utf-8") as fp: + fp.write(json_mod.dumps(res, ensure_ascii=False) + "\n") + completed += 1 + status = res["classification"] + logger.info( + "[%d/%d] %s %s EM=%s F1=%.2f %.1fs", + completed, + total, + res["financebench_id"], + status, + res["em"], + res["f1"], + res["elapsed"], + ) + return res + + tasks = [asyncio.create_task(_run_and_record(s)) for s in samples] + results = await asyncio.gather(*tasks) + + # Write aggregate metrics + metrics = compute_metrics(list(results)) + metrics_path = out_dir / f"financebench_{ts}_metrics.json" + with open(metrics_path, "w", encoding="utf-8") as fp: + json_mod.dump(metrics, fp, indent=2, ensure_ascii=False) + logger.info("Metrics saved to %s", metrics_path) + logger.info( + "Accuracy=%.2f%% Hallucination=%.2f%% Refusal=%.2f%%", + metrics.get("accuracy", 0), + metrics.get("hallucination_rate", 0), + metrics.get("refusal_rate", 0), + ) + + return list(results) From 4a0a01796bc51ae0869a8b78617aa0759cdc4c12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 16 Apr 2026 16:27:03 +0800 Subject: [PATCH 14/70] add llm judge for financebench --- benchmarks/financebench/analyze_results.py | 62 ++- benchmarks/financebench/config.py | 3 +- benchmarks/financebench/evaluate.py | 60 ++- benchmarks/financebench/judge.py | 420 +++++++++++++++++++++ benchmarks/financebench/run_benchmark.py | 37 +- benchmarks/financebench/runner.py | 37 +- 6 files changed, 597 insertions(+), 22 deletions(-) create mode 100644 benchmarks/financebench/judge.py diff --git a/benchmarks/financebench/analyze_results.py b/benchmarks/financebench/analyze_results.py index 24d2b64..a804284 100644 --- a/benchmarks/financebench/analyze_results.py +++ b/benchmarks/financebench/analyze_results.py @@ -69,16 +69,34 @@ def print_breakdown(title: str, breakdown: Dict[str, Dict[str, Any]]) -> None: breakdown: ``{group_name: {accuracy, hallucination_rate, ...}}``. """ print(f"\n=== Breakdown by {title} ===\n") - header = f" {'Group':<30} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}" - print(header) - print(" " + "-" * (len(header) - 2)) - for group, m in sorted(breakdown.items(), key=lambda kv: -kv[1].get("accuracy", 0)): - acc = m.get("accuracy", 0) - hal = m.get("hallucination_rate", 0) - ref = m.get("refusal_rate", 0) - n = m.get("n", 0) - print(f" {group:<30} {acc:>5.1f} {hal:>7.1f} {ref:>7.1f} {n:>4}") + # Determine if judge data is available + has_judge = any(m.get("llm_judge_accuracy") is not None for m in breakdown.values()) + + if has_judge: + header = f" {'Group':<30} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'Judge%':>7} {'N':>4}" + print(header) + print(" " + "-" * (len(header) - 2)) + + for group, m in sorted(breakdown.items(), key=lambda kv: -kv[1].get("accuracy", 0)): + acc = m.get("accuracy", 0) + hal = m.get("hallucination_rate", 0) + ref = m.get("refusal_rate", 0) + n = m.get("n", 0) + jdg = m.get("llm_judge_accuracy") + jdg_str = f"{jdg:>6.1f}" if jdg is not None else " N/A" + print(f" {group:<30} {acc:>5.1f} {hal:>7.1f} {ref:>7.1f} {jdg_str} {n:>4}") + else: + header = f" {'Group':<30} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}" + print(header) + print(" " + "-" * (len(header) - 2)) + + for group, m in sorted(breakdown.items(), key=lambda kv: -kv[1].get("accuracy", 0)): + acc = m.get("accuracy", 0) + hal = m.get("hallucination_rate", 0) + ref = m.get("refusal_rate", 0) + n = m.get("n", 0) + print(f" {group:<30} {acc:>5.1f} {hal:>7.1f} {ref:>7.1f} {n:>4}") def _compute_company_breakdown( @@ -183,6 +201,12 @@ def print_comparison_with_sota(metrics: Dict[str, Any]) -> None: n = metrics.get("n", 0) coverage = min(100.0, n / 150.0 * 100) print(f" {'Sirchmunk (This Run)':<30} {f'{acc:.1f}%':>10} {f'{coverage:.0f}%':>10}") + + # Show Judge Accuracy in SOTA table if available + judge_acc = metrics.get("llm_judge_accuracy") + if judge_acc is not None: + print(f" {'Sirchmunk (Judge Acc)':<30} {f'{judge_acc:.1f}%':>10} {f'{coverage:.0f}%':>10}") + print(f"\n (This run evaluated {n} questions)") @@ -247,6 +271,12 @@ def main() -> None: print(f" Evidence Recall: N/A (page-level telemetry unavailable)") print(f" Avg Latency: {avg_latency:.1f}s") + # LLM Judge independent metrics + if metrics.get("llm_judge_accuracy") is not None: + print(f"\n --- LLM Judge (Independent Evaluation) ---") + print(f" Judge Accuracy: {metrics['llm_judge_accuracy']:.1f}%") + print(f" Judge Correct: {metrics['llm_judge_correct']}/{metrics['llm_judge_count']}") + # --- Breakdowns --- if "by_question_type" in metrics: print_breakdown("Question Type", metrics["by_question_type"]) @@ -260,6 +290,20 @@ def main() -> None: # --- Error cases --- print_error_cases(results, max_show=args.max_errors) + # --- Judge-Rule Discrepancies --- + discrepancies = [r for r in results + if r.get("llm_judge_correct") is not None + and r.get("classification") != "correct" + and r.get("llm_judge_correct") is True] + if discrepancies: + print(f"\n=== Judge-Rule Discrepancies ({len(discrepancies)} cases) ===") + print(" (Cases where LLM Judge says correct but EM/F1 says wrong)") + for r in discrepancies[:10]: + print(f" {r.get('financebench_id', 'N/A')}: pred='{r.get('prediction', '')[:50]}' gold='{r.get('gold_answer', '')[:50]}'") + print(f" classification={r.get('classification')}, judge_reasoning={r.get('llm_judge_reasoning', '')[:80]}") + if len(discrepancies) > 10: + print(f" ... and {len(discrepancies) - 10} more discrepancy(ies) not shown.") + # --- SOTA comparison --- print_comparison_with_sota(metrics) diff --git a/benchmarks/financebench/config.py b/benchmarks/financebench/config.py index f2e0fdb..f51ea36 100644 --- a/benchmarks/financebench/config.py +++ b/benchmarks/financebench/config.py @@ -33,8 +33,9 @@ class FinanceBenchConfig: # Evaluation eval_mode: str = "singleDoc" # singleDoc / sharedCorpus - enable_llm_judge: bool = True # TODO: LLM Judge not yet implemented, reserved for future use + enable_llm_judge: bool = True # Use LLM to judge semantic equivalence (independent metric) extract_answer: bool = True + judge_f1_threshold: float = 0.8 # F1 threshold for 'correct' classification # Concurrency max_concurrent: int = 3 diff --git a/benchmarks/financebench/evaluate.py b/benchmarks/financebench/evaluate.py index 688cf41..3e78636 100644 --- a/benchmarks/financebench/evaluate.py +++ b/benchmarks/financebench/evaluate.py @@ -34,6 +34,26 @@ "no relevant information", "data not found", "unknown", + "i'm not able to", + "i am not able to", + "the document does not contain", + "the document doesn't contain", + "this information is not disclosed", + "not disclosed", + "could not find", + "couldn't find", + "no mention of", + "no information about", + "not provided in", + "not found in the document", + "i was unable to", + "unable to determine", + "unable to find", + "unable to locate", + "there is no data", + "no data available", + "not available in", + "not specified", ] _F1_CORRECT_THRESHOLD: float = 0.8 @@ -99,16 +119,29 @@ def _normalize_financial_value(text: str) -> str: - ``15.3%`` → ``15.3%`` - ``$1577`` → ``1577`` - ``1,577`` → ``1577`` + - ``($500)`` → ``-500`` + - ``-$500`` → ``-500`` """ s = text.strip() + # Handle accounting bracket notation for negatives: ($500) → -$500 + if s.startswith("(") and s.endswith(")"): + s = "-" + s[1:-1] + + # Handle negative sign: remember it, strip it for processing + negative = False + if s.startswith("-"): + negative = True + s = s[1:] + # Detect if value looks numeric (possibly with $ / % / commas) stripped_for_check = _RE_DOLLAR.sub("", s) stripped_for_check = stripped_for_check.replace(",", "").rstrip("%").strip() try: float(stripped_for_check) except ValueError: - return s # Not a numeric value – return as-is + # Not a numeric value – restore negative sign and return as-is + return ("-" + s) if negative else s # Remove dollar sign s = _RE_DOLLAR.sub("", s) @@ -129,6 +162,10 @@ def _normalize_financial_value(text: str) -> str: if has_pct: s = s + "%" + # Re-attach negative sign + if negative and not s.startswith("-"): + s = "-" + s + return s @@ -289,6 +326,18 @@ def compute_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: if avg_ev_recall is not None: overall["evidence_recall"] = round(avg_ev_recall, 4) + # --- LLM Judge metrics (independent dimension, NOT fallback) --- + judge_results = [r for r in results if r.get("llm_judge_correct") is not None] + if judge_results: + judge_correct = sum(1 for r in judge_results if r["llm_judge_correct"]) + overall["llm_judge_accuracy"] = round(judge_correct / len(judge_results) * 100, 2) + overall["llm_judge_count"] = len(judge_results) + overall["llm_judge_correct"] = judge_correct + else: + overall["llm_judge_accuracy"] = None + overall["llm_judge_count"] = 0 + overall["llm_judge_correct"] = 0 + # --- Breakdowns --- overall["by_question_type"] = _breakdown(results, "question_type") overall["by_question_reasoning"] = _breakdown(results, "question_reasoning") @@ -311,7 +360,7 @@ def _breakdown(results: List[Dict[str, Any]], key: str) -> Dict[str, Dict[str, A 1 for r in items if r.get("classification") == "hallucination" ) g_refusal = sum(1 for r in items if r.get("classification") == "refusal") - out[group] = { + group_dict: dict[str, Any] = { "n": g_n, "accuracy": round(g_correct / g_n * 100, 2) if g_n else 0.0, "hallucination_rate": round(g_halluc / g_n * 100, 2) if g_n else 0.0, @@ -320,4 +369,11 @@ def _breakdown(results: List[Dict[str, Any]], key: str) -> Dict[str, Dict[str, A "hallucination": g_halluc, "refusal": g_refusal, } + # LLM Judge breakdown + g_judge = [r for r in items if r.get("llm_judge_correct") is not None] + if g_judge: + g_jc = sum(1 for r in g_judge if r["llm_judge_correct"]) + group_dict["llm_judge_accuracy"] = round(g_jc / len(g_judge) * 100, 2) + group_dict["llm_judge_count"] = len(g_judge) + out[group] = group_dict return out diff --git a/benchmarks/financebench/judge.py b/benchmarks/financebench/judge.py new file mode 100644 index 0000000..e52b6e6 --- /dev/null +++ b/benchmarks/financebench/judge.py @@ -0,0 +1,420 @@ +"""LLM-based semantic equivalence judge for FinanceBench. + +The judge evaluates whether a model's prediction is semantically +equivalent to the gold answer, operating as an **independent** +evaluation dimension alongside EM/F1 — not as a fallback. + +This provides a more nuanced correctness signal for financial QA, +where formatting differences (e.g., $1.5B vs $1,500M) can cause +EM/F1 to undercount correct answers. +""" + +from __future__ import annotations + +import json +import logging +import re +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +_JUDGE_PROMPT = """\ +You are an expert financial analyst and auditor evaluating answer correctness \ +with **zero tolerance for numerical or factual errors**. + +Question: {question} +Gold Answer: {gold} +Model Prediction: {prediction} + +Task: Determine if the model's prediction is **semantically equivalent** \ +to the gold answer in the context of this financial question. + +═══════════════════════════════════════════════ +EQUIVALENT — only when ALL of the following hold: +═══════════════════════════════════════════════ + +1. **Numerical precision (ZERO TOLERANCE)**: + - Values must be mathematically identical after unit conversion. + - $1.5B = $1,500M = $1,500,000K = $1,500,000,000 ✓ + - $1,577 ≠ $1,580 ✗ (rounding is NOT acceptable) + - 15.3% = 15.30% = 0.153 ✓ but 15.3% ≠ 15% ✗ + - $1.5M ≠ $1.5B ✗ (unit mismatch is a critical error) + +2. **Negative / bracket notation**: + - ($500) = -$500 = -500 ✓ + - ($500) ≠ $500 ✗ (sign matters) + +3. **Time period / fiscal year**: + - FY2018 = fiscal year 2018 = 2018 ✓ + - FY2018 ≠ FY2019 ✗ (different fiscal year — NEVER equivalent) + - Q3 2019 ≠ Q4 2019 ✗ (different quarter) + - "year ended December 2018" = FY2018 ✓ + +4. **Currency formatting**: + - $1,577.00 = $1577 = 1577 ✓ (same value, format differs) + +5. **Financial term equivalences (accepted)**: + - net income = net profit ✓ + - CAPEX = capital expenditure ✓ + - EPS = earnings per share ✓ + - EBITDA = earnings before interest, taxes, depreciation and amortization ✓ + - YoY = year-over-year ✓ + - COGS = cost of goods sold ✓ + - D&A = depreciation and amortization ✓ + +6. **Financial term distinctions (NOT interchangeable)**: + - revenue ≠ net revenue ≠ gross revenue (unless context is clear) + - operating income ≠ net income + - gross profit ≠ net profit + - total assets ≠ net assets + +7. **Prediction with extra context**: + - If prediction contains the correct answer with additional supporting \ + detail, treat as equivalent (e.g., "Revenue was $1,577M in FY2018" \ + vs "$1,577M" — equivalent, provided the value is correct). + +═══════════════════════════════════════════════ +NOT EQUIVALENT — if ANY of the following hold: +═══════════════════════════════════════════════ + +1. Different numerical values (even slightly: $1,577 ≠ $1,580) +2. Different time periods or fiscal years +3. Different companies or entities +4. Opposite trend direction (increased ≠ decreased, growth ≠ decline) +5. Unit mismatch ($1.5M ≠ $1.5B) +6. Missing or wrong sign (positive ≠ negative) +7. Prediction is vague or hedging where gold is precise +8. Prediction is a refusal or states it cannot find the answer +9. Near-approximate values that are not mathematically equal after unit conversion + +═══════════════════════════════════════════════ +CONSERVATIVE JUDGMENT POLICY +═══════════════════════════════════════════════ + +- **When in doubt, judge as NOT equivalent.** Financial accuracy demands \ + precision; a false positive (incorrectly marking wrong answer as correct) \ + is far worse than a false negative. +- If you are less than 80% confident the answers are equivalent, \ + judge as NOT equivalent. +- Set confidence to reflect your actual certainty (0.0 = no idea, \ + 1.0 = absolutely certain). + +═══════════════════════════════════════════════ +FEW-SHOT EXAMPLES +═══════════════════════════════════════════════ + +Example 1 — EQUIVALENT (format difference): + Gold: "$1,577" | Prediction: "$1,577.00 million" + → {{"equivalent": true, "confidence": 0.95, "reasoning": "Same value $1,577M, trailing zeros are formatting."}} + +Example 2 — EQUIVALENT (abbreviation): + Gold: "$1.5 billion" | Prediction: "$1,500M" + → {{"equivalent": true, "confidence": 0.97, "reasoning": "$1.5B = $1,500M, correct unit conversion."}} + +Example 3 — NOT EQUIVALENT (different value): + Gold: "$1,577" | Prediction: "$1,580" + → {{"equivalent": false, "confidence": 0.99, "reasoning": "Values differ: 1577 ≠ 1580. No rounding tolerance."}} + +Example 4 — NOT EQUIVALENT (different fiscal year): + Gold: "FY2018" | Prediction: "FY2019" + → {{"equivalent": false, "confidence": 1.0, "reasoning": "Different fiscal years."}} + +Example 5 — NOT EQUIVALENT (unit mismatch): + Gold: "$1.5 million" | Prediction: "$1.5 billion" + → {{"equivalent": false, "confidence": 1.0, "reasoning": "Unit mismatch: million ≠ billion."}} + +Example 6 — EQUIVALENT (negative notation): + Gold: "-$500" | Prediction: "($500)" + → {{"equivalent": true, "confidence": 0.98, "reasoning": "Same negative value, bracket = negative."}} + +Respond ONLY with a JSON object (no markdown, no extra text): +{{"equivalent": true or false, "confidence": 0.0 to 1.0, "reasoning": "brief explanation"}}""" + + +# Refusal detection phrases (subset for quick judge-side check) +_REFUSAL_INDICATORS: frozenset[str] = frozenset( + { + "i cannot", + "i can't", + "unable to", + "not able to", + "i don't know", + "i do not know", + "unknown", + "no results found", + "cannot determine", + "insufficient data", + "data not found", + "could not find", + "couldn't find", + "unable to determine", + "unable to find", + } +) + + +class FinanceBenchLLMJudge: + """LLM-based judge for semantic equivalence in financial QA. + + Operates as an independent evaluation dimension — NOT as a + fallback for EM/F1. Each question gets a separate judge verdict + that is tracked in its own metrics. + """ + + _CONFIDENCE_THRESHOLD: float = 0.7 + _MAX_RETRIES: int = 2 + + def __init__(self, llm: Any) -> None: + self._llm = llm + self._cache: Dict[tuple, Dict[str, Any]] = {} + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def judge( + self, + prediction: str, + gold_answer: str, + question: str = "", + ) -> Dict[str, Any]: + """Judge whether prediction is semantically equivalent to gold. + + Args: + prediction: Model's answer text. + gold_answer: Ground-truth answer text. + question: The original question (for context). + + Returns: + { + "equivalent": bool, + "confidence": float (0-1), + "reasoning": str, + "cached": bool, + "error": Optional[str] + } + """ + # --- Refusal short-circuit (saves LLM call) --- + if self._is_refusal(prediction): + return { + "equivalent": False, + "confidence": 1.0, + "reasoning": "Prediction is a refusal — skipped LLM judge.", + "cached": False, + "error": None, + } + + # --- Quick exact-match shortcut --- + from evaluate import normalize_answer + + if normalize_answer(prediction) == normalize_answer(gold_answer): + return { + "equivalent": True, + "confidence": 1.0, + "reasoning": "Normalized exact match", + "cached": False, + "error": None, + } + + # --- Check cache (key includes question for context-sensitivity) --- + cache_key = ( + question.strip().lower(), + prediction.strip().lower(), + gold_answer.strip().lower(), + ) + if cache_key in self._cache: + result = dict(self._cache[cache_key]) + result["cached"] = True + return result + + # --- Call LLM with retry --- + prompt = _JUDGE_PROMPT.format( + question=question or "N/A", + gold=gold_answer, + prediction=prediction, + ) + + result: Dict[str, Any] | None = None + last_error: str | None = None + + for attempt in range(1, self._MAX_RETRIES + 1): + try: + resp = await self._llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, + ) + raw = resp.content.strip() + result = self._parse_response(raw) + if result.get("error") is None: + break # success + last_error = result.get("error") + except Exception as e: + last_error = str(e) + logger.warning( + "LLM Judge call failed (attempt %d/%d): %s", + attempt, + self._MAX_RETRIES, + e, + ) + result = None + + if result is None or result.get("error") is not None: + result = { + "equivalent": False, + "confidence": 0.0, + "reasoning": f"Judge error after {self._MAX_RETRIES} attempts: {last_error}", + "error": last_error, + } + + # --- Apply confidence threshold (conservative) --- + if ( + result.get("error") is None + and result["equivalent"] + and result["confidence"] < self._CONFIDENCE_THRESHOLD + ): + result["equivalent"] = False + result["reasoning"] = ( + f"Overridden to NOT equivalent: confidence " + f"{result['confidence']:.2f} < threshold " + f"{self._CONFIDENCE_THRESHOLD} — conservative policy. " + f"Original reasoning: {result['reasoning']}" + ) + + result.setdefault("cached", False) + result.setdefault("error", None) + + # Cache successful results only + if result["error"] is None: + self._cache[cache_key] = { + k: v for k, v in result.items() if k != "cached" + } + + return result + + # ------------------------------------------------------------------ + # Parsing + # ------------------------------------------------------------------ + + def _parse_response(self, raw: str) -> Dict[str, Any]: + """Parse LLM JSON response with robust fallback heuristics.""" + # --- Try direct JSON parse --- + parsed = self._try_parse_json(raw) + if parsed is not None: + return self._validated_result(parsed, raw) + + # --- Fallback: keyword detection (conservative) --- + lower = raw.lower() + + # Look for explicit true/false patterns with word boundaries + true_match = re.search( + r'"equivalent"\s*:\s*true\b', lower + ) + false_match = re.search( + r'"equivalent"\s*:\s*false\b', lower + ) + + if false_match and not true_match: + return { + "equivalent": False, + "confidence": 0.5, + "reasoning": f"Keyword fallback (NOT equivalent): {raw[:200]}", + } + elif true_match and not false_match: + # Conservative: lower confidence for keyword-only parse + return { + "equivalent": True, + "confidence": 0.5, + "reasoning": f"Keyword fallback (equivalent): {raw[:200]}", + } + + # --- Cannot parse → conservative default --- + logger.warning("Cannot parse judge response: %s", raw[:200]) + return { + "equivalent": False, + "confidence": 0.0, + "reasoning": f"Unparseable response: {raw[:200]}", + "error": "parse_error", + } + + def _try_parse_json(self, raw: str) -> Optional[Dict[str, Any]]: + """Attempt multiple JSON extraction strategies.""" + strategies = [ + raw.strip(), + # Strip markdown code fences + re.sub(r"```(?:json)?\s*\n?", "", raw).strip().rstrip("`").strip(), + # Extract first {...} block + self._extract_json_block(raw), + ] + + for text in strategies: + if not text: + continue + # Fix common LLM JSON quirks + text = self._fix_json_quirks(text) + try: + return json.loads(text) + except (json.JSONDecodeError, ValueError): + continue + return None + + @staticmethod + def _extract_json_block(raw: str) -> Optional[str]: + """Extract the first {...} JSON object from raw text.""" + match = re.search(r"\{[^{}]*\}", raw, re.DOTALL) + return match.group(0) if match else None + + @staticmethod + def _fix_json_quirks(text: str) -> str: + """Fix common non-standard JSON from LLMs.""" + # Replace single quotes with double quotes (basic heuristic) + # Only if the text doesn't already have double quotes for keys + if "'" in text and '"' not in text: + text = text.replace("'", '"') + # Remove trailing commas before closing braces + text = re.sub(r",\s*}", "}", text) + text = re.sub(r",\s*]", "]", text) + return text + + def _validated_result( + self, obj: Dict[str, Any], raw: str + ) -> Dict[str, Any]: + """Build a validated result dict from parsed JSON, clamping values.""" + equivalent = bool(obj.get("equivalent", False)) + + # Clamp confidence to [0.0, 1.0] + try: + confidence = float(obj.get("confidence", 0.0)) + except (ValueError, TypeError): + confidence = 0.0 + confidence = max(0.0, min(1.0, confidence)) + + reasoning = str(obj.get("reasoning", "")) + + return { + "equivalent": equivalent, + "confidence": confidence, + "reasoning": reasoning, + } + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _is_refusal(text: str) -> bool: + """Quick check whether *text* looks like a refusal / non-answer.""" + if not text or not text.strip(): + return True + lower = text.strip().lower() + if lower in ("unknown", "n/a", "none", ""): + return True + for phrase in _REFUSAL_INDICATORS: + if phrase in lower: + return True + return False + + @property + def cache_size(self) -> int: + """Return the number of cached judge results.""" + return len(self._cache) diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py index 28e99d7..ef0df8f 100644 --- a/benchmarks/financebench/run_benchmark.py +++ b/benchmarks/financebench/run_benchmark.py @@ -115,6 +115,13 @@ def _print_summary( print(f" Evidence Recall: N/A (page-level telemetry unavailable)") print(f" Avg Latency: {avg_latency:.1f}s") print(f" Total Time: {total_time:.1f}s") + + # LLM Judge independent metrics + if metrics.get("llm_judge_accuracy") is not None: + print(f"\n --- LLM Judge (Independent) ---") + print(f" Judge Accuracy: {metrics['llm_judge_accuracy']:.1f}%") + print(f" Judge Correct: {metrics['llm_judge_correct']}/{metrics['llm_judge_count']}") + print(f"\n Results: {results_path}") print(f" Metrics: {metrics_path}") print(f" Log: {log_path}") @@ -122,14 +129,28 @@ def _print_summary( # Breakdown by question_type by_qt = metrics.get("by_question_type") if by_qt: - print(f"\n {'Question Type':<25} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}") - print(" " + "-" * 52) - for qt, m in sorted(by_qt.items()): - qt_acc = m.get("accuracy", 0) - qt_hal = m.get("hallucination_rate", 0) - qt_ref = m.get("refusal_rate", 0) - qt_n = m.get("n", 0) - print(f" {qt:<25} {qt_acc:>5.1f} {qt_hal:>7.1f} {qt_ref:>7.1f} {qt_n:>4}") + # Determine if judge data is available + has_judge = any(m.get("llm_judge_accuracy") is not None for m in by_qt.values()) + if has_judge: + print(f"\n {'Question Type':<25} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'Judge%':>7} {'N':>4}") + print(" " + "-" * 59) + for qt, m in sorted(by_qt.items()): + qt_acc = m.get("accuracy", 0) + qt_hal = m.get("hallucination_rate", 0) + qt_ref = m.get("refusal_rate", 0) + qt_n = m.get("n", 0) + qt_judge = m.get("llm_judge_accuracy") + qt_judge_str = f"{qt_judge:>6.1f}" if qt_judge is not None else " N/A" + print(f" {qt:<25} {qt_acc:>5.1f} {qt_hal:>7.1f} {qt_ref:>7.1f} {qt_judge_str} {qt_n:>4}") + else: + print(f"\n {'Question Type':<25} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}") + print(" " + "-" * 52) + for qt, m in sorted(by_qt.items()): + qt_acc = m.get("accuracy", 0) + qt_hal = m.get("hallucination_rate", 0) + qt_ref = m.get("refusal_rate", 0) + qt_n = m.get("n", 0) + print(f" {qt:<25} {qt_acc:>5.1f} {qt_hal:>7.1f} {qt_ref:>7.1f} {qt_n:>4}") print("=" * 60) diff --git a/benchmarks/financebench/runner.py b/benchmarks/financebench/runner.py index 72d0a0b..64d709f 100644 --- a/benchmarks/financebench/runner.py +++ b/benchmarks/financebench/runner.py @@ -111,6 +111,7 @@ async def run_single( llm: Any, cfg: FinanceBenchConfig, semaphore: asyncio.Semaphore, + judge: Any = None, ) -> Dict[str, Any]: """Execute one FinanceBench question end-to-end.""" fb_id = entry.get("financebench_id", "") @@ -188,6 +189,25 @@ async def run_single( else: ev_recall = None # mark as unavailable, avoid false 0 + # LLM Judge — independent evaluation dimension + # Skip judge for refusals (no point calling LLM on non-answers) + llm_judge_correct = None + llm_judge_reasoning = None + if judge is not None and classification != "refusal": + try: + judge_result = await judge.judge( + prediction=answer, + gold_answer=gold, + question=question, + ) + llm_judge_correct = judge_result.get("equivalent", False) + llm_judge_reasoning = judge_result.get("reasoning", "") + except Exception as e: + logger.warning("LLM Judge failed for %s: %s", fb_id, e) + elif judge is not None and classification == "refusal": + llm_judge_correct = False + llm_judge_reasoning = "Skipped: prediction classified as refusal" + return { "financebench_id": fb_id, "question": question, @@ -204,6 +224,8 @@ async def run_single( "em": em, "f1": round(f1, 4), "evidence_recall": round(ev_recall, 4) if ev_recall is not None else None, + "llm_judge_correct": llm_judge_correct, # None if judge disabled + "llm_judge_reasoning": llm_judge_reasoning, "error": error, } @@ -230,6 +252,13 @@ async def run_batch( loader = FinanceBenchLoader(data_dir=cfg.data_dir, pdf_dir=cfg.pdf_dir) semaphore = asyncio.Semaphore(cfg.max_concurrent) + # Initialise LLM Judge (uses the same test model) + judge = None + if cfg.enable_llm_judge: + from judge import FinanceBenchLLMJudge + judge = FinanceBenchLLMJudge(llm=llm) + logger.info("LLM Judge enabled (independent evaluation dimension)") + # Prepare output directory / file out_dir = Path(cfg.output_dir) out_dir.mkdir(parents=True, exist_ok=True) @@ -242,14 +271,17 @@ async def run_batch( async def _run_and_record(entry: Dict[str, Any]) -> Dict[str, Any]: nonlocal completed - res = await run_single(entry, loader, searcher, llm, cfg, semaphore) + res = await run_single(entry, loader, searcher, llm, cfg, semaphore, judge=judge) # Incremental save with open(out_path, "a", encoding="utf-8") as fp: fp.write(json_mod.dumps(res, ensure_ascii=False) + "\n") completed += 1 status = res["classification"] + judge_tag = "" + if res.get("llm_judge_correct") is not None: + judge_tag = " [judge:\u2713]" if res["llm_judge_correct"] else " [judge:\u2717]" logger.info( - "[%d/%d] %s %s EM=%s F1=%.2f %.1fs", + "[%d/%d] %s %s EM=%s F1=%.2f %.1fs%s", completed, total, res["financebench_id"], @@ -257,6 +289,7 @@ async def _run_and_record(entry: Dict[str, Any]) -> Dict[str, Any]: res["em"], res["f1"], res["elapsed"], + judge_tag, ) return res From 613c099af0653fc95934f18ffc7f8569cd232404 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 16 Apr 2026 17:10:22 +0800 Subject: [PATCH 15/70] Adapt older knowledge cluster data structure --- src/sirchmunk/storage/knowledge_storage.py | 138 +++++++++++++++------ 1 file changed, 102 insertions(+), 36 deletions(-) diff --git a/src/sirchmunk/storage/knowledge_storage.py b/src/sirchmunk/storage/knowledge_storage.py index 0f09071..c74e05a 100644 --- a/src/sirchmunk/storage/knowledge_storage.py +++ b/src/sirchmunk/storage/knowledge_storage.py @@ -107,9 +107,11 @@ def _load_from_parquet(self): variable-length ``FLOAT[]`` from Parquet's list encoding, breaking ``list_cosine_similarity`` which requires matching fixed-size types. - Handles schema evolution gracefully: if the parquet file has fewer - columns than the current schema (e.g., missing ``merge_count``), - missing columns are filled with defaults instead of failing. + Handles schema evolution gracefully with adaptive column mapping: + - Forward compatible: old parquet (more cols) → new table (fewer cols), + extra columns in parquet are silently ignored. + - Backward compatible: new parquet (fewer cols) → old table (more cols), + missing columns are filled with defaults. Also records the file's modification time so that ``_check_and_reload()`` can detect external changes later. @@ -121,37 +123,62 @@ def _load_from_parquet(self): self.db.drop_table(self.table_name, if_exists=True) # Create table with explicit schema (preserves FLOAT[384]) self._create_table() - # Detect parquet columns to handle schema evolution - try: - pq_cols = self.db.fetch_all( - f"SELECT name FROM parquet_schema('{self.parquet_file}')" - ) - pq_col_names = {row[0] for row in pq_cols} - except Exception: - pq_col_names = None - - if pq_col_names is not None: - # Build column-by-column SELECT with defaults for missing cols - schema_cols = list(self._get_schema_columns()) - select_parts = [] - for col_name in schema_cols: - if col_name in pq_col_names: - select_parts.append(col_name) - elif col_name == "merge_count": - select_parts.append("0 AS merge_count") - else: - select_parts.append(f"NULL AS {col_name}") - select_clause = ", ".join(select_parts) - self.db.execute( - f"INSERT INTO {self.table_name} " - f"SELECT {select_clause} FROM read_parquet('{self.parquet_file}')" + + # Adaptive column mapping: detect parquet & table columns + parquet_cols = self._get_parquet_columns(self.parquet_file) + table_cols = self._get_table_columns() + + if not parquet_cols or not table_cols: + logger.warning( + "Could not detect columns for adaptive mapping, " + "skipping parquet load" ) else: - # Fallback: try direct SELECT * (works when schemas match) - self.db.execute( - f"INSERT INTO {self.table_name} " - f"SELECT * FROM read_parquet('{self.parquet_file}')" - ) + parquet_col_set = set(parquet_cols) + table_col_set = set(table_cols) + # Compute common columns (preserve table column order) + common_cols = [c for c in table_cols if c in parquet_col_set] + + if not common_cols: + logger.warning( + "No common columns between parquet and table, " + "skipping parquet load" + ) + else: + # Log column mismatches as warnings + ignored_cols = parquet_col_set - table_col_set + missing_cols = table_col_set - parquet_col_set + if ignored_cols: + logger.warning( + "Parquet has extra columns (ignored): %s", + ignored_cols, + ) + if missing_cols: + logger.warning( + "Table has extra columns (filled with defaults): %s", + missing_cols, + ) + + # Build INSERT with explicit column lists + # For common cols: select directly from parquet + # For missing cols (in table but not in parquet): use defaults + insert_cols = list(table_cols) # all table columns + select_parts = [] + for col_name in table_cols: + if col_name in parquet_col_set: + select_parts.append(col_name) + elif col_name == "merge_count": + select_parts.append("0 AS merge_count") + else: + select_parts.append(f"NULL AS {col_name}") + + cols_str = ", ".join(insert_cols) + select_clause = ", ".join(select_parts) + self.db.execute( + f"INSERT INTO {self.table_name} ({cols_str}) " + f"SELECT {select_clause} " + f"FROM read_parquet('{self.parquet_file}')" + ) count = self.db.get_table_count(self.table_name) # Record mtime for stale-detection @@ -163,10 +190,13 @@ def _load_from_parquet(self): self._parquet_loaded_mtime = 0.0 logger.info("Created new knowledge clusters table") except Exception as e: - logger.error(f"Failed to load from parquet: {e}") - # Try to recreate table - self.db.drop_table(self.table_name, if_exists=True) - self._create_table() + logger.warning(f"Failed to load from parquet (non-blocking): {e}") + # Try to recreate table so retrieval can still work + try: + self.db.drop_table(self.table_name, if_exists=True) + self._create_table() + except Exception as recreate_err: + logger.warning(f"Failed to recreate table after load failure: {recreate_err}") self._parquet_loaded_mtime = 0.0 def _get_schema_columns(self) -> List[str]: @@ -181,6 +211,42 @@ def _get_schema_columns(self) -> List[str]: "embedding_text_hash", ] + def _get_parquet_columns(self, parquet_path: str) -> List[str]: + """Get column names from a parquet file's schema. + + Uses DuckDB's ``parquet_schema()`` function. The returned metadata + rows use a ``name`` field (not ``column_name``). + + Returns: + Ordered list of column names, or empty list on failure. + """ + try: + rows = self.db.fetch_all( + f"SELECT name FROM parquet_schema('{parquet_path}') " + f"WHERE name != 'duckdb_schema'" + ) + return [row[0] for row in rows] + except Exception as e: + logger.warning(f"Failed to read parquet schema: {e}") + return [] + + def _get_table_columns(self) -> List[str]: + """Get column names from the current DuckDB table. + + Returns: + Ordered list of column names, or empty list on failure. + """ + try: + rows = self.db.fetch_all( + "SELECT column_name FROM information_schema.columns " + f"WHERE table_name = '{self.table_name}' " + "ORDER BY ordinal_position" + ) + return [row[0] for row in rows] + except Exception as e: + logger.warning(f"Failed to read table columns: {e}") + return [] + def _check_and_reload(self): """Check if the parquet file was modified externally and reload if so. From 6858418583efdee30ee5fb3da1e5047300cdae1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 16 Apr 2026 18:08:46 +0800 Subject: [PATCH 16/70] update finance bench readme --- benchmarks/financebench/README.md | 53 ++++++++++++++++++++++------- benchmarks/financebench/evaluate.py | 4 +-- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index 23bd67d..da6752f 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -20,33 +20,62 @@ with **150 expert-annotated questions** across **40+ US public companies** (10-K - **EM / F1**: Exact Match and token-level F1 with financial value normalisation - **Evidence Recall**: Retrieved pages vs gold evidence pages -## Quick Start +## Prerequisites + +### 1. Install Sirchmunk -### 1. Setup +Make sure Sirchmunk is installed and accessible: ```bash -cd benchmarks/financebench +pip install -e . +``` -# Copy and edit the config file -cp .env.example .env.financebench -# Edit .env.financebench — set your LLM_API_KEY at minimum +### 2. Prepare Corpus + +Download the FinanceBench dataset (PDF files and JSONL) and place them in the appropriate directory. +Update the paths in your `.env.financebench`: + +- `FB_PDF_DIR` — path to the directory containing the 10-K/10-Q PDF files +- `FB_QUESTIONS_FILE` — path to `financebench_open_source.jsonl` + +### 3. Initialize Workspace + +Initialize the Sirchmunk workspace pointing to the PDF corpus directory: + +```bash +sirchmunk init +``` -# Download FinanceBench data -# Place financebench_open_source.jsonl in ./data/ -# Place PDF corpus (41 files) in ./data/pdfs/ +### 4. Compile Knowledge Base + +Compile the corpus to build the knowledge base for retrieval: + +```bash +sirchmunk compile --paths /path/to/financebench/pdf_files ``` -### 2. Run +> **Note:** The compile step may take some time depending on the corpus size and your LLM provider's rate limits. For FinanceBench's ~41 PDFs (10-K/10-Q filings), expect 10-30 minutes. + +### 5. Configure Environment + +```bash +cp .env.example .env.financebench +# Edit .env.financebench with your API keys and paths +``` + +## Quick Start + +### 1. Run ```bash # Run full benchmark (150 questions) python run_benchmark.py # Run with custom config and question limit -python run_benchmark.py --env .env.financebench --limit 20 +python run_benchmark.py --env .env.custom --limit 20 ``` -### 3. Analyze +### 2. Analyze ```bash # Analyze a completed run diff --git a/benchmarks/financebench/evaluate.py b/benchmarks/financebench/evaluate.py index 3e78636..e22bf07 100644 --- a/benchmarks/financebench/evaluate.py +++ b/benchmarks/financebench/evaluate.py @@ -349,11 +349,11 @@ def _breakdown(results: List[Dict[str, Any]], key: str) -> Dict[str, Dict[str, A """Compute per-group accuracy / hallucination / refusal breakdown.""" groups: dict[str, list[dict]] = defaultdict(list) for r in results: - group = r.get(key, "unknown") + group = r.get(key) or "unknown" groups[group].append(r) out: dict[str, dict] = {} - for group, items in sorted(groups.items()): + for group, items in sorted(groups.items(), key=lambda x: (x[0] is None, x[0] or "")): g_n = len(items) g_correct = sum(1 for r in items if r.get("classification") == "correct") g_halluc = sum( From 9441ef23af039db48e0e035aa4d8e4dac9223498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 16 Apr 2026 19:20:44 +0800 Subject: [PATCH 17/70] refactor config for finbench --- .gitignore | 3 +- benchmarks/financebench/README.md | 27 +++++++++-- benchmarks/financebench/config.py | 59 ++++++++++++++++++------ benchmarks/financebench/run_benchmark.py | 14 ++++++ benchmarks/financebench/runner.py | 3 +- 5 files changed, 86 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index dbd34eb..f79f03f 100644 --- a/.gitignore +++ b/.gitignore @@ -270,4 +270,5 @@ benchmarks/*/data/ benchmarks/*/.env* benchmarks/*/logs/ benchmarks/*/results/ -benchmarks/*/output/ \ No newline at end of file +benchmarks/*/output/ +benchmarks/*/.work/ diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index da6752f..4751508 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -40,21 +40,26 @@ Update the paths in your `.env.financebench`: ### 3. Initialize Workspace -Initialize the Sirchmunk workspace pointing to the PDF corpus directory: +Initialize the Sirchmunk workspace with an experiment-isolated work path: ```bash -sirchmunk init +cd benchmarks/financebench +sirchmunk init --work-path ./.work ``` +This creates a `.work/` directory under the experiment folder, keeping knowledge base +and cache isolated from the default `~/.sirchmunk`. + ### 4. Compile Knowledge Base -Compile the corpus to build the knowledge base for retrieval: +Compile the PDF corpus into the experiment workspace: ```bash -sirchmunk compile --paths /path/to/financebench/pdf_files +sirchmunk compile --work-path ./.work --paths ``` -> **Note:** The compile step may take some time depending on the corpus size and your LLM provider's rate limits. For FinanceBench's ~41 PDFs (10-K/10-Q filings), expect 10-30 minutes. +> **Note:** The compile step may take some time depending on the corpus size. +> For FinanceBench's ~41 PDFs (10-K/10-Q filings), expect 10-30 minutes. ### 5. Configure Environment @@ -65,6 +70,18 @@ cp .env.example .env.financebench ## Quick Start +### Configuration Priority + +Configuration loads in this order (later overrides earlier): + +1. **Dataclass defaults** — hard-coded in `FinanceBenchConfig` +2. **Platform .env** — `.work/.env` (created by `sirchmunk init`) +3. **Experiment .env** — `.env.financebench` +4. **Command-line** — `--limit N`, `--env ` + +To reuse platform LLM config, leave `LLM_*` commented in `.env.financebench`. +To override, uncomment and set different values. + ### 1. Run ```bash diff --git a/benchmarks/financebench/config.py b/benchmarks/financebench/config.py index f51ea36..5c390ce 100644 --- a/benchmarks/financebench/config.py +++ b/benchmarks/financebench/config.py @@ -6,6 +6,27 @@ from pathlib import Path +def _parse_env_file(path: str) -> dict[str, str]: + """Parse a .env file into a dict, handling comments, blank lines, and quotes.""" + result: dict[str, str] = {} + p = Path(path) + if not p.exists(): + return result + for line in p.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" not in line: + continue + k, v = line.split("=", 1) + v = v.strip() + # Strip surrounding quotes + if len(v) >= 2 and v[0] == v[-1] and v[0] in ('"', "'"): + v = v[1:-1] + result[k.strip()] = v + return result + + @dataclass class FinanceBenchConfig: """All settings for a FinanceBench evaluation run.""" @@ -41,23 +62,34 @@ class FinanceBenchConfig: max_concurrent: int = 3 request_delay: float = 0.5 + # Experiment isolation + work_path: str = "./.work" # Isolated workspace for this experiment + @classmethod def from_env(cls, env_path: str = ".env.financebench") -> "FinanceBenchConfig": - """Load config from .env file with ``os.environ`` fallback.""" - # Read .env file - env_vars: dict[str, str] = {} - p = Path(env_path) - if p.exists(): - for line in p.read_text(encoding="utf-8").splitlines(): - line = line.strip() - if not line or line.startswith("#"): - continue - if "=" in line: - k, v = line.split("=", 1) - env_vars[k.strip()] = v.strip() + """Load config with layer inheritance. + + Priority (highest to lowest): + 1. Experiment .env (.env.financebench) + 2. Platform .env (/.env, if exists) + 3. os.environ + 4. Dataclass defaults + """ + # Step 0: Pre-read experiment env to determine work_path + experiment_vars = _parse_env_file(env_path) + work_path = experiment_vars.get( + "FB_WORK_PATH", os.environ.get("FB_WORK_PATH", "./.work") + ) + + # Step 1: Load platform-level env (/.env) + platform_env_path = Path(work_path) / ".env" + platform_vars = _parse_env_file(str(platform_env_path)) + + # Step 2: Merge — experiment > platform > os.environ > defaults + merged = {**platform_vars, **experiment_vars} def _get(key: str, default: str = "") -> str: - return env_vars.get(key, os.environ.get(key, default)) + return merged.get(key, os.environ.get(key, default)) def _bool(key: str, default: bool = False) -> bool: v = _get(key, str(default)).lower() @@ -97,4 +129,5 @@ def _float(key: str, default: float = 0.0) -> float: extract_answer=_bool("FB_EXTRACT_ANSWER", True), max_concurrent=_int("FB_MAX_CONCURRENT", 3), request_delay=_float("FB_REQUEST_DELAY", 0.5), + work_path=work_path, ) diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py index ef0df8f..c9f5b26 100644 --- a/benchmarks/financebench/run_benchmark.py +++ b/benchmarks/financebench/run_benchmark.py @@ -187,6 +187,20 @@ def main() -> None: log_path = setup_logging(cfg.output_dir) logger = logging.getLogger("financebench") + # Print config source info + work_env = Path(cfg.work_path) / ".env" + logger.info("=" * 50) + logger.info("FinanceBench Configuration") + logger.info("=" * 50) + logger.info(" Experiment env : %s", args.env) + logger.info(" Platform env : %s (%s)", work_env, "found" if work_env.exists() else "not found") + logger.info(" Work path : %s", Path(cfg.work_path).resolve()) + logger.info(" LLM : %s @ %s", cfg.llm_model, cfg.llm_base_url) + logger.info(" Eval mode : %s", cfg.eval_mode) + logger.info(" Search mode : %s, Top-K: %d", cfg.mode, cfg.top_k_files) + logger.info(" LLM Judge : %s", "enabled" if cfg.enable_llm_judge else "disabled") + logger.info("=" * 50) + # 3. Load data loader = FinanceBenchLoader(cfg.data_dir, cfg.pdf_dir) questions = loader.load_questions() diff --git a/benchmarks/financebench/runner.py b/benchmarks/financebench/runner.py index 64d709f..b95f7ca 100644 --- a/benchmarks/financebench/runner.py +++ b/benchmarks/financebench/runner.py @@ -248,7 +248,8 @@ async def run_batch( base_url=cfg.llm_base_url, model=cfg.llm_model, ) - searcher = AgenticSearch(llm=llm, reuse_knowledge=False, verbose=False) + work_path = str(Path(cfg.work_path).resolve()) + searcher = AgenticSearch(llm=llm, work_path=work_path, reuse_knowledge=False, verbose=False) loader = FinanceBenchLoader(data_dir=cfg.data_dir, pdf_dir=cfg.pdf_dir) semaphore = asyncio.Semaphore(cfg.max_concurrent) From f1f86fab5ce18c245ad46d2de817eb66bab444b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 10:41:50 +0800 Subject: [PATCH 18/70] refactor financebench readme --- benchmarks/financebench/README.md | 161 +++++++++++++++++++++++++----- 1 file changed, 134 insertions(+), 27 deletions(-) diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index 4751508..d6c95b0 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -22,65 +22,172 @@ with **150 expert-annotated questions** across **40+ US public companies** (10-K ## Prerequisites -### 1. Install Sirchmunk +### Step 1: Install Sirchmunk -Make sure Sirchmunk is installed and accessible: +Install Sirchmunk from the repository root so that the `sirchmunk` CLI is available: ```bash +# From repository root pip install -e . ``` -### 2. Prepare Corpus +Verify the installation: -Download the FinanceBench dataset (PDF files and JSONL) and place them in the appropriate directory. -Update the paths in your `.env.financebench`: +```bash +sirchmunk --version +``` + +### Step 2: Prepare Dataset + +Download the [FinanceBench](https://huggingface.co/datasets/PatronusAI/financebench) +dataset and place the files under `benchmarks/financebench/data/`: -- `FB_PDF_DIR` — path to the directory containing the 10-K/10-Q PDF files -- `FB_QUESTIONS_FILE` — path to `financebench_open_source.jsonl` +``` +data/ +├── financebench_open_source.jsonl # 150 expert-annotated QA pairs +└── pdfs/ # 41 SEC-filing PDFs (10-K / 10-Q) + ├── 3M_2018_10K.pdf + ├── AMCOR_2023_10K.pdf + └── ... +``` -### 3. Initialize Workspace +Each PDF filename must match the `doc_name` field in the JSONL file. -Initialize the Sirchmunk workspace with an experiment-isolated work path: +### Step 3: Initialize Experiment Workspace + +Initialize an isolated workspace for this experiment. This keeps the knowledge base +and cache separate from the default `~/.sirchmunk`: ```bash cd benchmarks/financebench sirchmunk init --work-path ./.work ``` -This creates a `.work/` directory under the experiment folder, keeping knowledge base -and cache isolated from the default `~/.sirchmunk`. +This creates a `.work/` directory containing a **platform .env** file (`.work/.env`). + +**Configure the platform .env** (`.work/.env`): -### 4. Compile Knowledge Base +This file controls the LLM provider used by Sirchmunk's search engine. +You **must** set valid LLM credentials here before proceeding. -Compile the PDF corpus into the experiment workspace: +| Variable | Required | Description | Example | +|----------|----------|-------------|---------| +| `LLM_API_KEY` | **Yes** | API key for the LLM provider | `sk-xxx` | +| `LLM_BASE_URL` | **Yes** | LLM API endpoint | `https://dashscope.aliyuncs.com/compatible-mode/v1` | +| `LLM_MODEL_NAME` | **Yes** | Model name for search & QA | `qwen3.5-plus` | +| `LLM_TIMEOUT` | No | Request timeout in seconds | `120` | ```bash -sirchmunk compile --work-path ./.work --paths +# Edit the platform .env +vi .work/.env ``` -> **Note:** The compile step may take some time depending on the corpus size. -> For FinanceBench's ~41 PDFs (10-K/10-Q filings), expect 10-30 minutes. +### Step 4: Compile Knowledge Base -### 5. Configure Environment +Compile the PDF corpus into the experiment workspace so that Sirchmunk can search it: + +```bash +sirchmunk compile --work-path ./.work --paths ./data/pdfs +``` + +> **Note:** This step parses, chunks, and indexes all PDFs. +> For FinanceBench's ~41 PDFs (10-K/10-Q filings), expect 10–30 minutes. + +### Step 5: Configure Experiment + +Create the **experiment .env** from the template: ```bash cp .env.example .env.financebench -# Edit .env.financebench with your API keys and paths ``` -## Quick Start +**Configure the experiment .env** (`.env.financebench`): + +This file controls FinanceBench-specific evaluation parameters. + +#### Dataset Paths -### Configuration Priority +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_WORK_PATH` | No | Isolated workspace path | `./.work` | +| `FB_DATA_DIR` | **Yes** | Directory containing `financebench_open_source.jsonl` | `./data` | +| `FB_PDF_DIR` | **Yes** | Directory containing the 41 PDF files | `./data/pdfs` | +| `FB_OUTPUT_DIR` | No | Results output directory | `./output` | -Configuration loads in this order (later overrides earlier): +#### Dataset Settings -1. **Dataclass defaults** — hard-coded in `FinanceBenchConfig` -2. **Platform .env** — `.work/.env` (created by `sirchmunk init`) -3. **Experiment .env** — `.env.financebench` -4. **Command-line** — `--limit N`, `--env ` +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_LIMIT` | No | Number of questions to evaluate (`0` = all 150) | `0` | +| `FB_SEED` | No | Random seed for reproducibility | `42` | + +#### Search Settings + +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_MODE` | No | Search mode: `FAST` or `DEEP` | `FAST` | +| `FB_TOP_K_FILES` | No | Max files returned per search | `5` | +| `FB_MAX_TOKEN_BUDGET` | No | Token budget for search context | `128000` | +| `FB_ENABLE_DIR_SCAN` | No | Enable directory-level scanning | `true` | + +#### Evaluation Settings + +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_EVAL_MODE` | No | `singleDoc` (per-PDF) or `sharedCorpus` (all PDFs) | `singleDoc` | +| `FB_ENABLE_LLM_JUDGE` | No | Enable LLM Judge for semantic equivalence | `true` | +| `FB_EXTRACT_ANSWER` | No | Extract short answer from verbose response | `true` | + +#### Concurrency Settings + +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_MAX_CONCURRENT` | No | Max concurrent evaluation requests | `3` | +| `FB_REQUEST_DELAY` | No | Delay between requests in seconds | `0.5` | + +**Optional LLM Override**: If you want this experiment to use a **different** LLM +than the platform config, uncomment the `LLM_*` lines in `.env.financebench`. +Otherwise, the experiment inherits LLM settings from `.work/.env`. + +```bash +# Edit the experiment .env +vi .env.financebench +``` + +## Configuration Architecture + +Configuration loads with layered inheritance (highest priority wins): + +``` +Priority (highest → lowest): +┌──────────────────────────────────┐ +│ Command-line args │ ← --limit N, --env +├──────────────────────────────────┤ +│ .env.financebench (experiment) │ ← FB_* params + optional LLM override +├──────────────────────────────────┤ +│ .work/.env (platform) │ ← LLM_API_KEY, LLM_MODEL_NAME, etc. +├──────────────────────────────────┤ +│ Environment variables │ ← os.environ fallback +├──────────────────────────────────┤ +│ Defaults │ ← Hard-coded in FinanceBenchConfig +└──────────────────────────────────┘ +``` -To reuse platform LLM config, leave `LLM_*` commented in `.env.financebench`. -To override, uncomment and set different values. +### What Goes Where? + +| Setting | Platform `.work/.env` | Experiment `.env.financebench` | +|---------|:---------------------:|:------------------------------:| +| LLM API Key | ✅ (required) | Only if overriding | +| LLM Model | ✅ (required) | Only if overriding | +| LLM Base URL | ✅ (required) | Only if overriding | +| LLM Timeout | Optional | Only if overriding | +| PDF directory | — | ✅ (required) | +| Data directory | — | ✅ (required) | +| Output directory | — | Optional | +| Eval mode | — | Optional | +| Search mode | — | Optional | +| LLM Judge | — | Optional | +| Concurrency | — | Optional | ### 1. Run From 0e46ef5641adb5e3857f527c04f34fe2128ef834 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 10:52:09 +0800 Subject: [PATCH 19/70] update readme for finbench --- benchmarks/financebench/README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index d6c95b0..e294c7b 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -60,7 +60,7 @@ and cache separate from the default `~/.sirchmunk`: ```bash cd benchmarks/financebench -sirchmunk init --work-path ./.work +sirchmunk init --work-path .work ``` This creates a `.work/` directory containing a **platform .env** file (`.work/.env`). @@ -70,12 +70,12 @@ This creates a `.work/` directory containing a **platform .env** file (`.work/.e This file controls the LLM provider used by Sirchmunk's search engine. You **must** set valid LLM credentials here before proceeding. -| Variable | Required | Description | Example | -|----------|----------|-------------|---------| -| `LLM_API_KEY` | **Yes** | API key for the LLM provider | `sk-xxx` | +| Variable | Required | Description | Example | +|----------|----------|-------------|-----------------------------------------------------| +| `LLM_API_KEY` | **Yes** | API key for the LLM provider | `sk-xxx` | | `LLM_BASE_URL` | **Yes** | LLM API endpoint | `https://dashscope.aliyuncs.com/compatible-mode/v1` | -| `LLM_MODEL_NAME` | **Yes** | Model name for search & QA | `qwen3.5-plus` | -| `LLM_TIMEOUT` | No | Request timeout in seconds | `120` | +| `LLM_MODEL_NAME` | **Yes** | Model name for search & QA | `qwen3.6-plus` | +| `LLM_TIMEOUT` | No | Request timeout in seconds | `120` | ```bash # Edit the platform .env @@ -87,7 +87,7 @@ vi .work/.env Compile the PDF corpus into the experiment workspace so that Sirchmunk can search it: ```bash -sirchmunk compile --work-path ./.work --paths ./data/pdfs +sirchmunk compile --work-path .work --paths data/pdfs ``` > **Note:** This step parses, chunks, and indexes all PDFs. From 2cf5c378cbd42d1aba82ded34db994a62f56b487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 13:59:49 +0800 Subject: [PATCH 20/70] enhance tree indexes usage for search pipeline --- benchmarks/financebench/README.md | 2 +- src/sirchmunk/search.py | 291 +++++++++++++++++++++++++++++- 2 files changed, 285 insertions(+), 8 deletions(-) diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index e294c7b..95d04e7 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -82,7 +82,7 @@ You **must** set valid LLM credentials here before proceeding. vi .work/.env ``` -### Step 4: Compile Knowledge Base +### Step 4: Knowledge Compiling Compile the PDF corpus into the experiment workspace so that Sirchmunk can search it: diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 9b7bf47..2e44449 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -1427,6 +1427,9 @@ async def _search_deep( ) _llm_usage_start = len(self.llm_usages) + # --- Adaptive compile artifact detection (shared with FAST) --- + artifacts = self._detect_compile_artifacts() + # ============================================================== # Phase 0a: Direct document analysis (intent-gated short-circuit) # ============================================================== @@ -1460,7 +1463,9 @@ async def _search_deep( self._probe_knowledge_cache(query), self._load_spec_context(paths, stale_hours=spec_stale_hours), self._probe_tree_index(query), - self._probe_compile_hints(initial_keywords if initial_keywords else [query]), + self._probe_compile_hints([query]), # query-level hints; keyword-level runs post-Phase 1 + self._probe_summary_index(query, artifacts), # GAP 2: zero-LLM BM25 + self._probe_catalog_for_deep(query, artifacts), # GAP 4: zero-LLM keyword overlap return_exceptions=True, ) @@ -1470,8 +1475,10 @@ async def _search_deep( spec_context = phase1_results[3] if not isinstance(phase1_results[3], Exception) else "" tree_hits = phase1_results[4] if not isinstance(phase1_results[4], Exception) else [] compile_hints = phase1_results[5] if not isinstance(phase1_results[5], Exception) else CompileHints([], []) + summary_index_hits = phase1_results[6] if not isinstance(phase1_results[6], Exception) else [] + catalog_deep_hits = phase1_results[7] if not isinstance(phase1_results[7], Exception) else [] - for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index", "compile_hints"]): + for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index", "compile_hints", "summary_index", "catalog_deep"]): if isinstance(phase1_results[i], Exception): await self._logger.warning(f"[Phase 1] {label} probe failed: {phase1_results[i]}") @@ -1508,6 +1515,8 @@ async def _search_deep( f"knowledge_files={len(knowledge_probe.file_paths)}, " f"tree_hits={len(tree_hits)}, " f"compile_hints={len(compile_hints.file_paths)}, " + f"summary_index={len(summary_index_hits)}, " + f"catalog_deep={len(catalog_deep_hits)}, " f"soft_hit={'YES' if soft_hit else 'NO'}, " f"spec_cache={'YES' if spec_context else 'NO'}" ) @@ -1583,7 +1592,7 @@ async def _search_deep( if soft_hit: extra_knowledge_files = soft_hit.file_paths + extra_knowledge_files merged_files = self._merge_file_paths( - keyword_files=list(tree_hits) + compile_hints.file_paths + keyword_files, + keyword_files=list(tree_hits) + catalog_deep_hits + compile_hints.file_paths + summary_index_hits + keyword_files, dir_scan_files=dir_scan_files, knowledge_hits=extra_knowledge_files, ) @@ -1627,6 +1636,22 @@ async def _search_deep( answer: str = "" should_save: bool = True + # Inject catalog context for wiki-enhanced answer (GAP 4) + if artifacts and artifacts.catalog_map and cluster and cluster.content: + _catalog_ctx_parts = [] + for fp in (cluster.search_results or merged_files)[:3]: + ctx = self._build_answer_context(fp, artifacts) + if ctx: + _catalog_ctx_parts.append(ctx) + if _catalog_ctx_parts: + _catalog_context = "\n".join(_catalog_ctx_parts) + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = f"{cluster.content}\n\n[Document Context]\n{_catalog_context}" + await self._logger.info( + f"[Phase 4] Injected catalog context for {len(_catalog_ctx_parts)} documents" + ) + if cluster and cluster.content: await self._logger.info("[Phase 4] Evidence sufficient, generating summary") answer, should_save, should_answer = await self._summarise_cluster(query, cluster) @@ -2007,6 +2032,10 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Maximum character length for a catalog keyword token.""" _CATALOG_SUMMARY_TRUNCATE = 200 """Max chars of catalog summary shown in the listing.""" + _SUMMARY_INDEX_TOP_K = 3 + """Maximum files returned by proactive summary index BM25 probe.""" + _DEEP_CATALOG_TOP_K = 3 + """Maximum files returned by catalog keyword-overlap probe in DEEP mode.""" # --- Tree-guided sampling constants --- _TREE_SAMPLE_MAX_SECTIONS = 3 @@ -2019,6 +2048,8 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Maximum number of tree roots to include in FAST Step 1 hints.""" _DEEP_PRE_NAV_MAX_FILES = 3 """Maximum number of tree files to pre-navigate in DEEP Phase 2.5.""" + _FAST_TREE_PROBE_MAX_FILES = 2 + """Maximum files returned by active tree probing in FAST mode.""" _LLM_FALLBACK_EVIDENCE = ( "[No relevant documents found]\n\n" @@ -2106,10 +2137,33 @@ async def _search_fast( if tree_hints: prompt = prompt + tree_hints - resp = await self.llm.achat( + # Step 1 LLM call + compile hints + tree probe run in parallel + # (GAP 3: hints前置化, GAP 1: 树导航主动化) + _step1_llm_task = self.llm.achat( messages=[{"role": "user", "content": prompt}], stream=False, ) + _compile_hints_task = self._probe_compile_hints([query]) + _tree_probe_task = self._probe_tree_for_fast(query, artifacts) + + _parallel_results = await asyncio.gather( + _step1_llm_task, _compile_hints_task, _tree_probe_task, + return_exceptions=True, + ) + resp = _parallel_results[0] + _early_compile_hints = _parallel_results[1] + _tree_probed_files = _parallel_results[2] + + if isinstance(resp, Exception): + await self._logger.warning(f"[FAST:Step1] LLM call failed: {resp}") + return f"Search analysis failed: {resp}", None, context + if isinstance(_early_compile_hints, Exception): + await self._logger.warning(f"[FAST:Step1] Compile hints pre-fetch failed: {_early_compile_hints}") + _early_compile_hints = CompileHints([], []) + if isinstance(_tree_probed_files, Exception): + await self._logger.warning(f"[FAST:Step1] Tree probe failed: {_tree_probed_files}") + _tree_probed_files = [] + self.llm_usages.append(resp.usage) if resp.usage and isinstance(resp.usage, dict): context.add_llm_tokens( @@ -2207,8 +2261,9 @@ async def _search_fast( all_kw_set.add(p) keyword_idfs.setdefault(p, 0.6) - # P4: compile hints from manifest + tree cache - compile_hints = await self._probe_compile_hints(primary + fallback) + # P4: compile hints — pre-fetched (query-level) + keyword-level supplement + _kw_compile_hints = await self._probe_compile_hints(primary + fallback) + compile_hints = self._merge_compile_hints(_early_compile_hints, _kw_compile_hints) for kw in compile_hints.extra_keywords: if kw not in all_kw_set: fallback.append(kw) @@ -2222,6 +2277,17 @@ async def _search_fast( if fp not in seen_hint_paths: seen_hint_paths.add(fp) compile_hint_files.append(fp) + # Active tree probe files: second priority (GAP 1) + for fp in (_tree_probed_files or []): + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) + # Summary index BM25 files: proactive zero-LLM discovery (GAP 2) + _summary_hint_files = await self._probe_summary_index(query, artifacts) + for fp in _summary_hint_files: + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) if soft_hit: for fp in soft_hit.file_paths: if fp not in seen_hint_paths: @@ -2235,7 +2301,10 @@ async def _search_fast( if compile_hint_files: await self._logger.info( f"[FAST:Step1.5] Compile hints: {len(compile_hint_files)} files " - f"(catalog={len(catalog_routed_files)}, soft={len(soft_hit.file_paths) if soft_hit else 0}), " + f"(catalog={len(catalog_routed_files)}, " + f"tree={len(_tree_probed_files) if _tree_probed_files else 0}, " + f"summary={len(_summary_hint_files)}, " + f"soft={len(soft_hit.file_paths) if soft_hit else 0}), " f"{len(compile_hints.extra_keywords)} extra keywords" ) @@ -4105,6 +4174,214 @@ async def _probe_compile_hints(self, keywords: List[str]) -> CompileHints: extra_keywords=extra_keywords[:10], ) + @staticmethod + def _merge_compile_hints(base: "CompileHints", supplement: "CompileHints") -> "CompileHints": + """Merge two CompileHints, deduplicating file paths and keywords.""" + seen_fps = set(base.file_paths) + merged_fps = list(base.file_paths) + for fp in supplement.file_paths: + if fp not in seen_fps: + seen_fps.add(fp) + merged_fps.append(fp) + seen_kws = set(base.extra_keywords) + merged_kws = list(base.extra_keywords) + for kw in supplement.extra_keywords: + if kw not in seen_kws: + seen_kws.add(kw) + merged_kws.append(kw) + return CompileHints(file_paths=merged_fps[:15], extra_keywords=merged_kws[:10]) + + async def _probe_summary_index( + self, + query: str, + artifacts: Optional["CompileArtifacts"] = None, + ) -> List[str]: + """Zero-LLM file discovery via compile-time summary index (BM25 only). + + Uses the pre-built summary index's BM25 channel to find files whose + summaries are lexically similar to the query. No LLM or embedding + calls — pure local computation. + + Args: + query: User query string. + artifacts: Compile artifacts (uses summary_index field). + + Returns: + File paths of top-k matching documents, or empty list. + """ + if artifacts is None or artifacts.summary_index is None: + return [] + + try: + from sirchmunk.utils.tokenizer_util import TokenizerUtil + _tokenizer = TokenizerUtil() + query_tokens = _tokenizer.segment(query) + + if not query_tokens: + return [] + + # BM25-only search: pass query_embedding=None to skip embedding channel + results = artifacts.summary_index.search( + query_embedding=None, + query_tokens=query_tokens, + top_k=self._SUMMARY_INDEX_TOP_K, + ) + + file_paths = [ + fp for fp, score in results + if score > 0.0 and Path(fp).exists() + ] + + if file_paths: + await self._logger.info( + f"[SummaryIndex:BM25] Found {len(file_paths)} files " + f"from {artifacts.summary_index.num_entries} indexed docs" + ) + return file_paths + except Exception as exc: + await self._logger.warning(f"[SummaryIndex:BM25] Probe failed: {exc}") + return [] + + async def _probe_catalog_for_deep( + self, + query: str, + artifacts: Optional["CompileArtifacts"] = None, + ) -> List[str]: + """Zero-LLM file discovery via document catalog keyword overlap. + + Scores each catalog entry by counting query token overlap with the + document summary. Returns top-k file paths sorted by overlap score. + + Args: + query: User query string. + artifacts: Compile artifacts (uses catalog field). + + Returns: + File paths of top-k matching documents, or empty list. + """ + if not artifacts or not artifacts.catalog: + return [] + + try: + query_tokens = self._tokenize_for_matching(query.lower()) + if not query_tokens: + return [] + + scored: List[Tuple[str, float]] = [] + for entry in artifacts.catalog: + fp = entry.get("path", "") + if not fp or not Path(fp).exists(): + continue + summary = (entry.get("summary", "") or "").lower() + name = (entry.get("name", "") or "").lower() + doc_tokens = self._tokenize_for_matching(f"{name} {summary}") + overlap = len(query_tokens & doc_tokens) + if overlap > 0: + # Normalize by query length to avoid bias toward long summaries + score = overlap / max(1, len(query_tokens)) + scored.append((fp, score)) + + if not scored: + return [] + + scored.sort(key=lambda x: x[1], reverse=True) + result_paths = [fp for fp, _ in scored[:self._DEEP_CATALOG_TOP_K]] + + if result_paths: + await self._logger.info( + f"[DEEP:CatalogProbe] Found {len(result_paths)} files " + f"from {len(artifacts.catalog)} catalog entries" + ) + return result_paths + except Exception as exc: + await self._logger.warning(f"[DEEP:CatalogProbe] Failed: {exc}") + return [] + + async def _probe_tree_for_fast( + self, query: str, artifacts: Optional["CompileArtifacts"] = None, + ) -> List[str]: + """Active tree-based file discovery for FAST mode (1 LLM call). + + Lightweight wrapper around tree root selection logic. When compiled + tree indices are available and cover more than 2 files, asks the LLM + to select the most relevant 1-2 documents from root summaries. + + Returns file paths of selected documents, or empty list when trees + are unavailable or cover too few files to justify an LLM call. + """ + if not artifacts or len(artifacts.tree_available_paths) <= 2: + return [] + + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if not tree_cache.exists(): + return [] + + try: + from sirchmunk.learnings.tree_indexer import DocumentTree + + trees: List[DocumentTree] = [] + for tree_file in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: + try: + t = DocumentTree.from_json( + tree_file.read_text(encoding="utf-8") + ) + if t.root and t.file_path and Path(t.file_path).exists(): + trees.append(t) + except Exception: + continue + + if not trees: + return [] + + # Few trees: return all without LLM + if len(trees) <= self._FAST_TREE_PROBE_MAX_FILES: + return [t.file_path for t in trees] + + # LLM-driven selection among tree roots + listing = "\n".join( + f"[{i}] {Path(t.file_path).name}: {(t.root.summary or '')[:200]}" + for i, t in enumerate(trees) + ) + prompt = ( + f'Given the query: "{query}"\n\n' + f"Select the 1-{self._FAST_TREE_PROBE_MAX_FILES} most relevant documents:\n" + f"{listing}\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + resp = await self.llm.achat([{"role": "user", "content": prompt}]) + self.llm_usages.append(resp.usage) + + selected_indices: List[int] = [] + try: + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + selected_indices = [ + idx for idx in json.loads(m.group()) + if isinstance(idx, int) and 0 <= idx < len(trees) + ] + except (json.JSONDecodeError, TypeError): + pass + + if not selected_indices: + selected_indices = list(range(min(self._FAST_TREE_PROBE_MAX_FILES, len(trees)))) + + result_paths = [ + trees[idx].file_path + for idx in selected_indices[:self._FAST_TREE_PROBE_MAX_FILES] + if Path(trees[idx].file_path).exists() + ] + + if result_paths: + await self._logger.info( + f"[FAST:TreeProbe] Selected {len(result_paths)} files " + f"from {len(trees)} tree indices" + ) + return result_paths + except Exception as exc: + await self._logger.warning(f"[FAST:TreeProbe] Failed: {exc}") + return [] + @staticmethod async def _async_noop(default=None): """No-op coroutine used as placeholder in gather().""" From c0b0db5b1654ce510a185c6f0c80dae18b408d3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 14:11:45 +0800 Subject: [PATCH 21/70] fix issues --- src/sirchmunk/search.py | 205 ++++++++++++++++++---------------------- 1 file changed, 93 insertions(+), 112 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 2e44449..7506c0c 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2050,6 +2050,10 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Maximum number of tree files to pre-navigate in DEEP Phase 2.5.""" _FAST_TREE_PROBE_MAX_FILES = 2 """Maximum files returned by active tree probing in FAST mode.""" + _DEEP_TREE_PROBE_MAX_FILES = 3 + """Maximum files returned by tree index probing in DEEP mode.""" + _TREE_ROOT_HINT_TRUNCATE = 150 + """Max chars of tree root summary in Step 1 structure hints.""" _LLM_FALLBACK_EVIDENCE = ( "[No relevant documents found]\n\n" @@ -3312,7 +3316,7 @@ def _build_tree_root_hints(self, artifacts: CompileArtifacts) -> str: tree = indexer.load_tree(fp) if tree and tree.root and tree.root.summary: name = Path(fp).name - hints.append(f"[{i}] {name}: {tree.root.summary[:150]}") + hints.append(f"[{i}] {name}: {tree.root.summary[:self._TREE_ROOT_HINT_TRUNCATE]}") if not hints: return "" return "\nDocument structure hints:\n" + "\n".join(hints) + "\n" @@ -4017,80 +4021,110 @@ def _collect_cluster(c: KnowledgeCluster) -> None: except Exception: return empty - async def _probe_tree_index(self, query: str) -> List[str]: - """LLM-driven file discovery via compiled tree root summaries (PageIndex). + def _load_cached_trees(self) -> list: + """Load DocumentTree objects from the tree cache directory. - Loads all cached document trees, presents their root summaries to the - LLM, and asks it to select the most relevant 1-3 documents. For - selected trees, optionally drills one level deeper into children. - - Returns file paths of the most relevant documents. + Returns a list of ``DocumentTree`` instances whose file paths exist + on disk. Returns an empty list when the tree cache is absent or + contains no valid entries. """ tree_cache = self.work_path / ".cache" / "compile" / "trees" if not tree_cache.exists(): return [] - try: from sirchmunk.learnings.tree_indexer import DocumentTree - trees: List[DocumentTree] = [] - for tree_file in sorted(tree_cache.glob("*.json"))[:50]: + trees = [] + for tree_file in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: try: t = DocumentTree.from_json( tree_file.read_text(encoding="utf-8") ) - if t.root and t.file_path: + if t.root and t.file_path and Path(t.file_path).exists(): trees.append(t) except Exception: continue + return trees + except Exception: + return [] - if not trees: - return [] + async def _llm_select_from_trees( + self, query: str, trees: list, max_select: int, + ) -> List[str]: + """LLM-driven file selection from tree root summaries. - # If few trees, return all without LLM - if len(trees) <= 2: - return [t.file_path for t in trees if Path(t.file_path).exists()] + Presents root summaries to the LLM and returns the selected file + paths. When the number of trees is at most *max_select*, returns + all paths without an LLM call. - # LLM-driven selection among tree roots - listing = "\n".join( - f"[{i}] {Path(t.file_path).name}: {(t.root.summary or '')[:200]}" - for i, t in enumerate(trees) - ) - prompt = ( - f'Given the query: "{query}"\n\n' - f"Select the 1-3 most relevant documents (by index number):\n{listing}\n\n" - f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" - ) - resp = await self.llm.achat([{"role": "user", "content": prompt}]) - self.llm_usages.append(resp.usage) + Args: + query: User query string. + trees: List of ``DocumentTree`` objects (pre-loaded). + max_select: Maximum number of files to select. - selected_indices: List[int] = [] - try: - raw = resp.content.strip() - m = re.search(r"\[[\d\s,]+\]", raw) - if m: - selected_indices = [ - idx for idx in json.loads(m.group()) - if isinstance(idx, int) and 0 <= idx < len(trees) - ] - except (json.JSONDecodeError, TypeError): - pass + Returns: + Selected file paths, or empty list. + """ + if not trees: + return [] + if len(trees) <= max_select: + return [t.file_path for t in trees] - if not selected_indices: - selected_indices = list(range(min(2, len(trees)))) + listing = "\n".join( + f"[{i}] {Path(t.file_path).name}: " + f"{(t.root.summary or '')[:self._CATALOG_SUMMARY_TRUNCATE]}" + for i, t in enumerate(trees) + ) + prompt = ( + f'Given the query: "{query}"\n\n' + f"Select the 1-{max_select} most relevant documents " + f"(by index number):\n{listing}\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + resp = await self.llm.achat([{"role": "user", "content": prompt}]) + self.llm_usages.append(resp.usage) - result_paths: List[str] = [] - for idx in selected_indices: - fp = trees[idx].file_path - if Path(fp).exists(): - result_paths.append(fp) + selected_indices: List[int] = [] + try: + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + selected_indices = [ + idx for idx in json.loads(m.group()) + if isinstance(idx, int) and 0 <= idx < len(trees) + ] + except (json.JSONDecodeError, TypeError): + pass - if result_paths: + if not selected_indices: + selected_indices = list(range(min(max_select, len(trees)))) + + return [ + trees[idx].file_path + for idx in selected_indices[:max_select] + if Path(trees[idx].file_path).exists() + ] + + async def _probe_tree_index(self, query: str) -> List[str]: + """LLM-driven file discovery via compiled tree root summaries (PageIndex). + + Loads all cached document trees, presents their root summaries to the + LLM, and asks it to select the most relevant documents. Returns file + paths of the most relevant documents. + """ + try: + trees = self._load_cached_trees() + if not trees: + return [] + result = await self._llm_select_from_trees( + query, trees, max_select=self._DEEP_TREE_PROBE_MAX_FILES, + ) + if result: await self._logger.info( - f"[Probe:TreeIndex] LLM selected {len(result_paths)} documents " + f"[Probe:TreeIndex] LLM selected {len(result)} documents " f"from {len(trees)} tree indices" ) - return result_paths + return result except Exception: return [] @@ -4302,9 +4336,9 @@ async def _probe_tree_for_fast( ) -> List[str]: """Active tree-based file discovery for FAST mode (1 LLM call). - Lightweight wrapper around tree root selection logic. When compiled - tree indices are available and cover more than 2 files, asks the LLM - to select the most relevant 1-2 documents from root summaries. + When compiled tree indices are available and cover more than 2 files, + asks the LLM to select the most relevant 1-2 documents from root + summaries. Delegates to the shared ``_llm_select_from_trees`` helper. Returns file paths of selected documents, or empty list when trees are unavailable or cover too few files to justify an LLM call. @@ -4312,72 +4346,19 @@ async def _probe_tree_for_fast( if not artifacts or len(artifacts.tree_available_paths) <= 2: return [] - tree_cache = self.work_path / ".cache" / "compile" / "trees" - if not tree_cache.exists(): - return [] - try: - from sirchmunk.learnings.tree_indexer import DocumentTree - - trees: List[DocumentTree] = [] - for tree_file in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: - try: - t = DocumentTree.from_json( - tree_file.read_text(encoding="utf-8") - ) - if t.root and t.file_path and Path(t.file_path).exists(): - trees.append(t) - except Exception: - continue - + trees = self._load_cached_trees() if not trees: return [] - - # Few trees: return all without LLM - if len(trees) <= self._FAST_TREE_PROBE_MAX_FILES: - return [t.file_path for t in trees] - - # LLM-driven selection among tree roots - listing = "\n".join( - f"[{i}] {Path(t.file_path).name}: {(t.root.summary or '')[:200]}" - for i, t in enumerate(trees) - ) - prompt = ( - f'Given the query: "{query}"\n\n' - f"Select the 1-{self._FAST_TREE_PROBE_MAX_FILES} most relevant documents:\n" - f"{listing}\n\n" - f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + result = await self._llm_select_from_trees( + query, trees, max_select=self._FAST_TREE_PROBE_MAX_FILES, ) - resp = await self.llm.achat([{"role": "user", "content": prompt}]) - self.llm_usages.append(resp.usage) - - selected_indices: List[int] = [] - try: - raw = resp.content.strip() - m = re.search(r"\[[\d\s,]+\]", raw) - if m: - selected_indices = [ - idx for idx in json.loads(m.group()) - if isinstance(idx, int) and 0 <= idx < len(trees) - ] - except (json.JSONDecodeError, TypeError): - pass - - if not selected_indices: - selected_indices = list(range(min(self._FAST_TREE_PROBE_MAX_FILES, len(trees)))) - - result_paths = [ - trees[idx].file_path - for idx in selected_indices[:self._FAST_TREE_PROBE_MAX_FILES] - if Path(trees[idx].file_path).exists() - ] - - if result_paths: + if result: await self._logger.info( - f"[FAST:TreeProbe] Selected {len(result_paths)} files " + f"[FAST:TreeProbe] Selected {len(result)} files " f"from {len(trees)} tree indices" ) - return result_paths + return result except Exception as exc: await self._logger.warning(f"[FAST:TreeProbe] Failed: {exc}") return [] From e8184d06de2f8361fe6374fe31002d7404b039a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 14:39:21 +0800 Subject: [PATCH 22/70] update tree index --- src/sirchmunk/learnings/tree_indexer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index abf5459..26787eb 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -19,7 +19,7 @@ from sirchmunk.utils.file_utils import get_fast_hash # File-size threshold: skip tree indexing for small files -_TREE_MIN_CHARS = 20_000 # 20 K characters (lowered from 50K for broader coverage) +_TREE_MIN_CHARS = 10_000 # 10 K characters (lowered from 20K for broader coverage) # Adaptive depth thresholds: (min_chars, max_depth) — evaluated top-down; # **must** be sorted by min_chars descending so the first match wins. From dc27ed9eeccea93a720f2dc00d4749c238a7a440 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 16:18:52 +0800 Subject: [PATCH 23/70] update finbench readme --- benchmarks/financebench/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index 95d04e7..9c23648 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -91,7 +91,7 @@ sirchmunk compile --work-path .work --paths data/pdfs ``` > **Note:** This step parses, chunks, and indexes all PDFs. -> For FinanceBench's ~41 PDFs (10-K/10-Q filings), expect 10–30 minutes. +> For FinanceBench's all PDFs, expect hours of processing time, depending on your LLM speed and compute resources. ### Step 5: Configure Experiment From 8723b85c9394d7d1cecd7af2e7d348b3e8526757 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 16:59:00 +0800 Subject: [PATCH 24/70] update finbench readme --- benchmarks/financebench/README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index 9c23648..9bb0134 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -93,6 +93,18 @@ sirchmunk compile --work-path .work --paths data/pdfs > **Note:** This step parses, chunks, and indexes all PDFs. > For FinanceBench's all PDFs, expect hours of processing time, depending on your LLM speed and compute resources. +#### Shallow Compile (Recommended for First Run) + +Use `--shallow` to skip tree indexing and only generate Summary + Topics. +This reduces LLM calls dramatically and achieves **5–9× speedup**: + +```bash +sirchmunk compile --work-path .work --paths data/pdfs --shallow +``` + +> **Tip:** `--shallow` is ideal for quickly compiling a large corpus on the first pass. +> You can run a normal (full) compile later to incrementally add tree indexes. + ### Step 5: Configure Experiment Create the **experiment .env** from the template: From ca9a609d61564277c9384d9fd193d0685546a7b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 19:15:48 +0800 Subject: [PATCH 25/70] update should answer thres --- src/sirchmunk/llm/prompts.py | 6 +- src/sirchmunk/search.py | 193 +++++++++++++++++++++++++++++++++-- 2 files changed, 188 insertions(+), 11 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 8df111d..27338a2 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -189,7 +189,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. Is the content meaningful and not just error messages or "no information found"? 3. Are there sufficient evidences and context to answer the user's query? -- : output "true" only if the evidence is sufficient to answer the query. +- : output "true" if the evidence contains relevant information that can help answer the query, even if it requires reasoning, computation, or interpretation. Only output "false" if the evidence is clearly irrelevant or contains no useful information for the query. - : output "true" only if the evidence is sufficient AND the result is worth caching. - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". @@ -437,7 +437,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. Is the content meaningful and not just error messages or "no information found"? 3. Are there sufficient evidences and context to answer the user's query? -- : output "true" only if the evidence is sufficient to answer the query. +- : output "true" if the evidence contains relevant information that can help answer the query, even if it requires reasoning, computation, or interpretation. Only output "false" if the evidence is clearly irrelevant or contains no useful information for the query. - : output "true" only if the evidence is sufficient AND the result is worth caching. - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". @@ -476,7 +476,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. Is the content meaningful and not just error messages or "no information found"? 3. Are there sufficient evidences and context to answer the user's query? -- : output "true" only if the evidence is sufficient to answer the query. +- : output "true" if the evidence contains relevant information that can help answer the query, even if it requires reasoning, computation, or interpretation. Only output "false" if the evidence is clearly irrelevant or contains no useful information for the query. - : output "true" only if the evidence is sufficient AND the result is worth caching. - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 7506c0c..a900978 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -86,6 +86,14 @@ # Soft-similarity threshold for gradient cluster reuse (P2) _SOFT_SIM_THRESHOLD = 0.65 +# Common English stop-words filtered out during keyword coverage computation. +_STOP_WORDS: frozenset = frozenset({ + "the", "is", "a", "an", "of", "in", "for", "to", "and", "or", + "what", "how", "which", "does", "was", "were", "has", "have", "had", + "do", "did", "are", "be", "been", "by", "with", "from", "this", + "that", "it", "its", "on", "at", "as", "not", "no", +}) + @dataclass class SoftClusterHit: @@ -951,6 +959,105 @@ def _parse_summary_response(llm_response: str) -> Tuple[str, bool, bool]: return summary, should_save, should_answer + # ------------------------------------------------------------------ + # Multi-factor evidence acceptance helpers + # ------------------------------------------------------------------ + + @staticmethod + def _compute_keyword_coverage(query: str, evidence: str) -> float: + """Compute the fraction of query keywords found in the evidence text. + + Tokenises *query* into lowercase alpha-numeric words (length >= 2), + removes common English stop-words, then checks presence in + lower-cased *evidence*. + + Returns: + Coverage ratio in [0.0, 1.0]. Returns 0.0 when no valid + keywords can be extracted from *query*. + """ + tokens = re.findall(r'\b[a-z0-9]{2,}\b', query.lower()) + keywords = [t for t in tokens if t not in _STOP_WORDS] + if not keywords: + return 0.0 + evidence_lower = evidence.lower() + matched = sum(1 for kw in keywords if kw in evidence_lower) + return matched / len(keywords) + + @staticmethod + def _detect_numeric_evidence(query: str, evidence: str) -> bool: + """Detect whether *evidence* contains structured numeric data relevant to *query*. + + Returns True when *query* implies a numeric/financial intent AND + *evidence* contains numeric patterns (currency amounts, percentages, + financial figures). + """ + query_lower = query.lower() + has_intent = any( + kw in query_lower + for kw in AgenticSearch._NUMERIC_INTENT_KEYWORDS + ) + if not has_intent: + return False + has_numeric = bool( + re.search( + r'[\$\u20ac\u00a3]\s?\d' + r'|(? Tuple[bool, str]: + """Multi-factor decision on whether to accept retrieved evidence. + + Combines the LLM's own SHOULD_ANSWER judgment with heuristic + signals (evidence length, keyword coverage, numeric-data presence) + to reduce false-negative rejections of valid evidence. + + Returns: + A tuple of (*accept*, *reason*) where *accept* is the final + boolean decision and *reason* is a human-readable string + documenting which factor(s) determined the outcome. + """ + # Factor 1: LLM direct acceptance + if llm_should_answer: + return True, "llm_accepted" + + # Factor 2: Heuristic override — length + keyword coverage + evidence_len = len(evidence) if evidence else 0 + kw_coverage = ( + AgenticSearch._compute_keyword_coverage(query, evidence) + if evidence else 0.0 + ) + + if ( + evidence_len >= AgenticSearch._EVIDENCE_MIN_ACCEPT_LENGTH + and kw_coverage >= AgenticSearch._EVIDENCE_KEYWORD_COVERAGE_THRESHOLD + ): + return True, ( + f"heuristic_override(len={evidence_len}, " + f"kw_coverage={kw_coverage:.2f})" + ) + + # Factor 3: Numeric evidence detection + if AgenticSearch._detect_numeric_evidence(query, evidence or ""): + return True, ( + f"numeric_evidence(len={evidence_len}, " + f"kw_coverage={kw_coverage:.2f})" + ) + + # All factors negative + return False, ( + f"rejected(llm=false, len={evidence_len}, " + f"kw_coverage={kw_coverage:.2f}, numeric=false)" + ) + @staticmethod def _extract_and_validate_multi_level_keywords( llm_resp: str, @@ -1655,7 +1762,17 @@ async def _search_deep( if cluster and cluster.content: await self._logger.info("[Phase 4] Evidence sufficient, generating summary") answer, should_save, should_answer = await self._summarise_cluster(query, cluster) - if not should_answer: + + # --- Multi-factor evidence acceptance --- + cluster_evidence = str(cluster.content) if cluster and cluster.content else "" + accepted, accept_reason = self._evaluate_evidence_acceptance( + query, cluster_evidence, should_answer, + ) + await self._logger.info( + f"[Phase 4] Evidence acceptance: {accepted} ({accept_reason})" + ) + + if not accepted: if llm_fallback: await self._logger.info( "[Phase 4] Summary gate rejected evidence, llm_fallback=True → LLM fallback" @@ -1703,7 +1820,17 @@ async def _search_deep( # Final DEEP decision is always made in the summary call. answer, should_save, should_answer = await self._summarise_cluster(query, cluster) - if not should_answer: + + # --- Multi-factor evidence acceptance --- + final_cluster_evidence = str(cluster.content) if cluster and cluster.content else "" + final_accepted, final_reason = self._evaluate_evidence_acceptance( + query, final_cluster_evidence, should_answer, + ) + await self._logger.info( + f"[Phase 4] Final evidence acceptance: {final_accepted} ({final_reason})" + ) + + if not final_accepted: if llm_fallback: await self._logger.info( "[Phase 4] Final summary gate rejected evidence, llm_fallback=True → LLM fallback" @@ -2055,6 +2182,25 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _TREE_ROOT_HINT_TRUNCATE = 150 """Max chars of tree root summary in Step 1 structure hints.""" + # --- Self-correction expanded sampling --- + _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 6 + """Expanded tree navigation leaf count for same-file re-sampling (default nav uses 3).""" + _SELF_CORRECT_EXPANDED_SECTIONS: int = 5 + """Expanded tree sample sections for same-file re-sampling (default uses 3).""" + + # --- Evidence acceptance thresholds --- + _EVIDENCE_MIN_ACCEPT_LENGTH: int = 1500 + """Minimum evidence character length for heuristic override.""" + _EVIDENCE_KEYWORD_COVERAGE_THRESHOLD: float = 0.6 + """Minimum keyword coverage ratio for heuristic override.""" + _NUMERIC_INTENT_KEYWORDS: frozenset = frozenset({ + "revenue", "margin", "ratio", "ebitda", "income", "profit", "loss", + "cash", "debt", "equity", "eps", "dpo", "growth", "rate", + "percentage", "amount", "total", "net", "gross", "cost", "expense", + "sales", "fy", "fiscal", + }) + """Keywords indicating numeric/financial intent in a query.""" + _LLM_FALLBACK_EVIDENCE = ( "[No relevant documents found]\n\n" "The search did not find relevant content in the available documents. " @@ -2573,12 +2719,20 @@ async def _rga_evidence() -> str: answer_resp.content or "" ) + # --- Multi-factor evidence acceptance (P2+P3+P4) --- + accepted, accept_reason = self._evaluate_evidence_acceptance( + query, evidence, should_answer, + ) + await self._logger.info( + f"[FAST:Step4] Evidence acceptance: {accepted} ({accept_reason})" + ) + # ============================================================== # Step 5: Self-correction retry (conditional, ≤1 extra LLM call) # When the answer gate rejects the first attempt, try alternative # evidence sources before giving up. # ============================================================== - if not should_answer: + if not accepted: retry_evidence = await self._fast_self_correct( query, best_files, catalog_routed_files, context, ) @@ -2598,11 +2752,19 @@ async def _rga_evidence() -> str: context.add_llm_tokens( retry_resp.usage.get("total_tokens", 0), usage=retry_resp.usage, ) - answer, should_save, should_answer = self._parse_summary_response( + answer, should_save, retry_should_answer = self._parse_summary_response( retry_resp.content or "" ) + retry_accepted, retry_reason = self._evaluate_evidence_acceptance( + query, retry_evidence, retry_should_answer, + ) + await self._logger.info( + f"[FAST:Step5] Retry evidence acceptance: {retry_accepted} ({retry_reason})" + ) + if retry_accepted: + accepted = True - if not should_answer: + if not accepted: if llm_fallback: await self._logger.info( "[FAST:Step5] Retry also rejected, llm_fallback=True → LLM fallback" @@ -3637,7 +3799,7 @@ async def _tree_guided_sample( return evidence async def _navigate_tree_for_evidence( - self, file_path: str, query: str, + self, file_path: str, query: str, *, max_results: int = 3, ) -> Optional[str]: """LLM-driven tree navigation: select relevant sections and read leaf content. @@ -3653,7 +3815,7 @@ async def _navigate_tree_for_evidence( return None try: - leaves = await indexer.navigate(tree, query, max_results=3) + leaves = await indexer.navigate(tree, query, max_results=max_results) except Exception: return None @@ -3699,7 +3861,8 @@ async def _fast_self_correct( ) -> Optional[str]: """Attempt to gather alternative evidence when the first answer is rejected. - Three strategies tried in order: + Four strategies tried in order: + D) Re-sample the same primary file with expanded parameters (deeper sampling). A) Tree-navigate a 2nd catalog-routed file not yet tried. B) Retrieve the most semantically similar compiled cluster's content. C) Tree-navigate the 2nd-best rga file if available. @@ -3708,6 +3871,20 @@ async def _fast_self_correct( """ first_file = best_files[0]["path"] if best_files else "" + # Strategy D: Re-sample the SAME primary file with expanded parameters. + # The file was correct but the initial sampling may have missed key sections. + if first_file: + expanded_tree_ev = await self._navigate_tree_for_evidence( + first_file, query, + max_results=self._SELF_CORRECT_EXPANDED_NAV_RESULTS, + ) + if expanded_tree_ev and len(expanded_tree_ev.strip()) > 50: + await self._logger.info( + "[FAST:SelfCorrect] Strategy D succeeded: " + "expanded same-file tree navigation" + ) + return expanded_tree_ev + # Strategy A: 2nd catalog-routed file via tree navigation for fp in catalog_routed_files: if fp == first_file: From 34c181eaf1a01a10bf824908c6f18c1131d7078e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 20:22:44 +0800 Subject: [PATCH 26/70] fix eval for finbench in runner --- benchmarks/financebench/runner.py | 71 ++++++++++++++++++++++++++++--- src/sirchmunk/search.py | 6 +-- 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/benchmarks/financebench/runner.py b/benchmarks/financebench/runner.py index b95f7ca..86404f2 100644 --- a/benchmarks/financebench/runner.py +++ b/benchmarks/financebench/runner.py @@ -12,10 +12,11 @@ import asyncio import json as json_mod import logging +import re import time from datetime import datetime from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from config import FinanceBenchConfig from data_loader import FinanceBenchLoader @@ -38,7 +39,12 @@ Given the financial question and a verbose response, extract ONLY the short factoid answer. Rules: - Output ONLY the answer value/phrase (1-20 words). No explanation. -- If the response says it cannot find the answer, output: unknown +- If the response contains ANY concrete data (dollar amounts, percentages, numbers, + company names, yes/no conclusions), extract that data even if the response also + expresses uncertainty or says it could not find a "complete" answer. +- A partial answer with real data is ALWAYS better than "unknown". +- Output "unknown" ONLY when the response contains absolutely no useful factual + information (e.g., a pure apology with zero data points). - For monetary values, keep the currency format (e.g., $1,577.00) - For percentages, keep the % sign (e.g., 15.3%) - For yes/no questions, output: yes or no @@ -48,6 +54,16 @@ Short answer:""" +# Regex pattern for extracting financial numeric data as fallback +_NUMERIC_EXTRACTION_PATTERN = ( + r'\$[\d,]+(?:\.\d+)?\s*(?:million|billion|mn|bn|M|B|K)?' + r'|\d+(?:,\d{3})+(?:\.\d+)?\s*(?:million|billion|mn|bn|M|B|%)?' + r'|\d+(?:\.\d+)?\s*(?:million|billion|mn|bn|M|B|%)' +) + +# Sentinel values indicating extraction found no useful answer +_UNKNOWN_SENTINELS = frozenset({"unknown", "n/a", ""}) + # NOTE: _normalize_prediction removed — use evaluate.normalize_answer instead. @@ -57,12 +73,26 @@ # ------------------------------------------------------------------ -async def _extract_short_answer( +def _extract_numeric_fallback(text: str) -> Optional[str]: + """Extract financial figures from *text* using regex patterns. + + Looks for currency amounts ($xxx), percentages, and large numbers + with units (million, billion, etc.). + + Returns the first match or ``None``. + """ + match = re.search(_NUMERIC_EXTRACTION_PATTERN, text) + if match: + return match.group(0).strip() + return None + + +async def _llm_extract( question: str, verbose: str, llm: Any, -) -> str: - """Use *llm* to distil *verbose* into a short factoid answer.""" +) -> Optional[str]: + """Layer-1: use LLM to distil *verbose* into a short factoid answer.""" prompt = _EXTRACT_PROMPT.format(question=question, response=verbose[:4000]) try: resp = await llm.achat( @@ -71,8 +101,35 @@ async def _extract_short_answer( ) return resp.content.strip() except Exception: - logger.warning("Short-answer extraction failed; falling back to raw answer.") - return verbose + logger.warning("LLM extraction failed; will try regex fallback.") + return None + + +async def _extract_short_answer( + question: str, + verbose: str, + llm: Any, +) -> str: + """Extract a concise answer from verbose LLM analysis. + + Uses a three-layer extraction strategy: + 1. LLM-based extraction with improved prompt + 2. Regex-based numeric/financial data extraction as fallback + 3. Returns 'unknown' only when no useful data is found + """ + # Layer 1: LLM extraction + answer = await _llm_extract(question, verbose, llm) + if answer and answer.strip().lower() not in _UNKNOWN_SENTINELS: + return answer.strip() + + # Layer 2: Regex fallback for numeric/financial data + numeric_answer = _extract_numeric_fallback(verbose) + if numeric_answer: + logger.info("Regex fallback extracted: %s", numeric_answer) + return numeric_answer + + # Layer 3: No useful data found + return "unknown" # ------------------------------------------------------------------ diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index a900978..c2f30f4 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2189,9 +2189,9 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Expanded tree sample sections for same-file re-sampling (default uses 3).""" # --- Evidence acceptance thresholds --- - _EVIDENCE_MIN_ACCEPT_LENGTH: int = 1500 + _EVIDENCE_MIN_ACCEPT_LENGTH: int = 800 """Minimum evidence character length for heuristic override.""" - _EVIDENCE_KEYWORD_COVERAGE_THRESHOLD: float = 0.6 + _EVIDENCE_KEYWORD_COVERAGE_THRESHOLD: float = 0.5 """Minimum keyword coverage ratio for heuristic override.""" _NUMERIC_INTENT_KEYWORDS: frozenset = frozenset({ "revenue", "margin", "ratio", "ebitda", "income", "profit", "loss", @@ -3231,7 +3231,7 @@ async def _fast_sample_evidence( # Diagnostic logging when falling back to snippet mode if not hit_lines and match_objects: - await self._logger.warning( + await self._logger.info( f"[FAST] No line_number in {len(match_objects)} match(es) for {fname}, " f"falling back to snippet mode" ) From 2b4714efd270376c7b53a41a131facab858f02ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 21:03:07 +0800 Subject: [PATCH 27/70] refactor metrics as LLM judge for finbench --- benchmarks/financebench/evaluate.py | 264 ++++------------------- benchmarks/financebench/judge.py | 195 +++++++++++++++-- benchmarks/financebench/run_benchmark.py | 67 ++---- benchmarks/financebench/runner.py | 240 +++++---------------- 4 files changed, 300 insertions(+), 466 deletions(-) diff --git a/benchmarks/financebench/evaluate.py b/benchmarks/financebench/evaluate.py index e22bf07..d9614f3 100644 --- a/benchmarks/financebench/evaluate.py +++ b/benchmarks/financebench/evaluate.py @@ -1,63 +1,21 @@ -"""FinanceBench evaluation metrics. +"""FinanceBench evaluation metrics — LLM Judge driven. -Implements the three-class scoring scheme from the FinanceBench paper -(Islam et al., 2023): **correct**, **hallucination**, **refusal**. +All correctness evaluation (Accuracy, Coverage) is driven by the LLM Judge. +This module aggregates per-question judge results into benchmark-level metrics. -Financial-value normalisation handles currency symbols, thousand separators, -trailing zeros, and percentage signs so that ``$1,577.00`` matches ``1577``. +The ``normalize_answer`` helper is retained for quick short-circuit checks +inside the judge (exact-match bypass before calling the LLM). """ from __future__ import annotations import re -from collections import Counter, defaultdict +from collections import defaultdict from typing import Any, Dict, List # ------------------------------------------------------------------ # Constants # ------------------------------------------------------------------ -_REFUSAL_PHRASES: list[str] = [ - "i cannot", - "i can't", - "i could not", - "i couldn't", - "no results found", - "unable to", - "not able to", - "i don't know", - "i do not know", - "information is not available", - "not enough information", - "cannot determine", - "cannot be determined", - "insufficient data", - "no relevant information", - "data not found", - "unknown", - "i'm not able to", - "i am not able to", - "the document does not contain", - "the document doesn't contain", - "this information is not disclosed", - "not disclosed", - "could not find", - "couldn't find", - "no mention of", - "no information about", - "not provided in", - "not found in the document", - "i was unable to", - "unable to determine", - "unable to find", - "unable to locate", - "there is no data", - "no data available", - "not available in", - "not specified", -] - -_F1_CORRECT_THRESHOLD: float = 0.8 - # Markdown / wrapper patterns compiled once _RE_BOLD = re.compile(r"\*\*(.+?)\*\*") _RE_ITALIC = re.compile(r"\*(.+?)\*") @@ -169,109 +127,6 @@ def _normalize_financial_value(text: str) -> str: return s -# ------------------------------------------------------------------ -# Matching helpers -# ------------------------------------------------------------------ - - -def exact_match(prediction: str, gold: str) -> bool: - """Return ``True`` when normalised strings are identical.""" - return normalize_answer(prediction) == normalize_answer(gold) - - -def f1_score(prediction: str, gold: str) -> float: - """Compute token-level F1 between *prediction* and *gold*. - - Tokenisation is simple whitespace splitting after normalisation. - Each token is further normalised as a financial value so that - ``$1577`` matches ``1577`` at the token level. - Returns 0.0 when either side is empty. - """ - pred_tokens = [_normalize_financial_value(t) for t in normalize_answer(prediction).split()] - gold_tokens = [_normalize_financial_value(t) for t in normalize_answer(gold).split()] - if not pred_tokens or not gold_tokens: - return 0.0 - - common = Counter(pred_tokens) & Counter(gold_tokens) - num_common = sum(common.values()) - if num_common == 0: - return 0.0 - - precision = num_common / len(pred_tokens) - recall = num_common / len(gold_tokens) - return 2 * precision * recall / (precision + recall) - - -# ------------------------------------------------------------------ -# Three-class classification -# ------------------------------------------------------------------ - - -def classify_answer( - prediction: str, - gold: str, - *, - is_no_result: bool = False, - f1_threshold: float = _F1_CORRECT_THRESHOLD, -) -> str: - """Classify a prediction into ``correct``, ``refusal``, or ``hallucination``. - - Classification logic (faithful to FinanceBench paper): - 1. If the system explicitly refused (``is_no_result=True``) or the - prediction contains a refusal phrase → **refusal**. - 2. If EM passes or token-level F1 ≥ *f1_threshold* → **correct**. - 3. Otherwise → **hallucination**. - """ - norm_pred = normalize_answer(prediction) - - # --- Refusal --- - if is_no_result: - return "refusal" - pred_lower = norm_pred.lower() - for phrase in _REFUSAL_PHRASES: - if phrase in pred_lower: - return "refusal" - - # --- Correct --- - if exact_match(prediction, gold): - return "correct" - if f1_score(prediction, gold) >= f1_threshold: - return "correct" - - # --- Hallucination --- - return "hallucination" - - -# ------------------------------------------------------------------ -# Evidence recall -# ------------------------------------------------------------------ - - -def evidence_recall( - retrieved_pages: List[int], - gold_evidence: List[Dict[str, Any]], -) -> float: - """Compute page-level evidence recall. - - ``gold_evidence`` entries carry ``evidence_page_num`` (0-indexed). - Returns 1.0 when there is no gold evidence (vacuously true). - """ - if not gold_evidence: - return 1.0 - - gold_pages = { - int(e["evidence_page_num"]) - for e in gold_evidence - if "evidence_page_num" in e - } - if not gold_pages: - return 1.0 - - retrieved_set = set(retrieved_pages) - hits = gold_pages & retrieved_set - return len(hits) / len(gold_pages) - - # ------------------------------------------------------------------ # Aggregate metrics # ------------------------------------------------------------------ @@ -280,100 +135,75 @@ def evidence_recall( def compute_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: """Aggregate per-question results into benchmark-level metrics. - Expected keys per result dict: ``classification``, ``em``, ``f1``, - ``elapsed``, ``telemetry``, ``question_type``, ``question_reasoning``, - ``evidence_recall`` (optional). + All correctness evaluation is driven by LLM Judge results stored in + each result dict (``judge_correct``, ``coverage``). - Returns a dict with overall stats plus breakdowns by *question_type* - and *question_reasoning*. + Returns a dict with overall stats plus breakdown by *question_type*. """ n = len(results) if n == 0: return {"n": 0} - # --- Overall counts --- - correct = sum(1 for r in results if r.get("classification") == "correct") - halluc = sum(1 for r in results if r.get("classification") == "hallucination") - refusal = sum(1 for r in results if r.get("classification") == "refusal") + # --- Accuracy (Judge) --- + judge_correct = sum(1 for r in results if r.get("judge_correct")) - em_sum = sum(1 for r in results if r.get("em")) - f1_sum = sum(r.get("f1", 0.0) for r in results) + # --- Coverage (Judge) --- + coverage_true = sum(1 for r in results if r.get("coverage")) + # --- Latency --- latencies = [r["elapsed"] for r in results if "elapsed" in r] avg_latency = sum(latencies) / len(latencies) if latencies else 0.0 + total_time = sum(latencies) - token_counts = [ + # --- Token usage --- + search_tokens = sum( r.get("telemetry", {}).get("total_tokens", 0) for r in results - ] - avg_tokens = sum(token_counts) / len(token_counts) if token_counts else 0 - - ev_recalls = [r["evidence_recall"] for r in results if r.get("evidence_recall") is not None] - avg_ev_recall = sum(ev_recalls) / len(ev_recalls) if ev_recalls else None + ) + judge_tokens = sum(r.get("judge_tokens", 0) for r in results) + total_tokens = search_tokens + judge_tokens + avg_tokens_per_question = total_tokens / n if n else 0 - overall = { + overall: Dict[str, Any] = { "n": n, - "accuracy": round(correct / n * 100, 2), - "hallucination_rate": round(halluc / n * 100, 2), - "refusal_rate": round(refusal / n * 100, 2), - "correct": correct, - "hallucination": halluc, - "refusal": refusal, - "avg_em": em_sum / n, - "avg_f1": f1_sum / n, + "accuracy": round(judge_correct / n * 100, 2), + "coverage": round(coverage_true / n * 100, 2), "avg_latency": round(avg_latency, 2), - "avg_tokens": round(avg_tokens, 1), + "total_time_seconds": round(total_time, 2), + "token_usage": { + "total_tokens": total_tokens, + "search_tokens": search_tokens, + "judge_tokens": judge_tokens, + "avg_tokens_per_question": round(avg_tokens_per_question, 1), + }, + "judge_correct": judge_correct, + "coverage_true": coverage_true, + "by_question_type": _breakdown(results, "question_type"), } - if avg_ev_recall is not None: - overall["evidence_recall"] = round(avg_ev_recall, 4) - - # --- LLM Judge metrics (independent dimension, NOT fallback) --- - judge_results = [r for r in results if r.get("llm_judge_correct") is not None] - if judge_results: - judge_correct = sum(1 for r in judge_results if r["llm_judge_correct"]) - overall["llm_judge_accuracy"] = round(judge_correct / len(judge_results) * 100, 2) - overall["llm_judge_count"] = len(judge_results) - overall["llm_judge_correct"] = judge_correct - else: - overall["llm_judge_accuracy"] = None - overall["llm_judge_count"] = 0 - overall["llm_judge_correct"] = 0 - - # --- Breakdowns --- - overall["by_question_type"] = _breakdown(results, "question_type") - overall["by_question_reasoning"] = _breakdown(results, "question_reasoning") return overall -def _breakdown(results: List[Dict[str, Any]], key: str) -> Dict[str, Dict[str, Any]]: - """Compute per-group accuracy / hallucination / refusal breakdown.""" +def _breakdown( + results: List[Dict[str, Any]], key: str +) -> Dict[str, Dict[str, Any]]: + """Compute per-group accuracy / coverage breakdown.""" groups: dict[str, list[dict]] = defaultdict(list) for r in results: group = r.get(key) or "unknown" groups[group].append(r) out: dict[str, dict] = {} - for group, items in sorted(groups.items(), key=lambda x: (x[0] is None, x[0] or "")): + for group, items in sorted( + groups.items(), key=lambda x: (x[0] is None, x[0] or "") + ): g_n = len(items) - g_correct = sum(1 for r in items if r.get("classification") == "correct") - g_halluc = sum( - 1 for r in items if r.get("classification") == "hallucination" - ) - g_refusal = sum(1 for r in items if r.get("classification") == "refusal") - group_dict: dict[str, Any] = { + g_correct = sum(1 for r in items if r.get("judge_correct")) + g_coverage = sum(1 for r in items if r.get("coverage")) + out[group] = { "n": g_n, "accuracy": round(g_correct / g_n * 100, 2) if g_n else 0.0, - "hallucination_rate": round(g_halluc / g_n * 100, 2) if g_n else 0.0, - "refusal_rate": round(g_refusal / g_n * 100, 2) if g_n else 0.0, - "correct": g_correct, - "hallucination": g_halluc, - "refusal": g_refusal, + "coverage": round(g_coverage / g_n * 100, 2) if g_n else 0.0, + "judge_count": g_n, + "judge_correct": g_correct, } - # LLM Judge breakdown - g_judge = [r for r in items if r.get("llm_judge_correct") is not None] - if g_judge: - g_jc = sum(1 for r in g_judge if r["llm_judge_correct"]) - group_dict["llm_judge_accuracy"] = round(g_jc / len(g_judge) * 100, 2) - group_dict["llm_judge_count"] = len(g_judge) - out[group] = group_dict return out diff --git a/benchmarks/financebench/judge.py b/benchmarks/financebench/judge.py index e52b6e6..8140669 100644 --- a/benchmarks/financebench/judge.py +++ b/benchmarks/financebench/judge.py @@ -1,12 +1,11 @@ -"""LLM-based semantic equivalence judge for FinanceBench. +"""LLM-based judge for FinanceBench evaluation. -The judge evaluates whether a model's prediction is semantically -equivalent to the gold answer, operating as an **independent** -evaluation dimension alongside EM/F1 — not as a fallback. +The judge drives **all** evaluation decisions: +- **Accuracy**: whether the prediction is semantically equivalent to the gold answer. +- **Coverage**: whether the prediction contains any information relevant to the question. -This provides a more nuanced correctness signal for financial QA, -where formatting differences (e.g., $1.5B vs $1,500M) can cause -EM/F1 to undercount correct answers. +This replaces the previous EM/F1 rule-driven pipeline with a single LLM-based +evaluation authority, providing more nuanced correctness signals for financial QA. """ from __future__ import annotations @@ -155,19 +154,50 @@ class FinanceBenchLLMJudge: - """LLM-based judge for semantic equivalence in financial QA. + """LLM-based judge driving all FinanceBench evaluation. - Operates as an independent evaluation dimension — NOT as a - fallback for EM/F1. Each question gets a separate judge verdict - that is tracked in its own metrics. + Provides two evaluation axes: + - ``judge()``: semantic equivalence (Accuracy). + - ``judge_coverage()``: information relevance (Coverage). + + Token usage from every LLM call is tracked and returned. """ _CONFIDENCE_THRESHOLD: float = 0.7 _MAX_RETRIES: int = 2 + # Coverage evaluation prompt + _COVERAGE_PROMPT: str = """\ +You are evaluating whether a system's response contains ANY useful information \ +relevant to the given financial question. + +Question: {question} +System Response: {prediction} + +Task: Determine if the response contains relevant, useful information. + +═══════════════════════════════════════════════ +HAS COVERAGE (has_coverage = true) — when ANY of: +═══════════════════════════════════════════════ +1. Contains specific financial data (dollar amounts, percentages, ratios) +2. Contains relevant factual statements about the company or topic +3. Contains partial but concrete information related to the question +4. Provides a direct answer (even if potentially incorrect) + +═══════════════════════════════════════════════ +NO COVERAGE (has_coverage = false) — when ALL of: +═══════════════════════════════════════════════ +1. Response is a refusal ("I cannot", "No results found", etc.) +2. Response contains no concrete data related to the question +3. Response is empty, purely apologetic, or only contains generic filler + +Respond ONLY with a JSON object (no markdown, no extra text): +{{"has_coverage": true or false, "confidence": 0.0 to 1.0, "reasoning": "brief explanation"}}""" + def __init__(self, llm: Any) -> None: self._llm = llm self._cache: Dict[tuple, Dict[str, Any]] = {} + self._total_tokens_used: int = 0 # ------------------------------------------------------------------ # Public API @@ -192,7 +222,8 @@ async def judge( "confidence": float (0-1), "reasoning": str, "cached": bool, - "error": Optional[str] + "error": Optional[str], + "tokens_used": int, } """ # --- Refusal short-circuit (saves LLM call) --- @@ -203,6 +234,7 @@ async def judge( "reasoning": "Prediction is a refusal — skipped LLM judge.", "cached": False, "error": None, + "tokens_used": 0, } # --- Quick exact-match shortcut --- @@ -215,6 +247,7 @@ async def judge( "reasoning": "Normalized exact match", "cached": False, "error": None, + "tokens_used": 0, } # --- Check cache (key includes question for context-sensitivity) --- @@ -237,6 +270,7 @@ async def judge( result: Dict[str, Any] | None = None last_error: str | None = None + tokens_used: int = 0 for attempt in range(1, self._MAX_RETRIES + 1): try: @@ -244,6 +278,7 @@ async def judge( messages=[{"role": "user", "content": prompt}], stream=False, ) + tokens_used = self._extract_tokens(resp) raw = resp.content.strip() result = self._parse_response(raw) if result.get("error") is None: @@ -283,6 +318,8 @@ async def judge( result.setdefault("cached", False) result.setdefault("error", None) + result["tokens_used"] = tokens_used + self._total_tokens_used += tokens_used # Cache successful results only if result["error"] is None: @@ -414,6 +451,140 @@ def _is_refusal(text: str) -> bool: return True return False + async def judge_coverage( + self, + prediction: str, + question: str, + ) -> Dict[str, Any]: + """Evaluate whether *prediction* contains relevant information for *question*. + + Returns: + { + "has_coverage": bool, + "confidence": float (0-1), + "reasoning": str, + "tokens_used": int, + "error": Optional[str], + } + """ + # --- Refusal short-circuit --- + if self._is_refusal(prediction): + return { + "has_coverage": False, + "confidence": 1.0, + "reasoning": "Explicit refusal detected.", + "tokens_used": 0, + "error": None, + } + + prompt = self._COVERAGE_PROMPT.format( + question=question or "N/A", + prediction=prediction[:4000], + ) + + result: Dict[str, Any] | None = None + last_error: str | None = None + tokens_used: int = 0 + + for attempt in range(1, self._MAX_RETRIES + 1): + try: + resp = await self._llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, + ) + tokens_used = self._extract_tokens(resp) + raw = resp.content.strip() + result = self._parse_coverage_response(raw) + if result.get("error") is None: + break + last_error = result.get("error") + except Exception as e: + last_error = str(e) + logger.warning( + "LLM Coverage judge failed (attempt %d/%d): %s", + attempt, + self._MAX_RETRIES, + e, + ) + result = None + + if result is None or result.get("error") is not None: + result = { + "has_coverage": False, + "confidence": 0.0, + "reasoning": f"Coverage judge error after {self._MAX_RETRIES} attempts: {last_error}", + "error": last_error, + } + + result.setdefault("error", None) + result["tokens_used"] = tokens_used + self._total_tokens_used += tokens_used + return result + + # ------------------------------------------------------------------ + # Coverage response parsing + # ------------------------------------------------------------------ + + def _parse_coverage_response(self, raw: str) -> Dict[str, Any]: + """Parse LLM JSON response for coverage evaluation.""" + parsed = self._try_parse_json(raw) + if parsed is not None: + has_coverage = bool(parsed.get("has_coverage", False)) + try: + confidence = float(parsed.get("confidence", 0.0)) + except (ValueError, TypeError): + confidence = 0.0 + confidence = max(0.0, min(1.0, confidence)) + reasoning = str(parsed.get("reasoning", "")) + return { + "has_coverage": has_coverage, + "confidence": confidence, + "reasoning": reasoning, + } + + # Fallback: keyword detection + lower = raw.lower() + true_match = re.search(r'"has_coverage"\s*:\s*true\b', lower) + false_match = re.search(r'"has_coverage"\s*:\s*false\b', lower) + + if false_match and not true_match: + return { + "has_coverage": False, + "confidence": 0.5, + "reasoning": f"Keyword fallback (no coverage): {raw[:200]}", + } + elif true_match and not false_match: + return { + "has_coverage": True, + "confidence": 0.5, + "reasoning": f"Keyword fallback (has coverage): {raw[:200]}", + } + + logger.warning("Cannot parse coverage response: %s", raw[:200]) + return { + "has_coverage": False, + "confidence": 0.0, + "reasoning": f"Unparseable response: {raw[:200]}", + "error": "parse_error", + } + + # ------------------------------------------------------------------ + # Token tracking + # ------------------------------------------------------------------ + + @staticmethod + def _extract_tokens(resp: Any) -> int: + """Extract total token count from an LLM response.""" + usage = getattr(resp, "usage", None) + if isinstance(usage, dict): + return int(usage.get("total_tokens", 0)) + return 0 + + @property + def total_tokens_used(self) -> int: + """Cumulative tokens consumed by all judge calls.""" + return self._total_tokens_used + @property def cache_size(self) -> int: """Return the number of cached judge results.""" diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py index c9f5b26..cf7b30a 100644 --- a/benchmarks/financebench/run_benchmark.py +++ b/benchmarks/financebench/run_benchmark.py @@ -94,33 +94,28 @@ def _print_summary( """Print a human-readable run summary to stdout.""" n = len(results) acc = metrics.get("accuracy", 0) - hallu = metrics.get("hallucination_rate", 0) - refuse = metrics.get("refusal_rate", 0) - avg_em = metrics.get("avg_em", 0) - avg_f1 = metrics.get("avg_f1", 0) - ev_recall = metrics.get("evidence_recall") + cov = metrics.get("coverage", 0) avg_latency = metrics.get("avg_latency", 0) + token_usage = metrics.get("token_usage", {}) + total_tokens = token_usage.get("total_tokens", 0) + search_tokens = token_usage.get("search_tokens", 0) + judge_tokens = token_usage.get("judge_tokens", 0) + avg_tokens_q = token_usage.get("avg_tokens_per_question", 0) + print("\n" + "=" * 60) print(f"FinanceBench Results ({n} questions)") print("=" * 60) - print(f" Accuracy: {acc:.1f}%") - print(f" Hallucination Rate: {hallu:.1f}%") - print(f" Refusal Rate: {refuse:.1f}%") - print(f" Avg EM: {avg_em:.3f}") - print(f" Avg F1: {avg_f1:.3f}") - if ev_recall is not None: - print(f" Evidence Recall: {ev_recall:.3f}") - else: - print(f" Evidence Recall: N/A (page-level telemetry unavailable)") + print(f" Accuracy (Judge): {acc:.1f}%") + print(f" Coverage (Judge): {cov:.1f}%") print(f" Avg Latency: {avg_latency:.1f}s") print(f" Total Time: {total_time:.1f}s") - # LLM Judge independent metrics - if metrics.get("llm_judge_accuracy") is not None: - print(f"\n --- LLM Judge (Independent) ---") - print(f" Judge Accuracy: {metrics['llm_judge_accuracy']:.1f}%") - print(f" Judge Correct: {metrics['llm_judge_correct']}/{metrics['llm_judge_count']}") + print(f"\n --- Token Usage ---") + print(f" Total Tokens: {total_tokens:>,}") + print(f" Search Tokens: {search_tokens:>,}") + print(f" Judge Tokens: {judge_tokens:>,}") + print(f" Avg per Question: {avg_tokens_q:>,.0f}") print(f"\n Results: {results_path}") print(f" Metrics: {metrics_path}") @@ -129,28 +124,13 @@ def _print_summary( # Breakdown by question_type by_qt = metrics.get("by_question_type") if by_qt: - # Determine if judge data is available - has_judge = any(m.get("llm_judge_accuracy") is not None for m in by_qt.values()) - if has_judge: - print(f"\n {'Question Type':<25} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'Judge%':>7} {'N':>4}") - print(" " + "-" * 59) - for qt, m in sorted(by_qt.items()): - qt_acc = m.get("accuracy", 0) - qt_hal = m.get("hallucination_rate", 0) - qt_ref = m.get("refusal_rate", 0) - qt_n = m.get("n", 0) - qt_judge = m.get("llm_judge_accuracy") - qt_judge_str = f"{qt_judge:>6.1f}" if qt_judge is not None else " N/A" - print(f" {qt:<25} {qt_acc:>5.1f} {qt_hal:>7.1f} {qt_ref:>7.1f} {qt_judge_str} {qt_n:>4}") - else: - print(f"\n {'Question Type':<25} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}") - print(" " + "-" * 52) - for qt, m in sorted(by_qt.items()): - qt_acc = m.get("accuracy", 0) - qt_hal = m.get("hallucination_rate", 0) - qt_ref = m.get("refusal_rate", 0) - qt_n = m.get("n", 0) - print(f" {qt:<25} {qt_acc:>5.1f} {qt_hal:>7.1f} {qt_ref:>7.1f} {qt_n:>4}") + print(f"\n {'Question Type':<28} {'Acc%':>6} {'Cover%':>7} {'N':>5}") + print(" " + "-" * 48) + for qt, m in sorted(by_qt.items()): + qt_acc = m.get("accuracy", 0) + qt_cov = m.get("coverage", 0) + qt_n = m.get("n", 0) + print(f" {qt:<28} {qt_acc:>5.1f} {qt_cov:>7.1f} {qt_n:>5}") print("=" * 60) @@ -222,11 +202,9 @@ def main() -> None: # 6. Print run config logger.info( - "Config: mode=%s, eval_mode=%s, extract_answer=%s, " - "llm_judge=%s, concurrent=%d, model=%s", + "Config: mode=%s, eval_mode=%s, llm_judge=%s, concurrent=%d, model=%s", cfg.mode, cfg.eval_mode, - cfg.extract_answer, cfg.enable_llm_judge, cfg.max_concurrent, cfg.llm_model, @@ -246,7 +224,6 @@ def main() -> None: "eval_mode": cfg.eval_mode, "model": cfg.llm_model, "top_k_files": cfg.top_k_files, - "extract_answer": cfg.extract_answer, } # 9. Save results (JSONL) + metrics (JSON) diff --git a/benchmarks/financebench/runner.py b/benchmarks/financebench/runner.py index 86404f2..7e2f115 100644 --- a/benchmarks/financebench/runner.py +++ b/benchmarks/financebench/runner.py @@ -4,157 +4,24 @@ - **singleDoc**: each question searches only its target PDF directory. - **sharedCorpus**: all questions search the full PDF corpus. -After search, an optional LLM extraction step converts the verbose -briefing into a short factoid answer suitable for EM/F1. +All evaluation (Accuracy + Coverage) is driven by LLM Judge. """ from __future__ import annotations import asyncio import json as json_mod import logging -import re import time from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from config import FinanceBenchConfig from data_loader import FinanceBenchLoader -from evaluate import ( - classify_answer, - compute_metrics, - exact_match, - evidence_recall, - f1_score, - normalize_answer, -) +from evaluate import compute_metrics logger = logging.getLogger("financebench.runner") -# ------------------------------------------------------------------ -# Answer extraction prompt (financial domain) -# ------------------------------------------------------------------ - -_EXTRACT_PROMPT = """\ -Given the financial question and a verbose response, extract ONLY the short factoid answer. -Rules: -- Output ONLY the answer value/phrase (1-20 words). No explanation. -- If the response contains ANY concrete data (dollar amounts, percentages, numbers, - company names, yes/no conclusions), extract that data even if the response also - expresses uncertainty or says it could not find a "complete" answer. -- A partial answer with real data is ALWAYS better than "unknown". -- Output "unknown" ONLY when the response contains absolutely no useful factual - information (e.g., a pure apology with zero data points). -- For monetary values, keep the currency format (e.g., $1,577.00) -- For percentages, keep the % sign (e.g., 15.3%) -- For yes/no questions, output: yes or no - -Question: {question} -Response: {response} - -Short answer:""" - -# Regex pattern for extracting financial numeric data as fallback -_NUMERIC_EXTRACTION_PATTERN = ( - r'\$[\d,]+(?:\.\d+)?\s*(?:million|billion|mn|bn|M|B|K)?' - r'|\d+(?:,\d{3})+(?:\.\d+)?\s*(?:million|billion|mn|bn|M|B|%)?' - r'|\d+(?:\.\d+)?\s*(?:million|billion|mn|bn|M|B|%)' -) - -# Sentinel values indicating extraction found no useful answer -_UNKNOWN_SENTINELS = frozenset({"unknown", "n/a", ""}) - - -# NOTE: _normalize_prediction removed — use evaluate.normalize_answer instead. - - -# ------------------------------------------------------------------ -# LLM short-answer extraction -# ------------------------------------------------------------------ - - -def _extract_numeric_fallback(text: str) -> Optional[str]: - """Extract financial figures from *text* using regex patterns. - - Looks for currency amounts ($xxx), percentages, and large numbers - with units (million, billion, etc.). - - Returns the first match or ``None``. - """ - match = re.search(_NUMERIC_EXTRACTION_PATTERN, text) - if match: - return match.group(0).strip() - return None - - -async def _llm_extract( - question: str, - verbose: str, - llm: Any, -) -> Optional[str]: - """Layer-1: use LLM to distil *verbose* into a short factoid answer.""" - prompt = _EXTRACT_PROMPT.format(question=question, response=verbose[:4000]) - try: - resp = await llm.achat( - messages=[{"role": "user", "content": prompt}], - stream=False, - ) - return resp.content.strip() - except Exception: - logger.warning("LLM extraction failed; will try regex fallback.") - return None - - -async def _extract_short_answer( - question: str, - verbose: str, - llm: Any, -) -> str: - """Extract a concise answer from verbose LLM analysis. - - Uses a three-layer extraction strategy: - 1. LLM-based extraction with improved prompt - 2. Regex-based numeric/financial data extraction as fallback - 3. Returns 'unknown' only when no useful data is found - """ - # Layer 1: LLM extraction - answer = await _llm_extract(question, verbose, llm) - if answer and answer.strip().lower() not in _UNKNOWN_SENTINELS: - return answer.strip() - - # Layer 2: Regex fallback for numeric/financial data - numeric_answer = _extract_numeric_fallback(verbose) - if numeric_answer: - logger.info("Regex fallback extracted: %s", numeric_answer) - return numeric_answer - - # Layer 3: No useful data found - return "unknown" - - -# ------------------------------------------------------------------ -# Page extraction helper -# ------------------------------------------------------------------ - - -def _try_extract_pages(telemetry: Dict[str, Any]) -> List[int]: - """Best-effort extraction of retrieved page numbers from telemetry. - - Current limitation: Sirchmunk's ``read_file_ids`` contains plain file - paths without page-level suffixes, so this function will typically - return an empty list. When empty, callers should treat evidence - recall as *unavailable* (``None``) rather than zero. - """ - pages: list[int] = [] - for fid in telemetry.get("read_file_ids", []): - # Convention: page indices may be embedded in file IDs - if isinstance(fid, str) and "_page_" in fid: - try: - pages.append(int(fid.rsplit("_page_", 1)[-1])) - except (ValueError, IndexError): - pass - return pages - # ------------------------------------------------------------------ # Single question execution @@ -174,24 +41,24 @@ async def run_single( fb_id = entry.get("financebench_id", "") question = entry["question"] gold = entry.get("answer", "") - gold_evidence = entry.get("evidence", []) async with semaphore: t0 = time.time() error: str | None = None raw_answer = "" - answer = "" telemetry: dict[str, Any] = {} - retrieved_pages: list[int] = [] try: # Determine search paths based on eval mode if cfg.eval_mode == "singleDoc": pdf_path = loader.get_pdf_path(entry.get("doc_name", "")) if pdf_path: - search_paths = [pdf_path] # pass the single PDF file directly + search_paths = [pdf_path] else: - logger.warning("PDF not found for %s, falling back to full corpus", entry.get("doc_name", "")) + logger.warning( + "PDF not found for %s, falling back to full corpus", + entry.get("doc_name", ""), + ) search_paths = [cfg.pdf_dir] else: search_paths = [cfg.pdf_dir] @@ -217,14 +84,6 @@ async def run_single( "llm_calls": len(getattr(result, "llm_usages", None) or []), "num_files_read": len(read_files), } - retrieved_pages = _try_extract_pages(telemetry) - - # Answer extraction - if cfg.extract_answer and raw_answer: - answer = await _extract_short_answer(question, raw_answer, llm) - answer = normalize_answer(answer) - else: - answer = normalize_answer(raw_answer) except Exception as exc: error = str(exc) @@ -236,39 +95,42 @@ async def run_single( if cfg.request_delay > 0: await asyncio.sleep(cfg.request_delay) - # --- Evaluation --- - is_no_result = not answer or answer.lower() in ("unknown", "") - em = exact_match(answer, gold) - f1 = f1_score(answer, gold) - classification = classify_answer(answer, gold, is_no_result=is_no_result) - if retrieved_pages: # only compute when page-level data is available - ev_recall = evidence_recall(retrieved_pages, gold_evidence) - else: - ev_recall = None # mark as unavailable, avoid false 0 - - # LLM Judge — independent evaluation dimension - # Skip judge for refusals (no point calling LLM on non-answers) - llm_judge_correct = None - llm_judge_reasoning = None - if judge is not None and classification != "refusal": + # --- LLM Judge evaluation (Accuracy + Coverage) --- + judge_correct = False + judge_reasoning = "" + judge_tokens = 0 + has_coverage = False + coverage_reasoning = "" + + if judge is not None: + # Accuracy evaluation try: judge_result = await judge.judge( - prediction=answer, + prediction=raw_answer, gold_answer=gold, question=question, ) - llm_judge_correct = judge_result.get("equivalent", False) - llm_judge_reasoning = judge_result.get("reasoning", "") + judge_correct = judge_result.get("equivalent", False) + judge_reasoning = judge_result.get("reasoning", "") + judge_tokens += judge_result.get("tokens_used", 0) + except Exception as e: + logger.warning("LLM Judge (accuracy) failed for %s: %s", fb_id, e) + + # Coverage evaluation + try: + coverage_result = await judge.judge_coverage( + prediction=raw_answer, + question=question, + ) + has_coverage = coverage_result.get("has_coverage", False) + coverage_reasoning = coverage_result.get("reasoning", "") + judge_tokens += coverage_result.get("tokens_used", 0) except Exception as e: - logger.warning("LLM Judge failed for %s: %s", fb_id, e) - elif judge is not None and classification == "refusal": - llm_judge_correct = False - llm_judge_reasoning = "Skipped: prediction classified as refusal" + logger.warning("LLM Judge (coverage) failed for %s: %s", fb_id, e) return { "financebench_id": fb_id, "question": question, - "prediction": answer, "raw_prediction": raw_answer, "gold_answer": gold, "company": entry.get("company", ""), @@ -277,12 +139,11 @@ async def run_single( "question_reasoning": entry.get("question_reasoning", ""), "elapsed": round(elapsed, 2), "telemetry": telemetry, - "classification": classification, - "em": em, - "f1": round(f1, 4), - "evidence_recall": round(ev_recall, 4) if ev_recall is not None else None, - "llm_judge_correct": llm_judge_correct, # None if judge disabled - "llm_judge_reasoning": llm_judge_reasoning, + "judge_correct": judge_correct, + "judge_reasoning": judge_reasoning, + "coverage": has_coverage, + "coverage_reasoning": coverage_reasoning, + "judge_tokens": judge_tokens, "error": error, } @@ -310,12 +171,12 @@ async def run_batch( loader = FinanceBenchLoader(data_dir=cfg.data_dir, pdf_dir=cfg.pdf_dir) semaphore = asyncio.Semaphore(cfg.max_concurrent) - # Initialise LLM Judge (uses the same test model) + # Initialise LLM Judge judge = None if cfg.enable_llm_judge: from judge import FinanceBenchLLMJudge judge = FinanceBenchLLMJudge(llm=llm) - logger.info("LLM Judge enabled (independent evaluation dimension)") + logger.info("LLM Judge enabled (drives Accuracy + Coverage)") # Prepare output directory / file out_dir = Path(cfg.output_dir) @@ -334,20 +195,16 @@ async def _run_and_record(entry: Dict[str, Any]) -> Dict[str, Any]: with open(out_path, "a", encoding="utf-8") as fp: fp.write(json_mod.dumps(res, ensure_ascii=False) + "\n") completed += 1 - status = res["classification"] - judge_tag = "" - if res.get("llm_judge_correct") is not None: - judge_tag = " [judge:\u2713]" if res["llm_judge_correct"] else " [judge:\u2717]" + acc_tag = "\u2713" if res["judge_correct"] else "\u2717" + cov_tag = "cov" if res["coverage"] else "no-cov" logger.info( - "[%d/%d] %s %s EM=%s F1=%.2f %.1fs%s", + "[%d/%d] %s [acc:%s] [%s] %.1fs", completed, total, res["financebench_id"], - status, - res["em"], - res["f1"], + acc_tag, + cov_tag, res["elapsed"], - judge_tag, ) return res @@ -361,10 +218,9 @@ async def _run_and_record(entry: Dict[str, Any]) -> Dict[str, Any]: json_mod.dump(metrics, fp, indent=2, ensure_ascii=False) logger.info("Metrics saved to %s", metrics_path) logger.info( - "Accuracy=%.2f%% Hallucination=%.2f%% Refusal=%.2f%%", + "Accuracy=%.2f%% Coverage=%.2f%%", metrics.get("accuracy", 0), - metrics.get("hallucination_rate", 0), - metrics.get("refusal_rate", 0), + metrics.get("coverage", 0), ) return list(results) From a184e862d5bab63ce1eb1ae7781c93d733703377 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 21:08:02 +0800 Subject: [PATCH 28/70] update config --- benchmarks/financebench/config.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/benchmarks/financebench/config.py b/benchmarks/financebench/config.py index 5c390ce..68fe2a1 100644 --- a/benchmarks/financebench/config.py +++ b/benchmarks/financebench/config.py @@ -54,9 +54,7 @@ class FinanceBenchConfig: # Evaluation eval_mode: str = "singleDoc" # singleDoc / sharedCorpus - enable_llm_judge: bool = True # Use LLM to judge semantic equivalence (independent metric) - extract_answer: bool = True - judge_f1_threshold: float = 0.8 # F1 threshold for 'correct' classification + enable_llm_judge: bool = True # LLM Judge drives Accuracy + Coverage evaluation # Concurrency max_concurrent: int = 3 @@ -126,7 +124,6 @@ def _float(key: str, default: float = 0.0) -> float: enable_dir_scan=_bool("FB_ENABLE_DIR_SCAN", True), eval_mode=_get("FB_EVAL_MODE", "singleDoc"), enable_llm_judge=_bool("FB_ENABLE_LLM_JUDGE", True), - extract_answer=_bool("FB_EXTRACT_ANSWER", True), max_concurrent=_int("FB_MAX_CONCURRENT", 3), request_delay=_float("FB_REQUEST_DELAY", 0.5), work_path=work_path, From eb43fdd6a718cf6d22b492a3b877d30f9e85ffdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 22:54:47 +0800 Subject: [PATCH 29/70] refactor doc extractor --- requirements/core.txt | 1 + src/sirchmunk/learnings/compiler.py | 13 +- src/sirchmunk/learnings/toc_extractor.py | 846 +++++++++++++++++----- src/sirchmunk/utils/document_extractor.py | 398 ++++++++++ src/sirchmunk/utils/file_utils.py | 26 +- 5 files changed, 1101 insertions(+), 183 deletions(-) create mode 100644 src/sirchmunk/utils/document_extractor.py diff --git a/requirements/core.txt b/requirements/core.txt index 1848a37..6cff25b 100644 --- a/requirements/core.txt +++ b/requirements/core.txt @@ -5,6 +5,7 @@ openai genson pillow pypdf +pdfminer.six pandas parquet numpy diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 2f8983a..4e31441 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -32,7 +32,8 @@ ) from sirchmunk.storage.knowledge_storage import KnowledgeStorage from sirchmunk.utils import LogCallback, create_logger -from sirchmunk.utils.file_utils import fast_extract, get_fast_hash +from sirchmunk.utils.document_extractor import DocumentExtractor +from sirchmunk.utils.file_utils import get_fast_hash # Concurrency cap for LLM-heavy file processing _DEFAULT_CONCURRENCY = 3 @@ -539,7 +540,9 @@ async def _compile_single_file( try: await self._log.info(f"[Compile] Processing: {Path(entry.path).name}") - extraction = await fast_extract(file_path=entry.path) + extraction = await DocumentExtractor.extract( + entry.path, DocumentExtractor.ENHANCED, + ) content = extraction.content if not content or len(content.strip()) < 100: result.error = "Insufficient text content" @@ -550,11 +553,13 @@ async def _compile_single_file( and DocumentTreeIndexer.should_build_tree(entry.path, len(content)) ) - # Phase 0.5: TOC extraction (zero LLM calls) + # Phase 0.5: TOC extraction (layers 1-3 are zero LLM calls) toc_entries = None if use_tree: from sirchmunk.learnings.toc_extractor import TOCExtractor - toc_entries = TOCExtractor.extract(entry.path, content) + toc_entries = await TOCExtractor.extract( + entry.path, content, + ) if toc_entries: await self._log.info( f"[Compile] Extracted TOC with {len(toc_entries)} entries " diff --git a/src/sirchmunk/learnings/toc_extractor.py b/src/sirchmunk/learnings/toc_extractor.py index 85f3b8e..1516cfd 100644 --- a/src/sirchmunk/learnings/toc_extractor.py +++ b/src/sirchmunk/learnings/toc_extractor.py @@ -1,220 +1,589 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ -TOC (Table of Contents) extractor — pure local operations, zero LLM calls. +TOC (Table of Contents) extractor — multi-layer fallback strategy. Extracts hierarchical table-of-contents structures from various document -formats (PDF, Markdown, DOCX, HTML) using native format features (bookmarks, -heading styles, heading tags). The extracted TOCEntry list is consumed by -the tree indexer to accelerate tree construction. +formats (PDF, Markdown, DOCX, HTML) using a layered approach: + + Layer 1 — pypdf native outline (highest confidence, zero cost) + Layer 2 — pdfminer.six detailed parsing (fallback for pypdf) + Layer 3 — Text heading pattern detection (for documents without bookmarks) + Layer 4 — LLM-assisted inference (optional, last resort) + +The extracted TOCEntry list is consumed by the tree indexer to accelerate +tree construction. """ +import json +import logging import re from dataclasses import dataclass, field from pathlib import Path -from typing import List, Optional +from typing import Any, ClassVar, List, Optional -# Minimum number of TOC entries required to form a meaningful structure -_MIN_TOC_ENTRIES = 3 +logger = logging.getLogger(__name__) # Known heading-style prefixes across locales (English, Chinese, etc.) _HEADING_STYLE_PREFIXES = ("Heading", "heading", "\u6807\u9898") # "标题" = Chinese +# --------------------------------------------------------------------------- +# Data models +# --------------------------------------------------------------------------- + + @dataclass class TOCEntry: - """Single entry in an extracted table of contents.""" + """Single entry in an extracted table of contents. + + Attributes: + title: Section title text. + level: Heading depth (1 = top-level section, 2 = subsection, …). + char_start: Character offset in the extracted full text. + char_end: End character offset (exclusive), or None if unresolved. + page_start: 1-indexed page number, or None if unknown. + page_end: End page number (inclusive), or None. + children: Nested sub-entries forming a tree. + source: Which extraction layer produced this entry + ("pypdf", "pdfminer", "heading", "markdown", "docx", + "html", "llm"). + """ title: str - level: int # 0=root, 1=section, 2=subsection - char_start: int # Character offset in extracted text + level: int # 1=section, 2=subsection, … + char_start: int = 0 char_end: Optional[int] = None page_start: Optional[int] = None page_end: Optional[int] = None children: List["TOCEntry"] = field(default_factory=list) + source: str = "" -class TOCExtractor: - """Extract TOC structure from documents using native format features. - - All methods are static — no instance state required. Each extraction - method handles one file format and returns a flat or nested list of - ``TOCEntry`` objects. The main ``extract()`` entry point dispatches - by file extension and resolves character positions against the - extracted text content. - - Design constraints: - - Pure local operations, zero LLM calls - - Exceptions handled internally; failure returns None +@dataclass +class TocResult: + """Complete TOC extraction result with quality metadata. + + Attributes: + entries: Ordered list of TOCEntry objects. + source: Primary extraction method that produced the result. + confidence: Estimated quality score (0.0–1.0). + page_count: Total pages in the source document, if known. """ - @staticmethod - def extract(file_path: str, content: str) -> Optional[List[TOCEntry]]: - """Main entry point: extract TOC entries from a file. - - Dispatches to format-specific extractors based on file extension, - then resolves character positions in the extracted text content. - - Args: - file_path: Absolute path to the source file. - content: Extracted text content of the file. + entries: List[TOCEntry] = field(default_factory=list) + source: str = "" + confidence: float = 0.0 + page_count: Optional[int] = None - Returns: - List of TOCEntry with resolved char positions, or None if - the file format is unsupported or fewer than _MIN_TOC_ENTRIES - entries are found. - """ - ext = Path(file_path).suffix.lower() - entries: Optional[List[TOCEntry]] = None - if ext == ".pdf": - entries = TOCExtractor._extract_pdf_toc(file_path) - elif ext in (".md", ".markdown"): - entries = TOCExtractor._extract_markdown_toc(content) - elif ext in (".docx",): - entries = TOCExtractor._extract_docx_toc(file_path) - elif ext in (".html", ".htm"): - entries = TOCExtractor._extract_html_toc(content) - else: - return None +# --------------------------------------------------------------------------- +# Layer 1: pypdf native outline +# --------------------------------------------------------------------------- - if not entries: - return None - # Flatten nested children for total count check - total = TOCExtractor._count_entries(entries) - if total < _MIN_TOC_ENTRIES: - return None +class PypdfOutlineExtractor: + """Layer 1: Extract TOC from PDF native outline/bookmarks using pypdf. - # Resolve character positions in extracted text - entries = TOCExtractor._resolve_char_positions(entries, content) - return entries + Highest confidence (0.9) — relies on the PDF producer embedding + explicit bookmarks. Zero external cost. + """ @staticmethod - def _extract_pdf_toc(file_path: str) -> Optional[List[TOCEntry]]: - """Extract TOC from PDF bookmarks/outline using pypdf. - - Recursively parses the nested bookmark structure from - ``PdfReader.outline``. + def extract(file_path: str | Path) -> TocResult: + """Extract TOC from PDF outline. Args: file_path: Path to the PDF file. Returns: - List of TOCEntry with page_start populated, or None on failure. + TocResult with entries and page_count populated, + or an empty TocResult on failure. """ try: from pypdf import PdfReader - reader = PdfReader(file_path) + reader = PdfReader(str(file_path)) outline = reader.outline + page_count = len(reader.pages) + if not outline: - return None + return TocResult(source="pypdf", page_count=page_count) entries: List[TOCEntry] = [] - TOCExtractor._parse_pdf_outline(reader, outline, entries, level=1) - return entries if entries else None - except Exception: - return None + PypdfOutlineExtractor._parse_outline( + reader, outline, entries, level=1, + ) + + if not entries: + return TocResult(source="pypdf", page_count=page_count) + + return TocResult( + entries=entries, + source="pypdf", + confidence=0.9, + page_count=page_count, + ) + except Exception as exc: + logger.debug("pypdf outline extraction failed: %s", exc) + return TocResult(source="pypdf") @staticmethod - def _parse_pdf_outline( - reader: "PdfReader", - outline_items: List, + def _parse_outline( + reader: Any, + outline_items: list, entries: List[TOCEntry], level: int, ) -> None: - """Recursively parse pypdf outline items into TOCEntry list. - - Args: - reader: PdfReader instance for page number resolution. - outline_items: Nested list of outline Destination objects. - entries: Accumulator list to append entries to. - level: Current nesting level (1=top-level section). - """ + """Recursively parse pypdf outline items into TOCEntry list.""" for item in outline_items: if isinstance(item, list): - # Nested list means sub-bookmarks — attach to last entry + # Nested list → sub-bookmarks; attach to last entry if entries: - sub_entries: List[TOCEntry] = [] - TOCExtractor._parse_pdf_outline( - reader, item, sub_entries, level=level + 1, + sub: List[TOCEntry] = [] + PypdfOutlineExtractor._parse_outline( + reader, item, sub, level=level + 1, ) - entries[-1].children.extend(sub_entries) + entries[-1].children.extend(sub) else: - TOCExtractor._parse_pdf_outline( + PypdfOutlineExtractor._parse_outline( reader, item, entries, level=level, ) else: - # Single bookmark destination try: title = item.title if hasattr(item, "title") else str(item) - page_num = None + page_num: Optional[int] = None try: - page_num = reader.get_destination_page_number(item) + # get_destination_page_number returns 0-indexed + raw = reader.get_destination_page_number(item) + if raw is not None: + page_num = raw + 1 # convert to 1-indexed except Exception: pass - entry = TOCEntry( + entries.append(TOCEntry( title=title.strip(), level=level, char_start=0, page_start=page_num, - ) - entries.append(entry) + source="pypdf", + )) except Exception: continue - @staticmethod - def _extract_markdown_toc(content: str) -> Optional[List[TOCEntry]]: - """Extract TOC from Markdown heading syntax (# / ## / ###). - Matches ATX-style headings: lines beginning with 1-6 '#' characters - followed by whitespace and the heading text. +# --------------------------------------------------------------------------- +# Layer 2: pdfminer.six detailed parsing +# --------------------------------------------------------------------------- + + +class PdfminerOutlineExtractor: + """Layer 2: Extract TOC using pdfminer.six for more detailed parsing. + + Falls back here when pypdf yields insufficient entries. + Confidence 0.85 — pdfminer exposes more detail but requires + manual page-number resolution. + """ + + @staticmethod + def extract(file_path: str | Path) -> TocResult: + """Extract TOC using pdfminer's outline parser. Args: - content: Markdown text content. + file_path: Path to the PDF file. Returns: - List of TOCEntry with level derived from '#' count, or None. + TocResult with entries populated, or empty on failure. """ try: - pattern = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE) - matches = pattern.findall(content) - if not matches: + from pdfminer.pdfdocument import PDFDocument, PDFNoOutlines + from pdfminer.pdfpage import PDFPage + from pdfminer.pdfparser import PDFParser + from pdfminer.psparser import LIT + + fp = open(str(file_path), "rb") + try: + parser = PDFParser(fp) + document = PDFDocument(parser) + + # Build page-object-id → 1-indexed page number mapping + pages = list(PDFPage.create_pages(document)) + page_count = len(pages) + objid_to_pagenum = { + page.pageid: idx + 1 + for idx, page in enumerate(pages) + } + + entries: List[TOCEntry] = [] + try: + for level, title, dest, action, _se in document.get_outlines(): + page_num = PdfminerOutlineExtractor._resolve_page( + dest, action, objid_to_pagenum, document, + ) + entries.append(TOCEntry( + title=str(title).strip() if title else "", + level=level, + char_start=0, + page_start=page_num, + source="pdfminer", + )) + except PDFNoOutlines: + pass + + if not entries: + return TocResult(source="pdfminer", page_count=page_count) + + return TocResult( + entries=entries, + source="pdfminer", + confidence=0.85, + page_count=page_count, + ) + finally: + fp.close() + except Exception as exc: + logger.debug("pdfminer outline extraction failed: %s", exc) + return TocResult(source="pdfminer") + + @staticmethod + def _resolve_page( + dest: Any, + action: Any, + objid_to_pagenum: dict, + document: Any, + ) -> Optional[int]: + """Resolve a pdfminer destination/action to a 1-indexed page number.""" + try: + from pdfminer.pdfparser import PDFStream + from pdfminer.pdftypes import resolve1 + + # Try dest first + target = dest + if target is None and action is not None: + # GoTo action: action dict may have a 'D' key + if isinstance(action, dict): + target = action.get("D") + + if target is None: return None - entries: List[TOCEntry] = [] - for hashes, title in matches: + # Resolve indirect objects + target = resolve1(target) + + if isinstance(target, list) and len(target) > 0: + page_ref = resolve1(target[0]) + if hasattr(page_ref, "objid"): + return objid_to_pagenum.get(page_ref.objid) + elif hasattr(target, "objid"): + return objid_to_pagenum.get(target.objid) + except Exception: + pass + return None + + +# --------------------------------------------------------------------------- +# Layer 3: Text heading pattern detection +# --------------------------------------------------------------------------- + + +class HeadingTocExtractor: + """Layer 3: Infer TOC from document text structure (heading patterns). + + Handles Markdown headings, numbered sections, and common structural + keywords. Confidence 0.6 — heuristic-based, lower precision. + """ + + # Regex for Markdown ATX headings: # Title, ## Subtitle, … + _MD_HEADING_RE: ClassVar[re.Pattern] = re.compile( + r"^(#{1,6})\s+(.+)$", re.MULTILINE, + ) + + # Regex for numbered section patterns: "1.", "1.1", "1.1.1", … + _NUMBERED_RE: ClassVar[re.Pattern] = re.compile( + r"^(\d+(?:\.\d+)*)[.\s]+(.+)$", re.MULTILINE, + ) + + # Common structural keywords (case-insensitive) + _STRUCTURAL_KEYWORDS: ClassVar[tuple] = ( + "ITEM", "PART", "CHAPTER", "SECTION", "ARTICLE", + "APPENDIX", "EXHIBIT", "SCHEDULE", "ANNEX", + ) + + # Max characters for a candidate heading line + _MAX_HEADING_LINE_LEN: ClassVar[int] = 120 + + @staticmethod + def extract(content: str, mime_type: str = "") -> TocResult: + """Infer TOC from text content by detecting heading patterns. + + Tries strategies in order: + 1. Markdown ATX headings (``#`` syntax) + 2. Numbered section patterns (``1.``, ``1.1``, …) + 3. Structural keyword detection (ITEM, PART, CHAPTER, …) + + Args: + content: Full extracted text of the document. + mime_type: Optional MIME type hint (unused currently). + + Returns: + TocResult with char_position-based entries. + """ + if not content or len(content.strip()) < 50: + return TocResult(source="heading") + + # Strategy 1: Markdown headings + entries = HeadingTocExtractor._extract_markdown_headings(content) + if entries: + return TocResult( + entries=entries, + source="heading", + confidence=0.7, + ) + + # Strategy 2: Numbered sections + entries = HeadingTocExtractor._extract_numbered_sections(content) + if entries: + return TocResult( + entries=entries, + source="heading", + confidence=0.6, + ) + + # Strategy 3: Structural keywords + heuristic + entries = HeadingTocExtractor._extract_structural_headings(content) + if entries: + return TocResult( + entries=entries, + source="heading", + confidence=0.5, + ) + + return TocResult(source="heading") + + @staticmethod + def _extract_markdown_headings(content: str) -> List[TOCEntry]: + """Extract headings from Markdown ATX syntax (# / ## / ###).""" + matches = list(HeadingTocExtractor._MD_HEADING_RE.finditer(content)) + if not matches: + return [] + + entries: List[TOCEntry] = [] + for m in matches: + hashes, title = m.group(1), m.group(2).strip() + if title: entries.append(TOCEntry( - title=title.strip(), + title=title, level=len(hashes), - char_start=0, + char_start=m.start(), + source="heading", )) - return entries if entries else None - except Exception: - return None + return entries + + @staticmethod + def _extract_numbered_sections(content: str) -> List[TOCEntry]: + """Extract numbered section headings (1., 1.1, 1.1.1, …).""" + matches = list(HeadingTocExtractor._NUMBERED_RE.finditer(content)) + if not matches: + return [] + + entries: List[TOCEntry] = [] + for m in matches: + number_part = m.group(1) + title_part = m.group(2).strip() + # Line length check — skip long lines (likely not headings) + line_len = m.end() - m.start() + if line_len > HeadingTocExtractor._MAX_HEADING_LINE_LEN: + continue + if not title_part: + continue + level = number_part.count(".") + 1 + entries.append(TOCEntry( + title=f"{number_part} {title_part}", + level=level, + char_start=m.start(), + source="heading", + )) + return entries + + @staticmethod + def _extract_structural_headings(content: str) -> List[TOCEntry]: + """Detect common structural keywords as section boundaries.""" + # Build pattern: ITEM 1, PART I, CHAPTER 1, etc. + kw_pattern = "|".join(HeadingTocExtractor._STRUCTURAL_KEYWORDS) + pattern = re.compile( + rf"^({kw_pattern})\s+(\w+[\w .:\-]*)$", + re.MULTILINE | re.IGNORECASE, + ) + matches = list(pattern.finditer(content)) + if not matches: + return [] + + entries: List[TOCEntry] = [] + for m in matches: + keyword = m.group(1).upper() + rest = m.group(2).strip() + title = f"{keyword} {rest}" + # Determine level based on keyword + if keyword in ("PART", "CHAPTER"): + level = 1 + elif keyword in ("ITEM", "SECTION", "ARTICLE"): + level = 2 + else: + level = 3 + entries.append(TOCEntry( + title=title, + level=level, + char_start=m.start(), + source="heading", + )) + return entries + + +# --------------------------------------------------------------------------- +# Layer 4: LLM-assisted inference (optional) +# --------------------------------------------------------------------------- + + +class LlmTocExtractor: + """Layer 4: Use LLM to infer TOC from document content. + + This is the last-resort fallback. Requires an ``llm_caller`` that + supports ``await llm_caller.achat(messages)``. If no caller is + provided, returns an empty result immediately. + + Confidence 0.7 — LLM may hallucinate structure. + """ + + # Maximum characters sent to the LLM to stay within token limits + _MAX_CONTENT_CHARS: ClassVar[int] = 8_000 + + _PROMPT_TEMPLATE: ClassVar[str] = ( + "Analyze the following document excerpt and extract its " + "hierarchical table of contents (TOC) structure.\n\n" + "Return a JSON array where each element has:\n" + ' - "title": section title text\n' + ' - "level": integer heading depth (1=top, 2=sub, 3=subsub)\n\n' + "Only include actual section/chapter headings, not every paragraph.\n" + "Return ONLY the JSON array, no other text.\n\n" + "Document excerpt:\n---\n{content}\n---" + ) + + @staticmethod + async def extract( + content: str, + llm_caller: Any | None = None, + ) -> TocResult: + """Infer TOC using LLM analysis. + + Args: + content: Full extracted text of the document. + llm_caller: An object with ``achat(messages)`` method. + If None, returns an empty result. + + Returns: + TocResult with LLM-inferred entries. + """ + if llm_caller is None: + return TocResult(source="llm") + + if not content or len(content.strip()) < 100: + return TocResult(source="llm") + + try: + # Truncate content to fit token budget + truncated = content[:LlmTocExtractor._MAX_CONTENT_CHARS] + prompt = LlmTocExtractor._PROMPT_TEMPLATE.format(content=truncated) + + resp = await llm_caller.achat([{"role": "user", "content": prompt}]) + raw = resp.content.strip() + + entries = LlmTocExtractor._parse_response(raw, content) + if not entries: + return TocResult(source="llm") + + return TocResult( + entries=entries, + source="llm", + confidence=0.7, + ) + except Exception as exc: + logger.debug("LLM TOC extraction failed: %s", exc) + return TocResult(source="llm") @staticmethod - def _extract_docx_toc(file_path: str) -> Optional[List[TOCEntry]]: - """Extract TOC from DOCX heading styles using python-docx. + def _parse_response(raw: str, content: str) -> List[TOCEntry]: + """Parse LLM JSON response into TOCEntry list with char_positions.""" + # Strip markdown code fences if present + cleaned = raw.strip() + if cleaned.startswith("```"): + lines = cleaned.split("\n") + # Remove first and last fence lines + lines = [l for l in lines if not l.strip().startswith("```")] + cleaned = "\n".join(lines) + + try: + items = json.loads(cleaned) + except (json.JSONDecodeError, TypeError): + return [] + + if not isinstance(items, list): + return [] + + content_lower = content.lower() + search_from = 0 + entries: List[TOCEntry] = [] + + for item in items: + if not isinstance(item, dict): + continue + title = str(item.get("title", "")).strip() + level = int(item.get("level", 1)) + if not title: + continue + + # Try to locate title in content for char_position + pos = content_lower.find(title.lower(), search_from) + if pos >= 0: + char_start = pos + search_from = pos + len(title) + else: + # Fallback: try from beginning + pos = content_lower.find(title.lower()) + char_start = pos if pos >= 0 else search_from + + entries.append(TOCEntry( + title=title, + level=max(1, min(level, 6)), + char_start=char_start, + source="llm", + )) + + return entries - Reads paragraphs with heading style names (English ``Heading``, - Chinese ``\u6807\u9898``, etc.), extracting the heading level from the style - name suffix (e.g., ``Heading 1`` -> level 1). + +# --------------------------------------------------------------------------- +# Format-specific extractors (non-PDF) +# --------------------------------------------------------------------------- + + +class DocxTocExtractor: + """Extract TOC from DOCX heading styles using python-docx.""" + + @staticmethod + def extract(file_path: str | Path) -> TocResult: + """Extract TOC from DOCX heading styles. Args: file_path: Path to the DOCX file. Returns: - List of TOCEntry with level from heading style, or None. + TocResult with entries from heading styles. """ try: import docx - doc = docx.Document(file_path) + doc = docx.Document(str(file_path)) entries: List[TOCEntry] = [] for para in doc.paragraphs: style_name = para.style.name or "" - # Match heading styles across locales ("Heading 1", "标题 1", etc.) matched_prefix = "" for prefix in _HEADING_STYLE_PREFIXES: if style_name.startswith(prefix): @@ -233,47 +602,213 @@ def _extract_docx_toc(file_path: str) -> Optional[List[TOCEntry]]: title=title, level=level, char_start=0, + source="docx", )) - return entries if entries else None - except Exception: - return None - @staticmethod - def _extract_html_toc(content: str) -> Optional[List[TOCEntry]]: - """Extract TOC from HTML heading tags (

through

). + if not entries: + return TocResult(source="docx") + return TocResult(entries=entries, source="docx", confidence=0.85) + except Exception as exc: + logger.debug("DOCX TOC extraction failed: %s", exc) + return TocResult(source="docx") + - Uses regex to match heading tags and strips inner HTML tags - from the title text. +class HtmlTocExtractor: + """Extract TOC from HTML heading tags (

).""" + + _HTML_HEADING_RE: ClassVar[re.Pattern] = re.compile( + r"]*>(.*?)", + re.IGNORECASE | re.DOTALL, + ) + + @staticmethod + def extract(content: str) -> TocResult: + """Extract TOC from HTML heading tags. Args: content: HTML text content. Returns: - List of TOCEntry with level from tag number, or None. + TocResult with entries from

tags. """ try: - pattern = re.compile( - r"]*>(.*?)", - re.IGNORECASE | re.DOTALL, - ) - matches = pattern.findall(content) + matches = HtmlTocExtractor._HTML_HEADING_RE.findall(content) if not matches: - return None + return TocResult(source="html") entries: List[TOCEntry] = [] for level_str, raw_title in matches: - # Strip HTML tags from title title = re.sub(r"<[^>]+>", "", raw_title).strip() if title: entries.append(TOCEntry( title=title, level=int(level_str), char_start=0, + source="html", )) - return entries if entries else None - except Exception: + + if not entries: + return TocResult(source="html") + return TocResult(entries=entries, source="html", confidence=0.8) + except Exception as exc: + logger.debug("HTML TOC extraction failed: %s", exc) + return TocResult(source="html") + + +# --------------------------------------------------------------------------- +# Orchestrator: multi-layer fallback +# --------------------------------------------------------------------------- + + +class TOCExtractor: + """Orchestrates multi-layer TOC extraction with fallback strategy. + + All methods are static/classmethod — no instance state required. + The main ``extract()`` entry point dispatches by file extension and + applies the layered fallback for PDF files. + + Layer priority for PDFs: + 1. pypdf native outline (confidence 0.9) + 2. pdfminer.six detailed parsing (confidence 0.85) + 3. Text heading detection (confidence 0.5–0.7) + 4. LLM-assisted inference (confidence 0.7, optional) + + Design constraints: + - Layers 1–3 are pure-local, zero LLM calls + - Layer 4 is optional (requires llm_caller) + - Each layer is independently try-excepted; failure never blocks + subsequent layers + """ + + # Minimum entries to consider a TOC extraction successful + _MIN_ENTRIES_THRESHOLD: ClassVar[int] = 3 + + @classmethod + async def extract( + cls, + file_path: str, + content: str, + *, + llm_caller: Any | None = None, + ) -> Optional[List[TOCEntry]]: + """Extract TOC using layered fallback strategy. + + Tries extraction methods in order of reliability. Falls back to + the next layer when the current layer yields fewer than + ``_MIN_ENTRIES_THRESHOLD`` entries. + + Args: + file_path: Absolute path to the source file. + content: Extracted text content of the file. + llm_caller: Optional LLM caller for Layer 4. + + Returns: + List of TOCEntry with resolved char positions, or None if + no layer produced enough entries. + """ + ext = Path(file_path).suffix.lower() + + result: Optional[TocResult] = None + + if ext == ".pdf": + result = await cls._extract_pdf_layered( + file_path, content, llm_caller, + ) + elif ext in (".md", ".markdown"): + heading_result = HeadingTocExtractor.extract(content) + if cls._is_sufficient(heading_result): + result = heading_result + elif ext in (".docx",): + result = DocxTocExtractor.extract(file_path) + elif ext in (".html", ".htm"): + result = HtmlTocExtractor.extract(content) + else: + return None + + if result is None or not cls._is_sufficient(result): return None + entries = result.entries + total = cls._count_entries(entries) + if total < cls._MIN_ENTRIES_THRESHOLD: + return None + + # Resolve character positions in the extracted text + entries = cls._resolve_char_positions(entries, content) + return entries + + @classmethod + async def _extract_pdf_layered( + cls, + file_path: str, + content: str, + llm_caller: Any | None, + ) -> Optional[TocResult]: + """Apply layered extraction for PDF files. + + Args: + file_path: Path to the PDF file. + content: Extracted text content. + llm_caller: Optional LLM caller for Layer 4. + + Returns: + Best TocResult from the layer cascade, or None. + """ + # Layer 1: pypdf + result = PypdfOutlineExtractor.extract(file_path) + if cls._is_sufficient(result): + logger.info( + "TOC Layer 1 (pypdf): %d entries for %s", + len(result.entries), Path(file_path).name, + ) + return result + + # Layer 2: pdfminer.six + result = PdfminerOutlineExtractor.extract(file_path) + if cls._is_sufficient(result): + logger.info( + "TOC Layer 2 (pdfminer): %d entries for %s", + len(result.entries), Path(file_path).name, + ) + return result + + # Layer 3: heading detection from content + if content: + result = HeadingTocExtractor.extract(content) + if cls._is_sufficient(result): + logger.info( + "TOC Layer 3 (heading): %d entries for %s", + len(result.entries), Path(file_path).name, + ) + return result + + # Layer 4: LLM-assisted (optional) + if llm_caller is not None and content: + result = await LlmTocExtractor.extract(content, llm_caller) + if cls._is_sufficient(result): + logger.info( + "TOC Layer 4 (LLM): %d entries for %s", + len(result.entries), Path(file_path).name, + ) + return result + + logger.debug( + "TOC extraction: no layer produced sufficient entries for %s", + Path(file_path).name, + ) + return None + + @classmethod + def _is_sufficient(cls, result: Optional[TocResult]) -> bool: + """Check whether a TocResult has enough entries to be useful.""" + if result is None: + return False + return len(result.entries) >= cls._MIN_ENTRIES_THRESHOLD + + # ------------------------------------------------------------------ # + # Character position resolution # + # ------------------------------------------------------------------ # + @staticmethod def _resolve_char_positions( entries: List[TOCEntry], @@ -311,26 +846,21 @@ def _resolve_char_positions( if not title_lower: entry.char_start = search_from continue - # Normalise whitespace for fuzzy matching (PDF extracts may - # insert extra spaces inside headings). + # Normalise whitespace for fuzzy matching title_normalised = re.sub(r"\s+", " ", title_lower) pos = content_lower.find(title_normalised, search_from) if pos < 0: - # Retry with the original (un-normalised) title pos = content_lower.find(title_lower, search_from) if pos >= 0: entry.char_start = pos search_from = pos + len(title_lower) else: - # Title not found after search_from; try from beginning pos = content_lower.find(title_normalised) if pos < 0: pos = content_lower.find(title_lower) if pos >= 0: entry.char_start = pos - # Do NOT reset search_from to avoid breaking order else: - # Last resort: place at current search frontier entry.char_start = search_from # Pass 2: resolve char_end as start of next entry (or len(content)) @@ -346,12 +876,7 @@ def _flatten_entries( entries: List[TOCEntry], flat: List[TOCEntry], ) -> None: - """Flatten nested TOCEntry tree into document-order list. - - Args: - entries: Nested entry list. - flat: Accumulator for flattened output. - """ + """Flatten nested TOCEntry tree into document-order list.""" for entry in entries: flat.append(entry) if entry.children: @@ -359,30 +884,7 @@ def _flatten_entries( @staticmethod def _count_entries(entries: List[TOCEntry]) -> int: - """Count total entries including nested children. - - Args: - entries: Nested entry list. - - Returns: - Total number of entries in the tree. - """ - count = 0 - for entry in entries: - count += 1 - if entry.children: - count += TOCExtractor._count_entries(entry.children) - return count - @staticmethod - def _count_entries(entries: List[TOCEntry]) -> int: - """Count total entries including nested children. - - Args: - entries: Nested entry list. - - Returns: - Total number of entries in the tree. - """ + """Count total entries including nested children.""" count = 0 for entry in entries: count += 1 diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py new file mode 100644 index 0000000..76e0f15 --- /dev/null +++ b/src/sirchmunk/utils/document_extractor.py @@ -0,0 +1,398 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Unified document extraction facade over kreuzberg. + +Centralizes all kreuzberg interaction into a single module, providing a clean, +configurable interface for document text extraction with support for tables, +metadata, language detection, OCR, and page-range filtering. + +All other modules should import from here rather than from kreuzberg directly. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, ClassVar, List, Optional, Sequence, Union + +from loguru import logger + + +# --------------------------------------------------------------------------- +# Configuration profile +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class ExtractionProfile: + """Immutable extraction configuration profile. + + Controls which kreuzberg features are enabled during document extraction. + Default values align with the legacy ``fast_extract()`` behavior + (plain text only, no extras). + """ + + output_format: str = "plain" + """Output format: ``plain`` | ``markdown`` | ``html`` | ``djot``.""" + + extract_tables: bool = False + """Whether to extract and return tables.""" + + extract_metadata: bool = False + """Whether to return document metadata.""" + + detect_language: bool = False + """Whether to detect document language.""" + + ocr_enabled: bool = False + """Whether to enable OCR fallback.""" + + ocr_backend: str = "tesseract" + """OCR engine: ``tesseract`` | ``easyocr`` | ``paddleocr``.""" + + ocr_language: str = "eng" + """OCR language code (e.g. ``eng``, ``chi_sim``).""" + + page_start: Optional[int] = None + """Page range start (0-indexed). ``None`` means first page.""" + + page_end: Optional[int] = None + """Page range end (inclusive). ``None`` means last page.""" + + pdf_extract_images: bool = False + """Extract images embedded in PDF pages.""" + + pdf_extract_metadata: bool = False + """Extract PDF-level metadata (author, title, etc.).""" + + force_ocr: bool = False + """Force OCR for all pages, bypassing native text extraction. + + Maps directly to kreuzberg's ``ExtractionConfig.force_ocr``. + Note: kreuzberg does not offer a "fallback" OCR mode — + when set, OCR is always applied regardless of text layer presence. + """ + + pdf_password: Optional[str] = None + """Password for encrypted PDFs.""" + + max_concurrent: Optional[int] = None + """Max concurrency for batch extraction.""" + + +# --------------------------------------------------------------------------- +# Extraction output +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class ExtractionOutput: + """Structured extraction result. + + Always contains ``content``. Other fields are populated based on the + :class:`ExtractionProfile` settings used during extraction. + """ + + content: str + """Extracted text content.""" + + mime_type: str = "" + """MIME type of the source document.""" + + metadata: dict[str, Any] = field(default_factory=dict) + """Document metadata (empty when ``extract_metadata`` is disabled).""" + + tables: list[dict[str, Any]] = field(default_factory=list) + """Extracted tables (empty when ``extract_tables`` is disabled).""" + + detected_languages: dict[str, float] = field(default_factory=dict) + """Language → confidence mapping (empty when ``detect_language`` is disabled).""" + + page_count: Optional[int] = None + """Number of pages in the source document (if available).""" + + +# --------------------------------------------------------------------------- +# Document extractor facade +# --------------------------------------------------------------------------- + +class DocumentExtractor: + """Unified document extraction facade over kreuzberg. + + Provides a clean, configurable interface for document text extraction, + centralizing all kreuzberg interaction within a single module. + + Usage:: + + # Basic extraction (identical to legacy fast_extract) + result = await DocumentExtractor.extract(path) + + # Enhanced extraction with tables and metadata + result = await DocumentExtractor.extract(path, DocumentExtractor.ENHANCED) + + # Custom profile + profile = ExtractionProfile(output_format="markdown", extract_tables=True) + result = await DocumentExtractor.extract(path, profile) + """ + + # Pre-defined profiles ------------------------------------------------- + + BASIC: ClassVar[ExtractionProfile] = ExtractionProfile() + """Plain-text extraction only — equivalent to legacy ``fast_extract()``.""" + + ENHANCED: ClassVar[ExtractionProfile] = ExtractionProfile( + output_format="markdown", + extract_tables=True, + extract_metadata=True, + pdf_extract_metadata=True, + force_ocr=True, + ) + """Rich extraction with tables, metadata, and OCR fallback.""" + + # Public API ----------------------------------------------------------- + + @staticmethod + async def extract( + file_path: Union[str, Path], + profile: Optional[ExtractionProfile] = None, + ) -> ExtractionOutput: + """Extract content from a single file. + + Args: + file_path: Path to the document. + profile: Extraction profile. Defaults to :attr:`BASIC`. + + Returns: + :class:`ExtractionOutput` with at least ``content`` populated. + + Raises: + FileNotFoundError: If *file_path* does not exist. + Exception: Propagates kreuzberg extraction errors after logging. + """ + from kreuzberg import extract_file + + profile = profile or DocumentExtractor.BASIC + config = DocumentExtractor._build_config(profile) + + try: + result = await extract_file(file_path=file_path, config=config) + return DocumentExtractor._convert_result(result, profile) + except Exception as exc: + logger.error( + "Document extraction failed for {}: {}", + file_path, + exc, + ) + raise + + @staticmethod + async def extract_bytes( + data: bytes, + mime_type: str, + profile: Optional[ExtractionProfile] = None, + ) -> ExtractionOutput: + """Extract content from raw bytes. + + Args: + data: File content as bytes. + mime_type: MIME type of the data (required for format detection). + profile: Extraction profile. Defaults to :attr:`BASIC`. + + Returns: + :class:`ExtractionOutput`. + """ + from kreuzberg import extract_bytes as _extract_bytes + + profile = profile or DocumentExtractor.BASIC + config = DocumentExtractor._build_config(profile) + + try: + result = await _extract_bytes(data=data, mime_type=mime_type, config=config) + return DocumentExtractor._convert_result(result, profile) + except Exception: + logger.error("Byte extraction failed for mime_type={}", mime_type) + raise + + @staticmethod + async def batch_extract( + file_paths: Sequence[Union[str, Path]], + profile: Optional[ExtractionProfile] = None, + ) -> List[ExtractionOutput]: + """Extract content from multiple files in parallel. + + Args: + file_paths: Sequence of document paths. + profile: Extraction profile. Defaults to :attr:`BASIC`. + + Returns: + List of :class:`ExtractionOutput`, one per input path. + """ + from kreuzberg import batch_extract_files + + profile = profile or DocumentExtractor.BASIC + config = DocumentExtractor._build_config(profile) + + try: + results = await batch_extract_files(paths=list(file_paths), config=config) + return [ + DocumentExtractor._convert_result(r, profile) for r in results + ] + except Exception: + logger.error("Batch extraction failed for {} files", len(file_paths)) + raise + + # Internal helpers ----------------------------------------------------- + + @staticmethod + def _build_config(profile: ExtractionProfile): + """Build a kreuzberg ``ExtractionConfig`` from an :class:`ExtractionProfile`. + + Maps profile fields to the kreuzberg configuration objects that are + actually available in the installed version. + """ + from kreuzberg import ( + ExtractionConfig, + OcrConfig, + OutputFormat, + PageConfig, + PdfConfig, + ) + + # --- Output format --- + format_map = { + "plain": OutputFormat.PLAIN, + "markdown": OutputFormat.MARKDOWN, + "html": OutputFormat.HTML, + "djot": OutputFormat.DJOT, + } + output_format = format_map.get(profile.output_format, OutputFormat.PLAIN) + + # --- OCR config --- + ocr_config: Optional[OcrConfig] = None + if profile.ocr_enabled: + ocr_config = OcrConfig( + backend=profile.ocr_backend, + language=profile.ocr_language, + ) + + # --- Page config --- + page_config: Optional[PageConfig] = None + if profile.page_start is not None or profile.page_end is not None: + # kreuzberg PageConfig.extract_pages expects a list of page indices + pages: Optional[list[int]] = None + if profile.page_start is not None: + end = profile.page_end if profile.page_end is not None else profile.page_start + pages = list(range(profile.page_start, end + 1)) + page_config = PageConfig(extract_pages=pages) + + # --- PDF config --- + pdf_config: Optional[PdfConfig] = None + if any([ + profile.pdf_extract_images, + profile.pdf_extract_metadata, + profile.pdf_password, + ]): + passwords = [profile.pdf_password] if profile.pdf_password else None + pdf_config = PdfConfig( + extract_images=profile.pdf_extract_images, + extract_metadata=profile.pdf_extract_metadata, + passwords=passwords, + ) + + # --- Language detection --- + lang_config = None + if profile.detect_language: + from kreuzberg import LanguageDetectionConfig + lang_config = LanguageDetectionConfig(enabled=True) + + # --- Assemble ExtractionConfig --- + kwargs: dict[str, Any] = { + "output_format": output_format, + } + if ocr_config is not None: + kwargs["ocr"] = ocr_config + if profile.force_ocr: + kwargs["force_ocr"] = True + if page_config is not None: + kwargs["pages"] = page_config + if pdf_config is not None: + kwargs["pdf_options"] = pdf_config + if lang_config is not None: + kwargs["language_detection"] = lang_config + if profile.max_concurrent is not None: + kwargs["max_concurrent_extractions"] = profile.max_concurrent + + return ExtractionConfig(**kwargs) + + @staticmethod + def _convert_result( + result: "ExtractionResult", + profile: ExtractionProfile, + ) -> ExtractionOutput: + """Convert a kreuzberg ``ExtractionResult`` to :class:`ExtractionOutput`. + + Only populates optional fields when the corresponding profile flag is + enabled, keeping the output lean for basic extraction. + """ + content: str = result.content or "" + mime_type: str = getattr(result, "mime_type", "") or "" + + # Metadata + metadata: dict[str, Any] = {} + if profile.extract_metadata: + raw_meta = getattr(result, "metadata", None) + if raw_meta is not None: + if isinstance(raw_meta, dict): + metadata = dict(raw_meta) + else: + # kreuzberg may return a non-dict metadata object + try: + metadata = dict(raw_meta) + except (TypeError, ValueError): + metadata = {"raw": str(raw_meta)} + + # Tables + tables: list[dict[str, Any]] = [] + if profile.extract_tables: + raw_tables = getattr(result, "tables", None) or [] + for t in raw_tables: + if isinstance(t, dict): + tables.append(t) + else: + # kreuzberg ExtractedTable has: cells, markdown, page_number + tables.append({ + "markdown": getattr(t, "markdown", ""), + "cells": getattr(t, "cells", []), + "page_number": getattr(t, "page_number", None), + }) + + # Language detection + detected_languages: dict[str, float] = {} + if profile.detect_language: + raw_langs = getattr(result, "detected_languages", None) + if raw_langs: + for entry in raw_langs: + if isinstance(entry, dict): + lang = entry.get("language", "") + conf = entry.get("confidence", 0.0) + else: + # kreuzberg DetectedLanguage object + lang = getattr(entry, "language", "") + conf = getattr(entry, "confidence", 0.0) + if lang: + detected_languages[lang] = float(conf) + + # Page count — prefer get_page_count() over get_chunk_count() + page_count: Optional[int] = None + get_page_count = getattr(result, "get_page_count", None) + if get_page_count and callable(get_page_count): + cnt = get_page_count() + if cnt is not None and cnt > 0: + page_count = cnt + + return ExtractionOutput( + content=content, + mime_type=mime_type, + metadata=metadata, + tables=tables, + detected_languages=detected_languages, + page_count=page_count, + ) diff --git a/src/sirchmunk/utils/file_utils.py b/src/sirchmunk/utils/file_utils.py index edbbc2d..df308fd 100644 --- a/src/sirchmunk/utils/file_utils.py +++ b/src/sirchmunk/utils/file_utils.py @@ -4,17 +4,29 @@ from pathlib import Path from typing import Union -from kreuzberg import ExtractionResult, extract_file from loguru import logger +from sirchmunk.utils.document_extractor import ( + DocumentExtractor, + ExtractionOutput, +) -async def fast_extract(file_path: Union[str, Path]) -> ExtractionResult: - """ - Automatically detects and extracts text content from various file formats like docx, pptx, pdf, xlsx. - """ - result: ExtractionResult = await extract_file(file_path=file_path) - return result +async def fast_extract(file_path: Union[str, Path]) -> ExtractionOutput: + """Extract text content from a file using kreuzberg. + + This is a backward-compatible wrapper around + :meth:`DocumentExtractor.extract` with the ``BASIC`` profile + (plain text, no extras). All callers that only need ``.content`` + continue to work unchanged. + + Args: + file_path: Path to the file to extract. + + Returns: + :class:`ExtractionOutput` with ``.content`` populated. + """ + return await DocumentExtractor.extract(file_path) def get_fast_hash(file_path: Union[str, Path], sample_size: int = 8192): From b2c26bbc5321d5f7f4506cafd3431af6d74719d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 19 Apr 2026 20:58:23 +0800 Subject: [PATCH 30/70] enhance compiler for tree indexing --- config/env.example | 5 + src/sirchmunk/learnings/compiler.py | 5 +- src/sirchmunk/learnings/toc_extractor.py | 71 ++++++ src/sirchmunk/learnings/tree_indexer.py | 118 ++++++++-- src/sirchmunk/search.py | 262 ++++++++++++++++------- 5 files changed, 354 insertions(+), 107 deletions(-) diff --git a/config/env.example b/config/env.example index 8272d03..4b8dcd7 100644 --- a/config/env.example +++ b/config/env.example @@ -126,3 +126,8 @@ SIRCHMUNK_DEBUG=false # Maximum concurrent WebSocket connections (default: 100) SIRCHMUNK_MAX_WS_CONNECTIONS=100 + +# ===== Ablation Experiment Settings ===== +# Pure tree search mode (ablation experiment, default: false) +# When enabled, search relies solely on tree index navigation, skipping rga keyword search +# SIRCHMUNK_PURE_TREE_SEARCH=false diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 4e31441..ad2a115 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -559,6 +559,7 @@ async def _compile_single_file( from sirchmunk.learnings.toc_extractor import TOCExtractor toc_entries = await TOCExtractor.extract( entry.path, content, + total_pages=extraction.page_count, ) if toc_entries: await self._log.info( @@ -568,7 +569,9 @@ async def _compile_single_file( if use_tree: result.tree = await self._tree_indexer.build_tree( - entry.path, content, toc_entries=toc_entries, + entry.path, content, + toc_entries=toc_entries, + total_pages=extraction.page_count, ) # Record TOC / tree metrics on the result for manifest persistence diff --git a/src/sirchmunk/learnings/toc_extractor.py b/src/sirchmunk/learnings/toc_extractor.py index 1516cfd..0197485 100644 --- a/src/sirchmunk/learnings/toc_extractor.py +++ b/src/sirchmunk/learnings/toc_extractor.py @@ -683,6 +683,46 @@ class TOCExtractor: # Minimum entries to consider a TOC extraction successful _MIN_ENTRIES_THRESHOLD: ClassVar[int] = 3 + @staticmethod + def _build_hierarchy(flat_entries: List["TOCEntry"]) -> List["TOCEntry"]: + """Convert flat TocEntry list to nested tree using level field. + + Uses stack-based algorithm, O(n). When encountering a deeper level + entry, push it as a child of the current stack top; when same or + shallower, pop back to the corresponding level. + + Args: + flat_entries: Flat list of TOCEntry objects with ``level`` set. + + Returns: + List of top-level TOCEntry objects with ``children`` populated. + """ + if not flat_entries: + return [] + + roots: List[TOCEntry] = [] + # Stack holds (level, entry) pairs representing the current path + stack: List[TOCEntry] = [] + + for entry in flat_entries: + # Reset children to avoid stale data from prior processing + entry.children = [] + + # Pop stack until we find the parent (shallower level) + while stack and stack[-1].level >= entry.level: + stack.pop() + + if stack: + # Attach as child of the current stack top + stack[-1].children.append(entry) + else: + # No parent — this is a root-level entry + roots.append(entry) + + stack.append(entry) + + return roots + @classmethod async def extract( cls, @@ -690,6 +730,7 @@ async def extract( content: str, *, llm_caller: Any | None = None, + total_pages: Optional[int] = None, ) -> Optional[List[TOCEntry]]: """Extract TOC using layered fallback strategy. @@ -701,6 +742,8 @@ async def extract( file_path: Absolute path to the source file. content: Extracted text content of the file. llm_caller: Optional LLM caller for Layer 4. + total_pages: Total page count of the source document, if known. + Used to estimate ``page_start`` for Layer 3/4 entries. Returns: List of TOCEntry with resolved char positions, or None if @@ -709,11 +752,16 @@ async def extract( ext = Path(file_path).suffix.lower() result: Optional[TocResult] = None + # Track whether the result came from pypdf (Layer 1) which + # already produces a properly nested tree with children. + is_pypdf = False if ext == ".pdf": result = await cls._extract_pdf_layered( file_path, content, llm_caller, ) + if result is not None: + is_pypdf = result.source == "pypdf" elif ext in (".md", ".markdown"): heading_result = HeadingTocExtractor.extract(content) if cls._is_sufficient(heading_result): @@ -728,7 +776,30 @@ async def extract( if result is None or not cls._is_sufficient(result): return None + # Merge total_pages from TocResult if not explicitly provided + if total_pages is None and result.page_count: + total_pages = result.page_count + entries = result.entries + + # Post-processing for non-pypdf layers: rebuild hierarchy from + # flat level-annotated entries (Layer 2/3/4 and format extractors + # produce flat lists; pypdf already builds a nested tree). + if not is_pypdf: + entries = cls._build_hierarchy(entries) + + # Estimate page_start for Layer 3/4 entries that lack it + if total_pages and content: + flat_all: List[TOCEntry] = [] + cls._flatten_entries(entries, flat_all) + content_len = len(content) + for entry in flat_all: + if entry.page_start is None and entry.char_start is not None: + entry.page_start = min( + total_pages, + max(1, round(entry.char_start / content_len * total_pages) + 1), + ) + total = cls._count_entries(entries) if total < cls._MIN_ENTRIES_THRESHOLD: return None diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 26787eb..8d93a2c 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -32,6 +32,12 @@ # Summary snippet length extracted from section content (chars) _TOC_NODE_SUMMARY_MAX_CHARS = 300 +# Marker substring length for fuzzy fallback matching in _resolve_positions +_MARKER_SUBSTRING_LEN = 32 + +# Maximum span ratio: filter out overly large spans (>80% of document) +_MAX_SPAN_RATIO = 0.8 + # Adaptive preview window for LLM structure analysis _TREE_PREVIEW_MIN = 12_000 # Minimum preview window (chars) _TREE_PREVIEW_MAX = 50_000 # Maximum preview window (~12K tokens) @@ -201,7 +207,9 @@ async def build_tree( # TOC-accelerated path: skip recursive LLM analysis if toc_entries: - root = await self._build_tree_from_toc(toc_entries, content) + root = await self._build_tree_from_toc( + toc_entries, content, total_pages=total_pages, + ) if root is not None: tree = DocumentTree( file_path=file_path, @@ -300,6 +308,8 @@ async def _build_tree_from_toc( self, toc_entries: List[Any], content: str, + *, + total_pages: Optional[int] = None, ) -> Optional[TreeNode]: """Build tree directly from extracted TOC entries, avoiding recursive LLM. @@ -309,25 +319,29 @@ async def _build_tree_from_toc( Args: toc_entries: List of TOCEntry from toc_extractor. content: Full extracted text of the document. + total_pages: Total page count for page_range calculation. Returns: Root TreeNode, or None if no children could be created. """ seen_ids: set = set() children = self._toc_entries_to_nodes( - toc_entries, content, len(content), seen_ids, fallback_level=1, + toc_entries, content, len(content), seen_ids, + fallback_level=1, total_pages=total_pages, ) if not children: return None root_summary = await self._synthesize_root_summary(children) + root_page_range = (1, total_pages) if total_pages and total_pages > 0 else None return TreeNode( node_id=self._unique_node_id(0, seen_ids), title="Document", summary=root_summary, char_range=(0, len(content)), level=0, + page_range=root_page_range, children=children, ) @@ -338,15 +352,25 @@ def _toc_entries_to_nodes( parent_end: int, seen_ids: set, fallback_level: int, + total_pages: Optional[int] = None, ) -> List["TreeNode"]: """Recursively convert TOCEntry trees into TreeNode trees. Handles arbitrary nesting depth and guards against invalid - char_start / char_end values. + char_start / char_end values. Computes ``page_range`` using a + look-ahead algorithm when ``page_start`` is available on entries. + + Args: + entries: List of TOCEntry objects (may have children). + content: Full extracted text. + parent_end: End offset inherited from the parent node. + seen_ids: Set for unique node-id generation. + fallback_level: Default level when entry.level is 0. + total_pages: Total page count for page_range look-ahead. """ nodes: List[TreeNode] = [] content_len = len(content) - for entry in entries: + for i, entry in enumerate(entries): start = max(0, min(entry.char_start, content_len)) end = entry.char_end if entry.char_end and entry.char_end > start else parent_end end = min(end, content_len) @@ -355,11 +379,23 @@ def _toc_entries_to_nodes( nid = DocumentTreeIndexer._unique_node_id(start, seen_ids) level = entry.level if entry.level > 0 else fallback_level + # page_range: look-ahead algorithm + page_range = None + if hasattr(entry, 'page_start') and entry.page_start is not None: + # Find next sibling with page_start to determine page_end + page_end = total_pages or entry.page_start + for j in range(i + 1, len(entries)): + if hasattr(entries[j], 'page_start') and entries[j].page_start is not None: + page_end = entries[j].page_start + break + page_range = (entry.page_start, max(entry.page_start, page_end)) + child_nodes: List[TreeNode] = [] if entry.children: child_nodes = DocumentTreeIndexer._toc_entries_to_nodes( entry.children, content, end, seen_ids, fallback_level=level + 1, + total_pages=total_pages, ) node = TreeNode( @@ -368,6 +404,7 @@ def _toc_entries_to_nodes( summary=section_text.strip(), char_range=(start, end), level=level, + page_range=page_range, children=child_nodes, ) nodes.append(node) @@ -495,40 +532,65 @@ def _parse_sections( def _resolve_positions( items: List[Dict[str, Any]], full_text: str, ) -> List[Dict[str, Any]]: - """Resolve section start/end character offsets from marker text.""" + """Resolve section start/end character offsets from marker text. + + Two-pass algorithm: + Pass 1 — determine all start positions with tiered fallback: + exact match from prev_end -> substring match -> full-text fallback. + Pass 2 — set end[i] = start[i+1]; last end = text_len. + + Filters out invalid spans and overly large spans (> ``_MAX_SPAN_RATIO`` + of the document) to prevent accumulated positioning errors. + """ + text_lower = full_text.lower() + text_len = len(full_text) resolved: List[Dict[str, Any]] = [] + + # Pass 1: determine all start positions prev_end = 0 - text_lower = full_text.lower() for item in items: title = item.get("title", "") - summary = item.get("summary", "") marker = item.get("start_marker", title) - pos = text_lower.find(marker.lower(), prev_end) if marker else -1 - start = pos if pos >= 0 else prev_end - - end_marker = item.get("end_marker", "") - if end_marker: - epos = text_lower.find(end_marker.lower(), start + 1) - end = epos if epos > start else min(start + 50000, len(full_text)) - else: - end = min(start + 50000, len(full_text)) + pos = -1 + if marker: + marker_lower = marker.lower() + # Level 1: exact match from prev_end + pos = text_lower.find(marker_lower, prev_end) + # Level 2: substring match (first N chars) from prev_end + if pos < 0 and len(marker_lower) > _MARKER_SUBSTRING_LEN: + pos = text_lower.find( + marker_lower[:_MARKER_SUBSTRING_LEN], prev_end, + ) + # Level 3: full text fallback from start + if pos < 0: + pos = text_lower.find(marker_lower, 0) + start = pos if pos >= 0 else prev_end resolved.append({ "title": title, - "summary": summary, + "summary": item.get("summary", ""), "start": start, - "end": end, + "end": text_len, # placeholder }) - prev_end = end + prev_end = ( + start + max(1, len(marker)) + if pos >= 0 + else prev_end + ) - # Fix gaps: each section ends where the next begins + # Pass 2: set end[i] = start[i+1], last end = text_len for i in range(len(resolved) - 1): resolved[i]["end"] = resolved[i + 1]["start"] if resolved: - resolved[-1]["end"] = len(full_text) + resolved[-1]["end"] = text_len - return [s for s in resolved if s["end"] > s["start"]] + # Filter out invalid spans and overly large spans + return [ + s for s in resolved + if s["end"] > s["start"] + and (s["end"] - s["start"]) / max(text_len, 1) < _MAX_SPAN_RATIO + ] async def _select_children( self, nodes: List[TreeNode], query: str, @@ -538,7 +600,7 @@ async def _select_children( return nodes listing = "\n".join( - f"[{i}] {n.title}: {n.summary[:150]}" + f"[{i}] {n.title}{self._format_page_range(n.page_range)}: {n.summary[:150]}" for i, n in enumerate(nodes) ) prompt = ( @@ -604,6 +666,16 @@ def _max_node_depth(node: TreeNode) -> int: return node.level return max(DocumentTreeIndexer._max_node_depth(c) for c in node.children) + @staticmethod + def _format_page_range( + page_range: "Optional[Tuple[int, int]]", + ) -> str: + """Format a page_range tuple into a human-readable string for prompts.""" + if not page_range: + return "" + ps, pe = page_range + return f" [pages {ps}-{pe}]" if ps != pe else f" [page {ps}]" + @staticmethod def should_build_tree(file_path: str, content_length: int) -> bool: """Determine whether a file is eligible for tree indexing.""" diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index c2f30f4..74b2b85 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -86,6 +86,10 @@ # Soft-similarity threshold for gradient cluster reuse (P2) _SOFT_SIM_THRESHOLD = 0.65 +# Pure tree search mode for ablation experiments. +# When enabled, search relies solely on tree index navigation, skipping rga keyword search. +_PURE_TREE_SEARCH: bool = os.getenv("SIRCHMUNK_PURE_TREE_SEARCH", "false").lower() == "true" + # Common English stop-words filtered out during keyword coverage computation. _STOP_WORDS: frozenset = frozenset({ "the", "is", "a", "an", "of", "in", "for", "to", "and", "or", @@ -1631,36 +1635,44 @@ async def _search_deep( # ============================================================== # Phase 2: Parallel retrieval — keyword search + dir_scan rank # ============================================================== - await self._logger.info("[Phase 2] Parallel retrieval: rga keyword search + dir_scan LLM rank") - context.increment_loop() + keyword_files: List[str] = [] + dir_scan_files: List[str] = [] - phase2_tasks = [] + if _PURE_TREE_SEARCH: + # Pure tree search mode: skip rga and dir_scan, rely solely on tree hits + await self._logger.info("[Phase 2:PureTree] Skipping rga keyword search and dir_scan") + context.increment_loop() + else: + await self._logger.info("[Phase 2] Parallel retrieval: rga keyword search + dir_scan LLM rank") + context.increment_loop() - if initial_keywords: - phase2_tasks.append( - self._retrieve_by_keywords( - initial_keywords, paths, - max_depth=max_depth, include=include, exclude=exclude, + phase2_tasks = [] + + if initial_keywords: + phase2_tasks.append( + self._retrieve_by_keywords( + initial_keywords, paths, + max_depth=max_depth, include=include, exclude=exclude, + ) ) - ) - else: - phase2_tasks.append(self._async_noop([])) + else: + phase2_tasks.append(self._async_noop([])) - if scan_result is not None and enable_dir_scan: - phase2_tasks.append( - self._rank_dir_scan_candidates(query, scan_result) - ) - else: - phase2_tasks.append(self._async_noop([])) + if scan_result is not None and enable_dir_scan: + phase2_tasks.append( + self._rank_dir_scan_candidates(query, scan_result) + ) + else: + phase2_tasks.append(self._async_noop([])) - phase2_results = await asyncio.gather(*phase2_tasks, return_exceptions=True) + phase2_results = await asyncio.gather(*phase2_tasks, return_exceptions=True) - keyword_files = phase2_results[0] if not isinstance(phase2_results[0], Exception) else [] - dir_scan_files = phase2_results[1] if not isinstance(phase2_results[1], Exception) else [] + keyword_files = phase2_results[0] if not isinstance(phase2_results[0], Exception) else [] + dir_scan_files = phase2_results[1] if not isinstance(phase2_results[1], Exception) else [] - for i, label in enumerate(["keyword_search", "dir_scan_rank"]): - if isinstance(phase2_results[i], Exception): - await self._logger.warning(f"[Phase 2] {label} failed: {phase2_results[i]}") + for i, label in enumerate(["keyword_search", "dir_scan_rank"]): + if isinstance(phase2_results[i], Exception): + await self._logger.warning(f"[Phase 2] {label} failed: {phase2_results[i]}") await self._logger.info( f"[Phase 2] Results: keyword_files={len(keyword_files)}, " @@ -1698,12 +1710,30 @@ async def _search_deep( extra_knowledge_files = knowledge_probe.file_paths if soft_hit: extra_knowledge_files = soft_hit.file_paths + extra_knowledge_files - merged_files = self._merge_file_paths( - keyword_files=list(tree_hits) + catalog_deep_hits + compile_hints.file_paths + summary_index_hits + keyword_files, - dir_scan_files=dir_scan_files, - knowledge_hits=extra_knowledge_files, - ) - await self._logger.info(f"[Phase 3] Merged {len(merged_files)} unique candidate files") + + if _PURE_TREE_SEARCH: + # Pure tree search: only use tree hits (+ soft-hit fallback if no tree hits) + pure_tree_files = list(tree_hits) + if not pure_tree_files and soft_hit: + pure_tree_files = soft_hit.file_paths + await self._logger.info( + f"[Phase 3:PureTree] No tree hits, using {len(pure_tree_files)} soft-hit files" + ) + merged_files = self._merge_file_paths( + keyword_files=pure_tree_files, + dir_scan_files=[], + knowledge_hits=[], + ) + await self._logger.info( + f"[Phase 3:PureTree] Merged {len(merged_files)} tree-only candidate files" + ) + else: + merged_files = self._merge_file_paths( + keyword_files=list(tree_hits) + catalog_deep_hits + compile_hints.file_paths + summary_index_hits + keyword_files, + dir_scan_files=dir_scan_files, + knowledge_hits=extra_knowledge_files, + ) + await self._logger.info(f"[Phase 3] Merged {len(merged_files)} unique candidate files") cluster: Optional[KnowledgeCluster] = None if merged_files: @@ -2181,6 +2211,8 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Maximum files returned by tree index probing in DEEP mode.""" _TREE_ROOT_HINT_TRUNCATE = 150 """Max chars of tree root summary in Step 1 structure hints.""" + _CHAR_RANGE_MAX_SPAN_RATIO: float = 0.8 + """char_range spanning more than this ratio of the document is treated as invalid.""" # --- Self-correction expanded sampling --- _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 6 @@ -2484,58 +2516,88 @@ async def _search_fast( evidence = "" file_path: Optional[str] = None # set when best_files found - # High-confidence catalog routing: skip rga, use catalog directly - if catalog_routed_files and catalog_confidence == "high": - used_level = "catalog_route" - await self._logger.info( - f"[FAST:Step2] High-confidence catalog routing → " - f"{[Path(p).name for p in catalog_routed_files[:top_k_files]]}" - ) - best_files = [ - {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} - for p in catalog_routed_files[:top_k_files] - ] - - if not best_files and primary: - best_files = await self._fast_find_best_file( - primary, top_k=top_k_files, keyword_idfs=keyword_idfs, - query=query, artifacts=artifacts, - **rga_kwargs, - ) + # --- Pure tree search mode: skip rga, use tree probe results directly --- + if _PURE_TREE_SEARCH: + if _tree_probed_files: + used_level = "pure_tree" + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in _tree_probed_files[:top_k_files] + ] + await self._logger.info( + f"[FAST:PureTree] Using {len(best_files)} tree-probed files: " + f"{[Path(p).name for p in _tree_probed_files[:top_k_files]]}" + ) + elif compile_hint_files: + # Tree probe returned nothing but compile hints have tree files + used_level = "pure_tree_hint" + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in compile_hint_files[:top_k_files] + ] + await self._logger.info( + f"[FAST:PureTree] No tree probes, falling back to " + f"{len(best_files)} compile-hint files" + ) + else: + await self._logger.warning( + "[FAST:PureTree] No tree probes available, returning empty" + ) + return _NO_RESULTS_MESSAGE, None, context + else: + # --- Original rga-based retrieval logic --- + # High-confidence catalog routing: skip rga, use catalog directly + if catalog_routed_files and catalog_confidence == "high": + used_level = "catalog_route" + await self._logger.info( + f"[FAST:Step2] High-confidence catalog routing → " + f"{[Path(p).name for p in catalog_routed_files[:top_k_files]]}" + ) + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in catalog_routed_files[:top_k_files] + ] - if not best_files and fallback: - used_level = "fallback" - await self._logger.info( - "[FAST:Step2] Primary miss, trying fine-grained fallback" - ) - best_files = await self._fast_find_best_file( - fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, - query=query, artifacts=artifacts, - **rga_kwargs, - ) + if not best_files and primary: + best_files = await self._fast_find_best_file( + primary, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + **rga_kwargs, + ) - # --- Fallback: compile-hint files when rga misses (catalog + P2 + P4) --- - if not best_files and compile_hint_files: - used_level = "compile_hint" - await self._logger.info( - f"[FAST:Step2] rga miss — using {len(compile_hint_files)} compile-hint files" - ) - best_files = [ - {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} - for p in compile_hint_files[:top_k_files] - ] + if not best_files and fallback: + used_level = "fallback" + await self._logger.info( + "[FAST:Step2] Primary miss, trying fine-grained fallback" + ) + best_files = await self._fast_find_best_file( + fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + **rga_kwargs, + ) - # --- Fallback: use dir_scan only when rga misses and dir scan is enabled --- - if not best_files and enable_dir_scan: - scan_result = await self._probe_dir_scan(paths, enable=True, max_files=300) - if scan_result is not None: - await self._logger.info("[FAST:Step2] rga miss — falling back to dir_scan ranking") - ranked_paths = await self._rank_dir_scan_candidates( - query, scan_result, top_k=10, include_medium=True, + # --- Fallback: compile-hint files when rga misses (catalog + P2 + P4) --- + if not best_files and compile_hint_files: + used_level = "compile_hint" + await self._logger.info( + f"[FAST:Step2] rga miss — using {len(compile_hint_files)} compile-hint files" ) - if ranked_paths: - used_level = "dir_scan" - best_files = [{"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} for p in ranked_paths[:top_k_files]] + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in compile_hint_files[:top_k_files] + ] + + # --- Fallback: use dir_scan only when rga misses and dir scan is enabled --- + if not best_files and enable_dir_scan: + scan_result = await self._probe_dir_scan(paths, enable=True, max_files=300) + if scan_result is not None: + await self._logger.info("[FAST:Step2] rga miss — falling back to dir_scan ranking") + ranked_paths = await self._rank_dir_scan_candidates( + query, scan_result, top_k=10, include_medium=True, + ) + if ranked_paths: + used_level = "dir_scan" + best_files = [{"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} for p in ranked_paths[:top_k_files]] if not best_files: if llm_fallback: @@ -3745,14 +3807,24 @@ async def _tree_guided_sample( total_chars = 0 for leaf in leaves[: self._TREE_SAMPLE_MAX_SECTIONS]: start, end = leaf.char_range - if full_text and end > start: + if self._is_valid_char_range(start, end, len(full_text)) and full_text: segment = full_text[start:end] + elif leaf.summary: + logger.debug( + f"[TreeNav] char_range degraded for '{leaf.title}' " + f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" + ) + segment = leaf.summary else: - segment = leaf.summary or "" + continue segment = segment[: self._TREE_SAMPLE_SECTION_MAX_CHARS] if not segment.strip(): continue - header = f"[{fname} \u2192 {leaf.title}]" + page_info = "" + if leaf.page_range: + ps, pe = leaf.page_range + page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" + header = f"[{fname} → {leaf.title}{page_info}]" chunk = f"{header}\n{segment}" if total_chars + len(chunk) > max_chars: remaining = max_chars - total_chars @@ -3798,6 +3870,20 @@ async def _tree_guided_sample( ) return evidence + def _is_valid_char_range( + self, start: int, end: int, text_len: int, + ) -> bool: + """Check whether a char_range is valid for slicing. + + A range is invalid when it covers more than + ``_CHAR_RANGE_MAX_SPAN_RATIO`` of the document (likely a + whole-document fallback) or when *end <= start*. + """ + if start < 0 or end <= start or text_len <= 0: + return False + span_ratio = (end - start) / text_len + return span_ratio < self._CHAR_RANGE_MAX_SPAN_RATIO + async def _navigate_tree_for_evidence( self, file_path: str, query: str, *, max_results: int = 3, ) -> Optional[str]: @@ -3834,12 +3920,22 @@ async def _navigate_tree_for_evidence( for leaf in leaves: start, end = leaf.char_range - if full_text and end > start: + if self._is_valid_char_range(start, end, len(full_text)) and full_text: segment = full_text[start:end] + elif leaf.summary: + logger.debug( + f"[TreeNav] char_range degraded for '{leaf.title}' " + f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" + ) + segment = leaf.summary else: - segment = leaf.summary or "" + continue if segment.strip(): - header = f"[{fname} → {leaf.title}]" + page_info = "" + if leaf.page_range: + ps, pe = leaf.page_range + page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" + header = f"[{fname} → {leaf.title}{page_info}]" parts.append(f"{header}\n{segment[:3000]}") if not parts: From d4e8fe3a83ec1504776be330643d13448e54d083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 19 Apr 2026 22:14:17 +0800 Subject: [PATCH 31/70] fix table extraction --- src/sirchmunk/learnings/compiler.py | 104 ++++++++++++++++++ src/sirchmunk/learnings/tree_indexer.py | 7 +- src/sirchmunk/search.py | 122 +++++++++++++++++++++- src/sirchmunk/utils/document_extractor.py | 13 +++ 4 files changed, 243 insertions(+), 3 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index ad2a115..037070b 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -76,6 +76,8 @@ class FileManifestEntry: has_explicit_toc: bool = False # Whether a native TOC was extracted from the file tree_node_count: int = 0 # Number of nodes in the tree index (quality metric) has_xlsx_digest: bool = False # Whether a pre-compiled Excel evidence digest exists + has_table_digest: bool = False # Whether PDF tables were extracted and stored + table_count: int = 0 # Number of tables in this file def to_dict(self) -> Dict[str, Any]: return { @@ -88,6 +90,8 @@ def to_dict(self) -> Dict[str, Any]: "has_explicit_toc": self.has_explicit_toc, "tree_node_count": self.tree_node_count, "has_xlsx_digest": self.has_xlsx_digest, + "has_table_digest": self.has_table_digest, + "table_count": self.table_count, } @classmethod @@ -102,6 +106,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "FileManifestEntry": has_explicit_toc=data.get("has_explicit_toc", False), tree_node_count=data.get("tree_node_count", 0), has_xlsx_digest=data.get("has_xlsx_digest", False), + has_table_digest=data.get("has_table_digest", False), + table_count=data.get("table_count", 0), ) @@ -167,6 +173,8 @@ class FileCompileResult: has_explicit_toc: bool = False # Whether TOC was extracted from native structure tree_node_count: int = 0 # Number of nodes in the tree index has_xlsx_digest: bool = False # Whether a pre-compiled Excel evidence digest exists + has_table_digest: bool = False # Whether a pre-compiled table digest exists + table_count: int = 0 # Number of tables extracted @dataclass @@ -402,6 +410,8 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: has_explicit_toc=result.has_explicit_toc, tree_node_count=result.tree_node_count, has_xlsx_digest=result.has_xlsx_digest, + has_table_digest=result.has_table_digest, + table_count=result.table_count, ) # Phase 3: aggregate results into knowledge network @@ -609,6 +619,29 @@ async def _compile_single_file( except Exception: pass + # Persist table digest for documents with extracted tables + if extraction.tables: + try: + table_digest = self._build_table_digest(extraction.tables) + if table_digest: + digest_dir = self._compile_dir / "table_digests" + digest_dir.mkdir(parents=True, exist_ok=True) + file_hash = get_fast_hash(entry.path) or "" + if file_hash: + digest_path = digest_dir / f"{file_hash}.json" + digest_path.write_text( + json.dumps(table_digest, ensure_ascii=False), + encoding="utf-8", + ) + result.has_table_digest = True + result.table_count = len(extraction.tables) + except Exception: + pass + + # Annotate tree nodes with table counts for navigation hints + if result.tree and result.tree.root and extraction.tables: + self._annotate_tree_with_table_counts(result.tree.root, extraction.tables) + except Exception as exc: result.error = str(exc) await self._log.warning(f"[Compile] Failed: {entry.path}: {exc}") @@ -1130,6 +1163,77 @@ def _add_edge( WeakSemanticEdge(target_cluster_id=target_id, weight=weight, source=source) ) + def _build_table_digest( + self, tables: List[Dict[str, Any]], + ) -> Optional[Dict[str, Any]]: + """Build a structured table digest from extraction output. + + Returns a versioned JSON-serializable dict containing all tables + with their page numbers, markdown representation, and cell data. + Tables are indexed for page-range-based retrieval at search time. + """ + if not tables: + return None + + digest_tables = [] + for idx, table in enumerate(tables): + markdown = table.get("markdown", "") + cells = table.get("cells", []) + if not markdown and not cells: + continue + + # Compute row/col counts from cells (kreuzberg returns List[List[str]]) + row_count = 0 + col_count = 0 + if cells: + row_count = len(cells) + col_count = max((len(row) for row in cells if isinstance(row, (list, tuple))), default=0) + elif markdown: + # Estimate from markdown lines + lines = [l for l in markdown.strip().split("\n") if l.strip().startswith("|")] + row_count = max(0, len(lines) - 1) # exclude separator + col_count = lines[0].count("|") - 1 if lines else 0 + + digest_tables.append({ + "index": idx, + "page_number": table.get("page_number"), + "markdown": markdown, + "row_count": row_count, + "col_count": col_count, + "cells": cells, + }) + + if not digest_tables: + return None + + return { + "version": 1, + "table_count": len(digest_tables), + "tables": digest_tables, + } + + def _annotate_tree_with_table_counts( + self, + node: "TreeNode", + tables: List[Dict[str, Any]], + ) -> None: + """Annotate tree nodes with table count based on page_range overlap. + + For each node with a valid page_range, counts how many extracted + tables fall within that range and sets node.table_count accordingly. + """ + if node is None: + return + if node.page_range: + ps, pe = node.page_range + count = sum( + 1 for t in tables + if t.get("page_number") is not None and ps <= t["page_number"] <= pe + ) + node.table_count = count + for child in node.children: + self._annotate_tree_with_table_counts(child, tables) + @staticmethod def _count_tree_nodes(tree: Optional[DocumentTree]) -> int: """Count total nodes in a DocumentTree (recursive). diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 8d93a2c..6895745 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -65,6 +65,7 @@ class TreeNode: level: int = 0 page_range: Optional[Tuple[int, int]] = None children: List["TreeNode"] = field(default_factory=list) + table_count: int = 0 # Number of tables associated with this node's page range def to_dict(self) -> Dict[str, Any]: return { @@ -75,6 +76,7 @@ def to_dict(self) -> Dict[str, Any]: "level": self.level, "page_range": list(self.page_range) if self.page_range else None, "children": [c.to_dict() for c in self.children], + "table_count": self.table_count, } @classmethod @@ -89,6 +91,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "TreeNode": level=data.get("level", 0), page_range=tuple(pr) if pr else None, children=children, + table_count=data.get("table_count", 0), ) @property @@ -600,7 +603,9 @@ async def _select_children( return nodes listing = "\n".join( - f"[{i}] {n.title}{self._format_page_range(n.page_range)}: {n.summary[:150]}" + f"[{i}] {n.title}{self._format_page_range(n.page_range)}" + f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" + f": {n.summary[:150]}" for i, n in enumerate(nodes) ) prompt = ( diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 74b2b85..9374e27 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2671,6 +2671,18 @@ async def _rga_evidence() -> str: except Exception: pass + # 0.5 Table digest priority (pre-compiled PDF table evidence) + if ev is None and artifacts and artifacts.manifest_map: + _me = artifacts.manifest_map.get(fp) + if _me and getattr(_me, 'has_table_digest', False): + _all_tables = self._load_table_digest( + self.work_path, _me.file_hash, + ) + if _all_tables: + _table_ev = self._format_table_evidence(_all_tables) + if _table_ev: + ev = f"[{fn} - Table Evidence]\n{_table_ev}" + # 1. Tree-guided sampling FIRST for tree-indexed files if ( artifacts @@ -3810,7 +3822,7 @@ async def _tree_guided_sample( if self._is_valid_char_range(start, end, len(full_text)) and full_text: segment = full_text[start:end] elif leaf.summary: - logger.debug( + _loguru_logger.debug( f"[TreeNav] char_range degraded for '{leaf.title}' " f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" ) @@ -3884,6 +3896,81 @@ def _is_valid_char_range( span_ratio = (end - start) / text_len return span_ratio < self._CHAR_RANGE_MAX_SPAN_RATIO + @staticmethod + def _load_table_digest( + work_path: Path, file_hash: str, + ) -> Optional[List[Dict[str, Any]]]: + """Load pre-compiled table digest for a file. + + Returns the list of table entries from the digest JSON, or None + if no digest exists or loading fails. + """ + digest_path = ( + work_path / ".cache" / "compile" / "table_digests" / f"{file_hash}.json" + ) + if not digest_path.exists(): + return None + try: + data = json.loads(digest_path.read_text(encoding="utf-8")) + return data.get("tables", []) + except Exception: + return None + + @staticmethod + def _filter_tables_by_page_range( + tables: List[Dict[str, Any]], + page_start: int, + page_end: int, + ) -> List[Dict[str, Any]]: + """Filter tables whose page_number falls within the given range (inclusive).""" + return [ + t for t in tables + if t.get("page_number") is not None + and page_start <= t["page_number"] <= page_end + ] + + @staticmethod + def _format_table_evidence( + tables: List[Dict[str, Any]], + max_chars: int = 3000, + ) -> str: + """Format table digest entries as LLM-friendly evidence text. + + Strategy: + - Small tables (<1000 chars): preserve full Markdown + - Large tables: truncate to max_chars with "(truncated)" note + - Each table prefixed with "[Table from page N]" + + Returns concatenated formatted table evidence string. + """ + if not tables: + return "" + + parts: List[str] = [] + remaining = max_chars + + for table in tables: + if remaining <= 0: + break + + page = table.get("page_number", "?") + markdown = table.get("markdown", "") + + if not markdown: + continue + + header = f"[Table from page {page}]" + + if len(markdown) <= remaining: + parts.append(f"{header}\n{markdown}") + remaining -= len(markdown) + len(header) + 2 + else: + truncated = markdown[:remaining] + parts.append(f"{header}\n{truncated}\n(truncated)") + remaining = 0 + + return "\n\n".join(parts) + async def _navigate_tree_for_evidence( self, file_path: str, query: str, *, max_results: int = 3, ) -> Optional[str]: @@ -3923,7 +4010,7 @@ async def _navigate_tree_for_evidence( if self._is_valid_char_range(start, end, len(full_text)) and full_text: segment = full_text[start:end] elif leaf.summary: - logger.debug( + _loguru_logger.debug( f"[TreeNav] char_range degraded for '{leaf.title}' " f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" ) @@ -3941,6 +4028,37 @@ async def _navigate_tree_for_evidence( if not parts: return None + # Supplement with table evidence if available + try: + from sirchmunk.utils.file_utils import get_fast_hash + _file_hash = get_fast_hash(file_path) + if _file_hash: + _all_tables = self._load_table_digest( + self.work_path, _file_hash, + ) + if _all_tables and leaves: + _seen_pages: set = set() + for leaf in leaves: + if leaf.page_range: + ps, pe = leaf.page_range + page_key = (ps, pe) + if page_key in _seen_pages: + continue + _seen_pages.add(page_key) + leaf_tables = self._filter_tables_by_page_range( + _all_tables, ps, pe, + ) + if leaf_tables: + table_text = self._format_table_evidence( + leaf_tables, max_chars=2000, + ) + if table_text: + parts.append( + f"[Tables pp.{ps}-{pe}]\n{table_text}" + ) + except Exception: + pass + evidence = "\n\n".join(parts) await self._logger.info( f"[FAST:TreeNav] Extracted {len(parts)} sections, " diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py index 76e0f15..d72d397 100644 --- a/src/sirchmunk/utils/document_extractor.py +++ b/src/sirchmunk/utils/document_extractor.py @@ -303,6 +303,17 @@ def _build_config(profile: ExtractionProfile): from kreuzberg import LanguageDetectionConfig lang_config = LanguageDetectionConfig(enabled=True) + # --- Layout detection for table extraction --- + layout_config = None + if profile.extract_tables: + try: + from kreuzberg import LayoutDetectionConfig + layout_config = LayoutDetectionConfig() + except ImportError: + # kreuzberg <= 4.2.x extracts tables by default; + # filtering is handled in _convert_result(). + pass + # --- Assemble ExtractionConfig --- kwargs: dict[str, Any] = { "output_format": output_format, @@ -319,6 +330,8 @@ def _build_config(profile: ExtractionProfile): kwargs["language_detection"] = lang_config if profile.max_concurrent is not None: kwargs["max_concurrent_extractions"] = profile.max_concurrent + if layout_config is not None: + kwargs["layout"] = layout_config return ExtractionConfig(**kwargs) From 1d550bf9d73e5bc043f8a7ce91b8742c8859ef04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 20 Apr 2026 00:53:53 +0800 Subject: [PATCH 32/70] fix warning --- src/sirchmunk/learnings/compiler.py | 4 ++ src/sirchmunk/utils/document_extractor.py | 72 +++++++++++++++++++++-- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 037070b..6f65e12 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -1194,6 +1194,10 @@ def _build_table_digest( row_count = max(0, len(lines) - 1) # exclude separator col_count = lines[0].count("|") - 1 if lines else 0 + # Skip pseudo-tables: single-column or insufficient structure + if col_count <= 1: + continue + digest_tables.append({ "index": idx, "page_number": table.get("page_number"), diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py index d72d397..d114b7d 100644 --- a/src/sirchmunk/utils/document_extractor.py +++ b/src/sirchmunk/utils/document_extractor.py @@ -143,9 +143,16 @@ class DocumentExtractor: extract_tables=True, extract_metadata=True, pdf_extract_metadata=True, - force_ocr=True, + force_ocr=False, ) - """Rich extraction with tables, metadata, and OCR fallback.""" + """Rich extraction with tables, metadata, and layout-based table detection. + + ``force_ocr`` is disabled because: + - Most documents (e.g. 10-K, 10-Q PDFs) already contain a native text layer. + - kreuzberg automatically falls back to OCR for scanned / image-only pages. + - Forcing OCR triggers Tesseract ObjectCache leak warnings in concurrent use + and significantly slows down compilation with no quality benefit. + """ # Public API ----------------------------------------------------------- @@ -174,7 +181,21 @@ async def extract( try: result = await extract_file(file_path=file_path, config=config) - return DocumentExtractor._convert_result(result, profile) + output = DocumentExtractor._convert_result(result, profile) + # Fallback: kreuzberg 4.9.1 returns page_count=0 when force_ocr=True; + # use pypdf to get the real page count when missing. + if output.page_count is None: + fallback = DocumentExtractor._fallback_page_count(file_path) + if fallback is not None: + output = ExtractionOutput( + content=output.content, + mime_type=output.mime_type, + metadata=output.metadata, + tables=output.tables, + detected_languages=output.detected_languages, + page_count=fallback, + ) + return output except Exception as exc: logger.error( "Document extraction failed for {}: {}", @@ -232,15 +253,54 @@ async def batch_extract( try: results = await batch_extract_files(paths=list(file_paths), config=config) - return [ + outputs = [ DocumentExtractor._convert_result(r, profile) for r in results ] + # Apply page_count fallback for each output + fixed: List[ExtractionOutput] = [] + for output, fp in zip(outputs, file_paths): + if output.page_count is None: + fallback = DocumentExtractor._fallback_page_count(fp) + if fallback is not None: + output = ExtractionOutput( + content=output.content, + mime_type=output.mime_type, + metadata=output.metadata, + tables=output.tables, + detected_languages=output.detected_languages, + page_count=fallback, + ) + fixed.append(output) + return fixed except Exception: logger.error("Batch extraction failed for {} files", len(file_paths)) raise # Internal helpers ----------------------------------------------------- + @staticmethod + def _fallback_page_count( + file_path: Union[str, Path], + ) -> Optional[int]: + """Get page count via pypdf when kreuzberg fails to report it. + + kreuzberg >= 4.9.1 returns ``get_page_count() == 0`` when + ``force_ocr=True`` is set. This fallback uses pypdf (already a + transitive dependency) for a lightweight page-count-only read. + + Returns: + Page count, or None for non-PDF files or on error. + """ + if Path(file_path).suffix.lower() != ".pdf": + return None + try: + from pypdf import PdfReader + reader = PdfReader(str(file_path)) + count = len(reader.pages) + return count if count > 0 else None + except Exception: + return None + @staticmethod def _build_config(profile: ExtractionProfile): """Build a kreuzberg ``ExtractionConfig`` from an :class:`ExtractionProfile`. @@ -308,9 +368,11 @@ def _build_config(profile: ExtractionProfile): if profile.extract_tables: try: from kreuzberg import LayoutDetectionConfig + # kreuzberg >= 4.5.0: model-based table detection (RT-DETR v2) + # Default: table_model="tatr", apply_heuristics=True layout_config = LayoutDetectionConfig() except ImportError: - # kreuzberg <= 4.2.x extracts tables by default; + # kreuzberg < 4.5.0: tables extracted via heuristics only; # filtering is handled in _convert_result(). pass From 384d345d8c7f21e1ca435d3143e620f46b71ad79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 20 Apr 2026 02:02:52 +0800 Subject: [PATCH 33/70] enhance compiler --- src/sirchmunk/learnings/compiler.py | 175 ++++++++++++++- src/sirchmunk/learnings/tree_indexer.py | 282 ++++++++++++++++++++++-- src/sirchmunk/search.py | 50 +++-- 3 files changed, 458 insertions(+), 49 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 6f65e12..531cb28 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -638,9 +638,12 @@ async def _compile_single_file( except Exception: pass - # Annotate tree nodes with table counts for navigation hints + # Integrate tables into tree: annotate counts + create table child nodes if result.tree and result.tree.root and extraction.tables: - self._annotate_tree_with_table_counts(result.tree.root, extraction.tables) + self._integrate_tables_into_tree( + result.tree.root, extraction.tables, + content=content, total_pages=extraction.page_count, + ) except Exception as exc: result.error = str(exc) @@ -1216,27 +1219,175 @@ def _build_table_digest( "tables": digest_tables, } - def _annotate_tree_with_table_counts( + def _integrate_tables_into_tree( self, node: "TreeNode", tables: List[Dict[str, Any]], + content: str, + *, + total_pages: Optional[int] = None, + _counter: Optional[List[int]] = None, ) -> None: - """Annotate tree nodes with table count based on page_range overlap. + """Integrate tables into tree: annotate counts AND create table child nodes for leaf nodes. - For each node with a valid page_range, counts how many extracted - tables fall within that range and sets node.table_count accordingly. + For each node with a valid page_range, counts how many valid extracted + tables fall within that range (excluding pseudo-tables with col_count <= 1). + For leaf nodes with matching tables, creates dedicated TreeNode children + with ``content_type="table"``. """ + from sirchmunk.learnings.tree_indexer import TreeNode + if node is None: return + + if _counter is None: + _counter = [0] + + # Depth-first: process existing children first + for child in list(node.children): + self._integrate_tables_into_tree( + child, tables, content, + total_pages=total_pages, _counter=_counter, + ) + + # Match valid tables to this node's page_range + matched_tables: List[Dict[str, Any]] = [] if node.page_range: ps, pe = node.page_range - count = sum( - 1 for t in tables - if t.get("page_number") is not None and ps <= t["page_number"] <= pe + for t in tables: + pn = t.get("page_number") + if pn is None or not (ps <= pn <= pe): + continue + # Skip pseudo-tables + if self._is_pseudo_table(t): + continue + matched_tables.append(t) + + node.table_count = len(matched_tables) + + # Create table child nodes only for leaf nodes with matched tables + if not node.children and matched_tables: + try: + self._spawn_table_children( + node, matched_tables, content, _counter, + ) + except Exception: + pass # Never break compile for table node creation + + @staticmethod + def _is_pseudo_table(table: Dict[str, Any]) -> bool: + """Return True if the table lacks meaningful structure (col_count <= 1).""" + markdown = table.get("markdown", "") + cells = table.get("cells", []) + if not markdown and not cells: + return True + col_count = 0 + if cells: + col_count = max( + (len(row) for row in cells if isinstance(row, (list, tuple))), + default=0, ) - node.table_count = count - for child in node.children: - self._annotate_tree_with_table_counts(child, tables) + elif markdown: + lines = [l for l in markdown.strip().split("\n") if l.strip().startswith("|")] + col_count = (lines[0].count("|") - 1) if lines else 0 + return col_count <= 1 + + def _spawn_table_children( + self, + node: "TreeNode", + matched_tables: List[Dict[str, Any]], + content: str, + counter: List[int], + ) -> None: + """Create TreeNode children for each matched table under a leaf node. + + Also inserts a text-content sibling preserving the original leaf content. + """ + from sirchmunk.learnings.tree_indexer import TreeNode + + child_level = node.level + 1 + + # Preserve original text content as first child + text_child_id = f"T{counter[0]:06d}" + counter[0] += 1 + node.children.append( + TreeNode( + node_id=text_child_id, + title=node.title, + summary=node.summary[:300] if node.summary else "", + char_range=node.char_range, + level=child_level, + page_range=node.page_range, + children=[], + table_count=0, + content_type="text", + ) + ) + + # Create one child per table + for table in matched_tables: + tid = f"T{counter[0]:06d}" + counter[0] += 1 + + markdown = table.get("markdown", "") + title = self._extract_table_title(table) + page_number = table.get("page_number") + + # Attempt to locate table markdown in content + char_range = node.char_range + if markdown and content: + pos = content.find(markdown[:120]) + if pos >= 0: + char_range = (pos, pos + len(markdown)) + + page_range = ( + (page_number, page_number) if page_number is not None + else node.page_range + ) + + node.children.append( + TreeNode( + node_id=tid, + title=title, + summary=markdown[:300] if markdown else "", + char_range=char_range, + level=child_level, + page_range=page_range, + children=[], + table_count=0, + content_type="table", + ) + ) + + @staticmethod + def _extract_table_title(table: Dict[str, Any]) -> str: + """Extract a concise title from table markdown header row. + + Parses the first meaningful line of the markdown table (skipping + separator rows like ``|---|---|``), strips ``|`` delimiters, and + returns the first 80 characters as the title. + """ + markdown = table.get("markdown", "") + if not markdown: + pn = table.get("page_number", "?") + return f"Table (p.{pn})" + + for line in markdown.strip().split("\n"): + stripped = line.strip() + if not stripped: + continue + # Skip separator rows (e.g. |---|---| or +---+---+) + content_chars = stripped.replace("|", "").replace("-", "").replace(":", "").replace("+", "").strip() + if not content_chars: + continue + # Extract cell contents + title = " | ".join( + seg.strip() for seg in stripped.split("|") if seg.strip() + ) + return title[:80] if title else f"Table (p.{table.get('page_number', '?')})" + + pn = table.get("page_number", "?") + return f"Table (p.{pn})" @staticmethod def _count_tree_nodes(tree: Optional[DocumentTree]) -> int: diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 6895745..060c9b8 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -66,6 +66,7 @@ class TreeNode: page_range: Optional[Tuple[int, int]] = None children: List["TreeNode"] = field(default_factory=list) table_count: int = 0 # Number of tables associated with this node's page range + content_type: str = "text" # "text" | "table" def to_dict(self) -> Dict[str, Any]: return { @@ -77,6 +78,7 @@ def to_dict(self) -> Dict[str, Any]: "page_range": list(self.page_range) if self.page_range else None, "children": [c.to_dict() for c in self.children], "table_count": self.table_count, + "content_type": self.content_type, } @classmethod @@ -92,6 +94,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "TreeNode": page_range=tuple(pr) if pr else None, children=children, table_count=data.get("table_count", 0), + content_type=data.get("content_type", "text"), ) @property @@ -214,6 +217,8 @@ async def build_tree( toc_entries, content, total_pages=total_pages, ) if root is not None: + await self._deepen_large_leaves(root, content, max_depth=effective_depth) + await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, file_hash=file_hash, @@ -233,6 +238,9 @@ async def build_tree( if root is None: return None + await self._deepen_large_leaves(root, content, max_depth=effective_depth) + await self._enrich_node_summaries(root, content) + tree = DocumentTree( file_path=file_path, file_hash=file_hash, @@ -254,36 +262,48 @@ async def navigate( query: str, *, max_results: int = 3, + max_nav_depth: int = 4, ) -> List[TreeNode]: """Reasoning-based tree navigation: LLM selects the most relevant branches. + Iteratively descends through the tree until leaf nodes are reached or + *max_nav_depth* selection rounds are exhausted. + Returns up to *max_results* leaf nodes with their char_range for precise evidence extraction. """ if tree.root is None: return [] - candidates = tree.root.children if tree.root.children else [tree.root] - if not candidates: + current = tree.root.children if tree.root.children else [tree.root] + if not current: return [tree.root] - selected = await self._select_children(candidates, query) - if not selected: - return [] - - result_leaves: List[TreeNode] = [] - for node in selected: - if node.leaf: - result_leaves.append(node) - else: - deeper = await self._select_children(node.children, query) - for d in (deeper or node.children[:1]): - result_leaves.extend(d.all_leaves()[:max_results]) + selected: List[TreeNode] = current + for _ in range(max_nav_depth): + selected = await self._select_children(current, query) + if not selected: + break + # All leaves — stop descending + if all(n.leaf for n in selected): + break + # Expand non-leaf children, keep leaves as-is + next_level: List[TreeNode] = [] + for n in selected: + if n.leaf: + next_level.append(n) + else: + next_level.extend(n.children) + if not next_level: + break + current = next_level + else: + selected = current # Deduplicate and cap - seen_ids = set() + seen_ids: set = set() unique: List[TreeNode] = [] - for n in result_leaves: + for n in (selected or current): if n.node_id not in seen_ids: seen_ids.add(n.node_id) unique.append(n) @@ -604,12 +624,17 @@ async def _select_children( listing = "\n".join( f"[{i}] {n.title}{self._format_page_range(n.page_range)}" + f" [{n.content_type.upper()}]" f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" - f": {n.summary[:150]}" + f": {n.summary[:200]}" for i, n in enumerate(nodes) ) prompt = ( f"Given the query: \"{query}\"\n\n" + "Guidelines:\n" + "- For numerical/financial data queries, prefer TABLE nodes and consolidated statements\n" + "- Prefer company-wide/consolidated data over segment-level unless query specifies a segment\n" + "- When multiple tables exist, select the one most directly answering the query\n\n" f"Select the 1-2 most relevant sections (by index number):\n{listing}\n\n" f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" ) @@ -681,6 +706,229 @@ def _format_page_range( ps, pe = page_range return f" [pages {ps}-{pe}]" if ps != pe else f" [page {ps}]" + # ------------------------------------------------------------------ # + # Leaf deepening & summary enrichment # + # ------------------------------------------------------------------ # + + async def _deepen_large_leaves( + self, + node: TreeNode, + content: str, + *, + max_leaf_chars: int = 5000, + max_depth: int = 4, + _seen_ids: Optional[set] = None, + ) -> None: + """Recursively deepen leaf nodes whose char_range exceeds *max_leaf_chars* using LLM decomposition.""" + if _seen_ids is None: + _seen_ids = self._collect_node_ids(node) + + if not node.leaf: + for child in node.children: + await self._deepen_large_leaves( + child, content, + max_leaf_chars=max_leaf_chars, + max_depth=max_depth, + _seen_ids=_seen_ids, + ) + return + + start, end = node.char_range + span = end - start + if span <= max_leaf_chars or node.level >= max_depth: + return + + snippet = self._truncate_snippet(content[start:end]) + + prompt = ( + "Analyze this document section and identify 3-8 logical sub-sections.\n" + "For each sub-section, provide:\n" + '- "title": descriptive heading (concise)\n' + '- "start_text": the first 8-15 words that mark where this sub-section ' + "begins (must be exact text from the content)\n" + '- "content_type": "text" or "table"\n\n' + f'Section: "{node.title}"\n---\n{snippet}\n---\n\n' + 'Return ONLY a JSON array, e.g. ' + '[{"title": "...", "start_text": "...", "content_type": "text"}, ...]' + ) + + try: + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + sub_sections = self._parse_json_array(resp.content) + if not sub_sections or len(sub_sections) < 2: + return + except Exception: + return + + sub_nodes = self._build_sub_nodes_from_llm( + sub_sections, node, content, _seen_ids, + ) + if not sub_nodes: + return + + node.children = sub_nodes + await self._log.info( + f"[TreeIndexer] Deepened '{node.title}' into {len(sub_nodes)} sub-nodes" + ) + + # Recurse into newly created children + for child in node.children: + await self._deepen_large_leaves( + child, content, + max_leaf_chars=max_leaf_chars, + max_depth=max_depth, + _seen_ids=_seen_ids, + ) + + def _build_sub_nodes_from_llm( + self, + sub_sections: List[Dict[str, Any]], + parent: TreeNode, + content: str, + seen_ids: set, + ) -> List[TreeNode]: + """Create child TreeNodes from LLM-decomposed sub-sections.""" + parent_start, parent_end = parent.char_range + parent_span = max(parent_end - parent_start, 1) + parent_ps, parent_pe = parent.page_range if parent.page_range else (0, 0) + page_span = parent_pe - parent_ps + child_level = parent.level + 1 + + # Resolve char_start for each sub-section + positions: List[int] = [] + search_from = parent_start + for sec in sub_sections: + start_text = sec.get("start_text", "") + pos = content.find(start_text, search_from) if start_text else -1 + if pos < 0 or pos >= parent_end: + pos = search_from + positions.append(pos) + search_from = pos + 1 + + nodes: List[TreeNode] = [] + for i, sec in enumerate(sub_sections): + char_start = positions[i] + char_end = positions[i + 1] if i + 1 < len(positions) else parent_end + + # Estimate page_range proportionally from parent + page_range = None + if parent.page_range and parent_span > 0: + p_start = parent_ps + (char_start - parent_start) / parent_span * page_span + p_end = parent_ps + (char_end - parent_start) / parent_span * page_span + page_range = (int(p_start), max(int(p_start), int(p_end))) + + content_type = sec.get("content_type", "text") + if content_type not in ("text", "table"): + content_type = "text" + + nodes.append(TreeNode( + node_id=self._unique_node_id(char_start, seen_ids), + title=sec.get("title", f"Sub-section {i + 1}"), + summary="", + char_range=(char_start, char_end), + level=child_level, + page_range=page_range, + content_type=content_type, + )) + return nodes + + async def _enrich_node_summaries( + self, + node: TreeNode, + content: str, + *, + max_summary_len: int = 200, + ) -> None: + """Post-order traversal to enrich empty summaries: leaf from content, non-leaf via LLM.""" + # Post-order: process children first + for child in node.children: + await self._enrich_node_summaries( + child, content, max_summary_len=max_summary_len, + ) + + if self._summary_needs_enrichment(node.summary): + if node.leaf: + node.summary = self._extract_leaf_summary( + content, node.char_range, max_summary_len, + ) + else: + node.summary = await self._generate_nonleaf_summary( + node, max_summary_len, + ) + + @staticmethod + def _summary_needs_enrichment(summary: str) -> bool: + """Check whether a summary is empty or too short to be useful.""" + return not summary or len(summary.strip()) < 10 + + @staticmethod + def _extract_leaf_summary( + content: str, + char_range: Tuple[int, int], + max_len: int, + ) -> str: + """Extract a concise summary for a leaf node from its content slice.""" + start, end = char_range + raw = content[start:end][:500] + # Clean to single line + return " ".join(raw.split())[:max_len] + + async def _generate_nonleaf_summary( + self, + node: TreeNode, + max_summary_len: int, + ) -> str: + """Generate a summary for a non-leaf node via LLM, with fallback.""" + children_listing = "\n".join( + f"- {c.title}: {c.summary[:100]}" for c in node.children + ) + prompt = ( + "Summarize this document section in 1-2 concise sentences.\n" + f'Section: "{node.title}"\n' + f"Sub-sections:\n{children_listing}\n\n" + "Return ONLY the summary text." + ) + try: + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + return resp.content.strip()[:max_summary_len] + except Exception: + # Fallback: concatenate children titles + return ", ".join(c.title for c in node.children)[:max_summary_len] + + # ------------------------------------------------------------------ # + # Parsing / snippet helpers # + # ------------------------------------------------------------------ # + + @staticmethod + def _truncate_snippet( + text: str, + *, + head_chars: int = 3000, + tail_chars: int = 1000, + ) -> str: + """Truncate a long text snippet keeping head and tail with an ellipsis marker.""" + if len(text) <= head_chars + tail_chars: + return text + return text[:head_chars] + "\n...[truncated]...\n" + text[-tail_chars:] + + @staticmethod + def _parse_json_array(raw: str) -> List[Dict[str, Any]]: + """Extract and parse a JSON array from LLM output.""" + cleaned = re.sub(r"^```(?:json)?\s*", "", raw, flags=re.MULTILINE) + cleaned = re.sub(r"```\s*$", "", cleaned, flags=re.MULTILINE).strip() + m = re.search(r"\[.*\]", cleaned, re.DOTALL) + if m: + return json.loads(m.group()) + return [] + + @staticmethod + def _collect_node_ids(node: TreeNode) -> set: + """Collect all existing node_ids in the subtree.""" + ids = {node.node_id} + for c in node.children: + ids.update(DocumentTreeIndexer._collect_node_ids(c)) + return ids + @staticmethod def should_build_tree(file_path: str, content_length: int) -> bool: """Determine whether a file is eligible for tree indexing.""" diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 9374e27..42424ab 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -3818,17 +3818,21 @@ async def _tree_guided_sample( parts: List[str] = [] total_chars = 0 for leaf in leaves[: self._TREE_SAMPLE_MAX_SECTIONS]: - start, end = leaf.char_range - if self._is_valid_char_range(start, end, len(full_text)) and full_text: - segment = full_text[start:end] - elif leaf.summary: - _loguru_logger.debug( - f"[TreeNav] char_range degraded for '{leaf.title}' " - f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" - ) + # Table nodes: prefer summary (contains table markdown) + if getattr(leaf, 'content_type', 'text') == 'table' and leaf.summary: segment = leaf.summary else: - continue + start, end = leaf.char_range + if self._is_valid_char_range(start, end, len(full_text)) and full_text: + segment = full_text[start:end] + elif leaf.summary: + _loguru_logger.debug( + f"[TreeNav] char_range degraded for '{leaf.title}' " + f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" + ) + segment = leaf.summary + else: + continue segment = segment[: self._TREE_SAMPLE_SECTION_MAX_CHARS] if not segment.strip(): continue @@ -3836,7 +3840,8 @@ async def _tree_guided_sample( if leaf.page_range: ps, pe = leaf.page_range page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" - header = f"[{fname} → {leaf.title}{page_info}]" + type_tag = " [TABLE]" if getattr(leaf, 'content_type', 'text') == 'table' else "" + header = f"[{fname} → {leaf.title}{page_info}{type_tag}]" chunk = f"{header}\n{segment}" if total_chars + len(chunk) > max_chars: remaining = max_chars - total_chars @@ -4006,23 +4011,28 @@ async def _navigate_tree_for_evidence( full_text = "" for leaf in leaves: - start, end = leaf.char_range - if self._is_valid_char_range(start, end, len(full_text)) and full_text: - segment = full_text[start:end] - elif leaf.summary: - _loguru_logger.debug( - f"[TreeNav] char_range degraded for '{leaf.title}' " - f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" - ) + # Table nodes: prefer summary (contains table markdown) + if getattr(leaf, 'content_type', 'text') == 'table' and leaf.summary: segment = leaf.summary else: - continue + start, end = leaf.char_range + if self._is_valid_char_range(start, end, len(full_text)) and full_text: + segment = full_text[start:end] + elif leaf.summary: + _loguru_logger.debug( + f"[TreeNav] char_range degraded for '{leaf.title}' " + f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" + ) + segment = leaf.summary + else: + continue if segment.strip(): page_info = "" if leaf.page_range: ps, pe = leaf.page_range page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" - header = f"[{fname} → {leaf.title}{page_info}]" + type_tag = " [TABLE]" if getattr(leaf, 'content_type', 'text') == 'table' else "" + header = f"[{fname} → {leaf.title}{page_info}{type_tag}]" parts.append(f"{header}\n{segment[:3000]}") if not parts: From 579f8d6c64514d6f0b21d29a1ac304299111c275 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 20 Apr 2026 02:33:44 +0800 Subject: [PATCH 34/70] fix robust issue --- src/sirchmunk/learnings/compiler.py | 18 ++++++++++-------- src/sirchmunk/learnings/tree_indexer.py | 10 ++++++++-- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 531cb28..df2ca12 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -1265,14 +1265,16 @@ def _integrate_tables_into_tree( node.table_count = len(matched_tables) - # Create table child nodes only for leaf nodes with matched tables - if not node.children and matched_tables: - try: - self._spawn_table_children( - node, matched_tables, content, _counter, - ) - except Exception: - pass # Never break compile for table node creation + # NOTE: _spawn_table_children disabled - converting leaf to non-leaf breaks + # search navigation which expects leaves for char_range extraction. + # TODO: Re-enable when search can properly handle mixed text+table children. + # if not node.children and matched_tables: + # try: + # self._spawn_table_children( + # node, matched_tables, content, _counter, + # ) + # except Exception: + # pass @staticmethod def _is_pseudo_table(table: Dict[str, Any]) -> bool: diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 060c9b8..eb56f6e 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -217,7 +217,10 @@ async def build_tree( toc_entries, content, total_pages=total_pages, ) if root is not None: - await self._deepen_large_leaves(root, content, max_depth=effective_depth) + # NOTE: _deepen_large_leaves disabled - char_range anchoring via LLM start_text + # is unreliable, causing overlapping ranges and search failures. + # TODO: Re-enable when robust char_range calculation is implemented. + # await self._deepen_large_leaves(root, content, max_depth=effective_depth) await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, @@ -238,7 +241,10 @@ async def build_tree( if root is None: return None - await self._deepen_large_leaves(root, content, max_depth=effective_depth) + # NOTE: _deepen_large_leaves disabled - char_range anchoring via LLM start_text + # is unreliable, causing overlapping ranges and search failures. + # TODO: Re-enable when robust char_range calculation is implemented. + # await self._deepen_large_leaves(root, content, max_depth=effective_depth) await self._enrich_node_summaries(root, content) tree = DocumentTree( From 78c11170ba6ed96b361529baba1f89495905aa13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 20 Apr 2026 15:56:50 +0800 Subject: [PATCH 35/70] fix pure tree search env --- benchmarks/financebench/run_benchmark.py | 5 ++ src/sirchmunk/learnings/tree_indexer.py | 64 +++++++++--------------- src/sirchmunk/search.py | 2 + 3 files changed, 31 insertions(+), 40 deletions(-) diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py index cf7b30a..65af87d 100644 --- a/benchmarks/financebench/run_benchmark.py +++ b/benchmarks/financebench/run_benchmark.py @@ -27,6 +27,8 @@ from pathlib import Path from typing import List +from dotenv import load_dotenv + from config import FinanceBenchConfig from data_loader import FinanceBenchLoader from evaluate import compute_metrics @@ -158,6 +160,9 @@ def main() -> None: ) args = parser.parse_args() + # Load .env into os.environ so SIRCHMUNK_* variables are visible globally + load_dotenv(args.env, override=True) + # 1. Load config cfg = FinanceBenchConfig.from_env(args.env) if args.limit is not None: diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index eb56f6e..e1a652f 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -221,7 +221,9 @@ async def build_tree( # is unreliable, causing overlapping ranges and search failures. # TODO: Re-enable when robust char_range calculation is implemented. # await self._deepen_large_leaves(root, content, max_depth=effective_depth) - await self._enrich_node_summaries(root, content) + # NOTE: _enrich_node_summaries disabled temporarily to isolate its impact. + # The summaries may inadvertently bias _select_children() navigation. + # await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, file_hash=file_hash, @@ -245,7 +247,9 @@ async def build_tree( # is unreliable, causing overlapping ranges and search failures. # TODO: Re-enable when robust char_range calculation is implemented. # await self._deepen_large_leaves(root, content, max_depth=effective_depth) - await self._enrich_node_summaries(root, content) + # NOTE: _enrich_node_summaries disabled temporarily to isolate its impact. + # The summaries may inadvertently bias _select_children() navigation. + # await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, @@ -268,48 +272,32 @@ async def navigate( query: str, *, max_results: int = 3, - max_nav_depth: int = 4, ) -> List[TreeNode]: - """Reasoning-based tree navigation: LLM selects the most relevant branches. - - Iteratively descends through the tree until leaf nodes are reached or - *max_nav_depth* selection rounds are exhausted. - - Returns up to *max_results* leaf nodes with their char_range for - precise evidence extraction. - """ + """LLM-driven branch selection using _select_children().""" if tree.root is None: return [] - current = tree.root.children if tree.root.children else [tree.root] - if not current: + candidates = tree.root.children if tree.root.children else [tree.root] + if not candidates: return [tree.root] - selected: List[TreeNode] = current - for _ in range(max_nav_depth): - selected = await self._select_children(current, query) - if not selected: - break - # All leaves — stop descending - if all(n.leaf for n in selected): - break - # Expand non-leaf children, keep leaves as-is - next_level: List[TreeNode] = [] - for n in selected: - if n.leaf: - next_level.append(n) - else: - next_level.extend(n.children) - if not next_level: - break - current = next_level - else: - selected = current + selected = await self._select_children(candidates, query) + if not selected: + return [] + + result_leaves: List[TreeNode] = [] + for node in selected: + if node.leaf: + result_leaves.append(node) + else: + deeper = await self._select_children(node.children, query) + for d in (deeper or node.children[:1]): + result_leaves.extend(d.all_leaves()[:max_results]) # Deduplicate and cap seen_ids: set = set() unique: List[TreeNode] = [] - for n in (selected or current): + for n in result_leaves: if n.node_id not in seen_ids: seen_ids.add(n.node_id) unique.append(n) @@ -630,17 +618,13 @@ async def _select_children( listing = "\n".join( f"[{i}] {n.title}{self._format_page_range(n.page_range)}" - f" [{n.content_type.upper()}]" f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" - f": {n.summary[:200]}" + f": {n.summary[:150]}" for i, n in enumerate(nodes) ) + prompt = ( f"Given the query: \"{query}\"\n\n" - "Guidelines:\n" - "- For numerical/financial data queries, prefer TABLE nodes and consolidated statements\n" - "- Prefer company-wide/consolidated data over segment-level unless query specifies a segment\n" - "- When multiple tables exist, select the one most directly answering the query\n\n" f"Select the 1-2 most relevant sections (by index number):\n{listing}\n\n" f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" ) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 42424ab..5e497f2 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -1453,6 +1453,8 @@ async def search( return ctx return msg + await self._logger.info(f"[SearchConfig] PURE_TREE_SEARCH={'enabled' if _PURE_TREE_SEARCH else 'disabled'}") + # ---- Chat intent short-circuit (rule-based, no LLM cost) ---- if mode != "FILENAME_ONLY" and self._is_chat_query(query): answer, cluster, ctx = await self._respond_chat(query, chat_history=chat_history) From 86d528ed208290501d730d3059c6a0ea4df94768 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 20 Apr 2026 20:32:02 +0800 Subject: [PATCH 36/70] improve tree index --- src/sirchmunk/learnings/compiler.py | 290 +++++++++++++ src/sirchmunk/learnings/tree_indexer.py | 480 +++++++++++++++++++++- src/sirchmunk/search.py | 248 +++++++++-- src/sirchmunk/utils/document_extractor.py | 91 ++++ 4 files changed, 1048 insertions(+), 61 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index df2ca12..92dba7f 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -12,6 +12,7 @@ import math import os import random +import re import hashlib from dataclasses import dataclass, field from datetime import datetime, timezone @@ -51,6 +52,17 @@ _SUMMARY_SAMPLE_SECTIONS = 3 # Number of sections to sample for large docs _SUMMARY_SAMPLE_SECTION_CHARS = 5_000 # Chars per sampled section +# Targeted table extraction: max chars per table region +_TARGETED_TABLE_MAX_CHARS = 5000 + +# Targeted table extraction: only process nodes spanning <= N pages +_TABLE_PAGE_SPAN_LIMIT = 5 + +# Numeric density threshold – fraction of numeric/symbol chars ($, %, digits, +# parenthesised numbers) relative to total non-whitespace chars. Pages below +# this threshold are skipped during targeted extraction. +_TABLE_NUMERIC_DENSITY_THRESHOLD = 0.15 + # Excel table-level adaptive sampling constants _XLSX_TOTAL_ROW_BUDGET = 100 # Total sampled rows budget across all sheets _XLSX_MIN_ROWS_PER_SHEET = 3 # Minimum sampled rows per sheet @@ -645,6 +657,51 @@ async def _compile_single_file( content=content, total_pages=extraction.page_count, ) + # Phase 2.5: Targeted table extraction via generic structural signals + if result.tree and result.tree.root and ext == ".pdf": + targeted_tables = await self._targeted_table_extraction( + entry.path, result.tree, + ) + if targeted_tables: + # Load existing table digest (if any) and merge + digest_dir = self._compile_dir / "table_digests" + file_hash = get_fast_hash(entry.path) or "" + existing_digest: list[dict] = [] + if file_hash and result.has_table_digest: + digest_path = digest_dir / f"{file_hash}.json" + if digest_path.exists(): + try: + raw = json.loads( + digest_path.read_text(encoding="utf-8") + ) + existing_digest = raw.get("tables", []) + except Exception: + pass + merged = self._merge_table_digests( + existing_digest, targeted_tables, + ) + if merged and file_hash: + digest_dir.mkdir(parents=True, exist_ok=True) + digest_path = digest_dir / f"{file_hash}.json" + digest_path.write_text( + json.dumps( + { + "version": 1, + "table_count": len(merged), + "tables": merged, + }, + ensure_ascii=False, + ), + encoding="utf-8", + ) + result.has_table_digest = True + result.table_count = len(merged) + await self._log.info( + f"[Compile] Targeted table extraction added " + f"{len(targeted_tables)} tables for " + f"{Path(entry.path).name}" + ) + except Exception as exc: result.error = str(exc) await self._log.warning(f"[Compile] Failed: {entry.path}: {exc}") @@ -1409,6 +1466,239 @@ def _count(node: Any) -> int: return _count(tree.root) + # ------------------------------------------------------------------ # + # Targeted table extraction # + # ------------------------------------------------------------------ # + + async def _targeted_table_extraction( + self, file_path: str, tree: DocumentTree, + ) -> list[dict]: + """Extract tables from tree nodes likely containing tabular data. + + Uses generic structural signals (metadata, page span, numeric + density) instead of domain-specific title keywords. For each + candidate with a valid ``page_range``, extracts per-page text + via :meth:`DocumentExtractor.extract_page_range` and applies + heuristic table-region detection. Pages whose numeric density + falls below ``_TABLE_NUMERIC_DENSITY_THRESHOLD`` are skipped. + + Returns: + List of table dicts compatible with the table-digest format:: + + {"page": int, "content": str, "source": str} + """ + if tree is None or tree.root is None: + return [] + + candidates = self._find_table_candidate_nodes(tree.root) + if not candidates: + return [] + + await self._log.info( + f"[Compile] Targeted extraction: {len(candidates)} candidate " + f"nodes in {Path(file_path).name}" + ) + + results: list[dict] = [] + seen_pages: set[int] = set() + + for node in candidates: + if node.page_range is None: + continue + start_page, end_page = node.page_range + # Skip pages already processed by another candidate + page_nums = [p for p in range(start_page, end_page + 1) + if p not in seen_pages] + if not page_nums: + continue + + try: + pages = DocumentExtractor.extract_page_range( + file_path, start_page, end_page, + ) + except Exception as exc: + await self._log.warning( + f"[Compile] Targeted extraction page read failed " + f"({start_page}-{end_page}): {exc}" + ) + continue + + for pc in pages: + if pc.page_number in seen_pages: + continue + seen_pages.add(pc.page_number) + # Numeric density gate – skip pages unlikely to contain tables + if not self._page_has_table_density(pc.content): + continue + regions = self._identify_table_regions(pc.content) + for region in regions: + truncated = region[:_TARGETED_TABLE_MAX_CHARS] + results.append({ + "page": pc.page_number, + "content": truncated, + "source": f"targeted:{node.title[:80]}", + }) + + return results + + def _find_table_candidate_nodes( + self, root: "TreeNode", + ) -> list["TreeNode"]: + """Collect leaf nodes that likely contain tables. + + Uses generic, domain-agnostic structural signals (any match + suffices): + + - ``node.content_type == "table"`` – already tagged during compile. + - ``node.table_count > 0`` – known to contain tables. + - Has a valid ``page_range`` with span ≤ ``_TABLE_PAGE_SPAN_LIMIT``. + """ + candidates: list = [] + + def _walk(node: "TreeNode") -> None: + if node.leaf: + # Signal 1: content_type marked as table + if getattr(node, "content_type", None) == "table": + candidates.append(node) + return + # Signal 2: known to contain tables + if getattr(node, "table_count", 0) > 0: + candidates.append(node) + return + # Signal 3: moderate page span (tables rarely span many pages) + page_range = getattr(node, "page_range", None) + if page_range and len(page_range) == 2: + span = page_range[1] - page_range[0] + 1 + if 1 <= span <= _TABLE_PAGE_SPAN_LIMIT: + candidates.append(node) + else: + for child in node.children: + _walk(child) + + _walk(root) + return candidates + + @staticmethod + def _page_has_table_density(page_text: str) -> bool: + """Return True if *page_text* has numeric density above the threshold. + + Counts digits and common table symbols (``$``, ``%``, ``(``, ``)``) + relative to total non-whitespace characters. + """ + if not page_text: + return False + non_ws = sum(1 for ch in page_text if not ch.isspace()) + if non_ws == 0: + return False + numeric_chars = sum( + 1 for ch in page_text + if ch.isdigit() or ch in "$%(),.+-" + ) + return (numeric_chars / non_ws) >= _TABLE_NUMERIC_DENSITY_THRESHOLD + + @staticmethod + def _identify_table_regions(page_text: str) -> list[str]: + """Identify contiguous table-like regions in *page_text*. + + Heuristic rules: + - Lines containing multiple numeric tokens (dollar amounts, %, + parenthesised negatives) are considered *numeric rows*. + - A run of >= 3 consecutive numeric rows forms a table region. + - Leading/trailing whitespace rows are trimmed. + + Returns: + List of extracted region strings (may be empty). + """ + if not page_text: + return [] + + # Pattern: line has at least 2 numeric-looking tokens + _NUM_TOKEN = re.compile( + r"(?:" + r"[\$€£¥]\s*[\d,.]+|" + r"\([\d,.]+\)|" + r"[\d,.]+%|" + r"[\d]+\.[\d]+(?:[eE][+-]?\d+)?|" + r"[\d,]{2,}" + r")" + ) + _MIN_NUMS_PER_LINE = 2 + _MIN_CONSECUTIVE = 3 + + lines = page_text.split("\n") + is_numeric = [ + len(_NUM_TOKEN.findall(line)) >= _MIN_NUMS_PER_LINE + for line in lines + ] + + regions: list[str] = [] + run_start: int | None = None + + for i, flag in enumerate(is_numeric): + if flag: + if run_start is None: + run_start = i + else: + if run_start is not None: + run_len = i - run_start + if run_len >= _MIN_CONSECUTIVE: + # Include one context line above/below + start = max(0, run_start - 1) + end = min(len(lines), i + 1) + regions.append( + "\n".join(lines[start:end]).strip() + ) + run_start = None + + # Flush trailing run + if run_start is not None: + run_len = len(lines) - run_start + if run_len >= _MIN_CONSECUTIVE: + start = max(0, run_start - 1) + regions.append( + "\n".join(lines[start:]).strip() + ) + + return regions + + @staticmethod + def _get_table_page(entry: dict) -> int | None: + """统一获取表格条目的页码,兼容 page_number 和 page 两种字段名。""" + p = entry.get("page_number") or entry.get("page") + return int(p) if p is not None else None + + @classmethod + def _merge_table_digests( + cls, existing: list[dict], new_tables: list[dict], + ) -> list[dict]: + """Merge *new_tables* into *existing* digest, deduplicating by page. + + If an existing entry and a new entry share the same page number, + the new entry is skipped (existing kreuzberg-detected table takes + precedence because it has richer structure like cells/markdown). + + Returns: + Merged list suitable for storage in the table-digest JSON. + """ + existing_pages = {cls._get_table_page(e) for e in existing} + existing_pages.discard(None) + + merged = list(existing) + for tbl in new_tables: + page = cls._get_table_page(tbl) + if page is not None and page in existing_pages: + continue + # Normalise to digest table format for consistency + merged.append({ + "page_number": page, + "markdown": tbl.get("content", ""), + "row_count": None, + "col_count": None, + "cells": [], + "source": tbl.get("source", "targeted"), + }) + return merged + # ------------------------------------------------------------------ # # Summary index for embedding + BM25 fallback # # ------------------------------------------------------------------ # diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index e1a652f..96c44b9 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -8,7 +8,9 @@ """ import json +import math import re +from collections import Counter from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path @@ -153,6 +155,13 @@ def from_json(cls, json_str: str) -> "DocumentTree": class DocumentTreeIndexer: """Build and cache PageIndex-style hierarchical tree indices for documents.""" + # Maximum child nodes before switching to paginated LLM selection. + # Balance: lower = more LLM calls, higher = more tokens per call. + _PAGE_SIZE_THRESHOLD: int = 15 + + # Number of nodes per group in paginated selection. + _GROUP_PAGE_SIZE: int = 15 + def __init__( self, llm: OpenAIChat, @@ -272,8 +281,23 @@ async def navigate( query: str, *, max_results: int = 3, + max_depth: int = 4, ) -> List[TreeNode]: - """LLM-driven branch selection using _select_children().""" + """Adaptive-depth LLM-driven tree navigation. + + Iteratively descends the tree using _select_children() at each level, + collecting leaf nodes until *max_results* are found or *max_depth* is + reached. + + Args: + tree: DocumentTree with a root node. + query: Search query for relevance selection. + max_results: Maximum number of leaf nodes to return. + max_depth: Maximum descent depth (default 4). + + Returns: + List of the most relevant leaf TreeNodes. + """ if tree.root is None: return [] @@ -281,18 +305,41 @@ async def navigate( if not candidates: return [tree.root] - selected = await self._select_children(candidates, query) - if not selected: - return [] - result_leaves: List[TreeNode] = [] - for node in selected: - if node.leaf: - result_leaves.append(node) - else: - deeper = await self._select_children(node.children, query) - for d in (deeper or node.children[:1]): - result_leaves.extend(d.all_leaves()[:max_results]) + visited: set = set() # prevent cycles + frontier = candidates + selected: List[TreeNode] = [] + + depth = 0 + while depth < max_depth and frontier: + selected = await self._select_children( + frontier, query, max_selections=max_results, + ) + if not selected: + break + + next_frontier: List[TreeNode] = [] + for node in selected: + node_id = id(node) + if node_id in visited: + continue + visited.add(node_id) + + if node.leaf or not node.children: + result_leaves.append(node) + else: + next_frontier.extend(node.children) + + if len(result_leaves) >= max_results: + break + + frontier = next_frontier + depth += 1 + + # Fallback: if no leaves found, expand last selected nodes + if not result_leaves and selected: + for node in selected: + result_leaves.extend(node.all_leaves()[:max_results]) # Deduplicate and cap seen_ids: set = set() @@ -341,6 +388,9 @@ async def _build_tree_from_toc( Returns: Root TreeNode, or None if no children could be created. """ + # Infer hierarchy when TOC entries are flat (all same level) + toc_entries = self._infer_hierarchy(toc_entries) + seen_ids: set = set() children = self._toc_entries_to_nodes( toc_entries, content, len(content), seen_ids, @@ -610,12 +660,21 @@ def _resolve_positions( ] async def _select_children( - self, nodes: List[TreeNode], query: str, + self, nodes: List[TreeNode], query: str, *, max_selections: int = 3, ) -> List[TreeNode]: - """LLM-driven branch selection: pick the most relevant children.""" + """LLM-driven branch selection: pick the most relevant children. + + Dispatches to paginated selection when *nodes* exceeds + ``_PAGE_SIZE_THRESHOLD`` to avoid overwhelming the LLM. + """ if len(nodes) <= 2: return nodes + if len(nodes) > self._PAGE_SIZE_THRESHOLD: + return await self._select_children_paginated( + nodes, query, max_selections=max_selections, + ) + listing = "\n".join( f"[{i}] {n.title}{self._format_page_range(n.page_range)}" f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" @@ -634,10 +693,120 @@ async def _select_children( m = re.search(r"\[[\d\s,]+\]", raw) if m: indices = json.loads(m.group()) - return [nodes[i] for i in indices if 0 <= i < len(nodes)] + selected = [nodes[i] for i in indices if 0 <= i < len(nodes)] + return selected if selected else nodes[:max_selections] except (json.JSONDecodeError, IndexError, TypeError): pass - return nodes[:2] + return nodes[:max_selections] + + async def _select_children_paginated( + self, + nodes: List[TreeNode], + query: str, + *, + page_size: int = 15, + max_selections: int = 3, + ) -> List[TreeNode]: + """Two-phase paginated selection for large node sets. + + Phase 1: partition *nodes* into sequential groups of *page_size*, + present group summaries to LLM, and select 1-2 groups. + Phase 2: run fine-grained selection within each chosen group. + + Falls back to the first *max_selections* nodes on any LLM failure. + """ + page_size = max(page_size, self._GROUP_PAGE_SIZE) + + # --- Phase 0: build groups --- + groups: List[List[TreeNode]] = [] + for start in range(0, len(nodes), page_size): + groups.append(nodes[start:start + page_size]) + + if len(groups) <= 1: + # Only one group — skip directly to fine-grained selection + return await self._select_from_group(nodes, query, max_selections) + + # --- Phase 1: group-level selection --- + group_listing = "\n".join( + f"[{i}] {g[0].title} ... {g[-1].title} ({len(g)} sections)" + for i, g in enumerate(groups) + ) + group_prompt = ( + f"Given the query: \"{query}\"\n\n" + f"The document has {len(nodes)} sections organized into " + f"{len(groups)} groups.\n" + f"Select the 1-2 most relevant groups (by index number):\n" + f"{group_listing}\n\n" + f"Return ONLY a JSON array of group index numbers, e.g. [0, 2]" + ) + + selected_groups: List[List[TreeNode]] = [] + try: + resp = await self._llm.achat( + [{"role": "user", "content": group_prompt}], + ) + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + g_indices = json.loads(m.group()) + selected_groups = [ + groups[i] for i in g_indices if 0 <= i < len(groups) + ] + except (json.JSONDecodeError, IndexError, TypeError): + pass + + if not selected_groups: + # Fallback: take the first group + selected_groups = [groups[0]] + + # --- Phase 2: fine-grained selection within chosen groups --- + results: List[TreeNode] = [] + for group in selected_groups: + picked = await self._select_from_group(group, query, max_selections) + results.extend(picked) + + # Deduplicate by node_id and cap + seen: set = set() + unique: List[TreeNode] = [] + for n in results: + if n.node_id not in seen: + seen.add(n.node_id) + unique.append(n) + return unique[:max_selections] if unique else nodes[:max_selections] + + async def _select_from_group( + self, + group: List[TreeNode], + query: str, + max_selections: int, + ) -> List[TreeNode]: + """Select the most relevant nodes within a single group via LLM.""" + if len(group) <= 2: + return group + + listing = "\n".join( + f"[{i}] {n.title}{self._format_page_range(n.page_range)}" + f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" + f": {n.summary[:150]}" + for i, n in enumerate(group) + ) + prompt = ( + f"Given the query: \"{query}\"\n\n" + f"Select the 1-2 most relevant sections (by index number):\n{listing}\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + try: + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + indices = json.loads(m.group()) + selected = [group[i] for i in indices if 0 <= i < len(group)] + if selected: + return selected[:max_selections] + except (json.JSONDecodeError, IndexError, TypeError): + pass + return group[:max_selections] # ------------------------------------------------------------------ # # Cache I/O # @@ -924,3 +1093,282 @@ def should_build_tree(file_path: str, content_length: int) -> bool: """Determine whether a file is eligible for tree indexing.""" ext = Path(file_path).suffix.lower() return ext in _TREE_EXTENSIONS and content_length >= _TREE_MIN_CHARS + + # ------------------------------------------------------------------ # + # Hierarchy inference for flat TOC entries # + # ------------------------------------------------------------------ # + + # Minimum number of TOC entries to trigger hierarchy inference. + # Documents with fewer entries are typically already well-structured. + _FLAT_ENTRY_THRESHOLD = 20 + + # If this fraction of entries share the same level, consider it "flat" + # and apply hierarchy inference. Real hierarchies typically have + # varied level distribution. + _FLAT_LEVEL_RATIO = 0.9 + + # Number of entries per virtual group when using uniform grouping fallback. + _GROUP_SIZE = 15 + + @staticmethod + def _infer_hierarchy(entries: List[Any]) -> List[Any]: + """When all entries share the same level, infer hierarchy from title patterns. + + Applies three strategies in priority order: + A. Keyword groups — detect repeated structural prefixes (generic) + B. Generic numbering patterns (1., 1.1, I., A., etc.) + C. Uniform grouping fallback (virtual parent nodes) + + Only activates when >90% of entries share the same level and + the total count exceeds ``_FLAT_ENTRY_THRESHOLD``. + + Args: + entries: List of TOCEntry (may be nested). + + Returns: + Possibly restructured list of TOCEntry with updated levels + and rebuilt hierarchy. + """ + if not entries: + return entries or [] + + try: + from sirchmunk.learnings.toc_extractor import TOCExtractor + flat: List[Any] = [] + TOCExtractor._flatten_entries(entries, flat) + except Exception: + return entries # Cannot flatten; return original entries + + if not flat: + return entries + + if len(flat) <= DocumentTreeIndexer._FLAT_ENTRY_THRESHOLD: + return entries + + # Validate level field: skip entries with invalid levels + valid_flat = [e for e in flat if hasattr(e, 'level') and isinstance(e.level, (int, float))] + if not valid_flat: + return entries + + # Check if >90% share the same level + level_counts = Counter(e.level for e in valid_flat) + dominant_level, dominant_count = level_counts.most_common(1)[0] + if dominant_count / len(flat) <= DocumentTreeIndexer._FLAT_LEVEL_RATIO: + return entries # Already has meaningful hierarchy + + # Try strategies in priority order + modified = DocumentTreeIndexer._strategy_keyword_groups(flat, dominant_level) + if modified is None: + modified = DocumentTreeIndexer._strategy_numbering(flat, dominant_level) + if modified is None: + modified = DocumentTreeIndexer._strategy_uniform_grouping( + flat, dominant_level, + ) + if modified is None: + return entries + + # Rebuild hierarchy from the re-leveled flat list + return TOCExtractor._build_hierarchy(modified) + + # -- Strategy A: keyword groups (generic structural prefix detection) # + + # Pattern: title starts with a capitalized word optionally followed by + # a Roman numeral or Arabic number (e.g. "PART IV", "Item 1A", + # "Section 3", "Chapter 12", "Article II"). + _RE_STRUCTURAL_PREFIX = re.compile( + r'^([A-Z][A-Za-z]*(?:\s+[IVXLCDM\d]+[A-Za-z]?)?)\b', + ) + + @staticmethod + def _extract_structural_prefix(title: str) -> Optional[str]: + """Extract a structural prefix from a title. + + Matches leading capitalized words optionally followed by a number + or Roman numeral (e.g. "PART IV", "Item 1A", "Section 3"). + Returns the normalized (uppercased) prefix, or None. + """ + if not title or not title.strip(): + return None + m = DocumentTreeIndexer._RE_STRUCTURAL_PREFIX.match(title.strip()) + if m: + prefix = m.group(1).strip() + # Prefix must not be too long (avoid capturing entire title) + if len(prefix) <= 20: + return prefix.upper() + return None + + @staticmethod + def _strategy_keyword_groups( + flat: List[Any], + dominant_level: int, + ) -> Optional[List[Any]]: + """Strategy A — detect repeated structural prefixes and infer levels. + + Works for any document with repetitive heading patterns (SEC filings, + legal contracts, technical specs, etc.). Automatically discovers + prefix groups and assigns hierarchical levels based on frequency: + lower-frequency prefixes become higher-level parents. + + Returns re-leveled flat list, or None if coverage is insufficient. + """ + # 1. Extract prefix for each entry + prefix_map: Dict[str, List[int]] = {} # prefix -> [entry indices] + for i, e in enumerate(flat): + prefix = DocumentTreeIndexer._extract_structural_prefix(e.title) + if prefix: + prefix_map.setdefault(prefix, []).append(i) + + # 2. Keep only prefixes appearing >= 2 times + repeated_prefixes = {k: v for k, v in prefix_map.items() if len(v) >= 2} + if not repeated_prefixes: + return None + + # 3. Check coverage: at least 30% of entries must be covered + covered = sum(len(indices) for indices in repeated_prefixes.values()) + if covered < len(flat) * 0.3: + return None + + # 4. Sort prefixes by frequency (ascending) then by first appearance + # Low frequency = higher level (parent), high frequency = lower level + sorted_prefixes = sorted( + repeated_prefixes.items(), + key=lambda x: (len(x[1]), min(x[1])), + ) + + # 5. Assign level per prefix group + prefix_to_level: Dict[str, int] = {} + for level_idx, (prefix, _) in enumerate(sorted_prefixes): + prefix_to_level[prefix] = level_idx + 1 + + # 6. Determine the "other" level for entries without a known prefix + max_level = max(prefix_to_level.values()) + 1 + + # 7. Apply levels + for i, e in enumerate(flat): + prefix = DocumentTreeIndexer._extract_structural_prefix(e.title) + if prefix and prefix in prefix_to_level: + e.level = prefix_to_level[prefix] + else: + e.level = max_level + e.children = [] + + return flat + + # -- Strategy B: generic numbering --------------------------------- # + + # Three-level numbering: 1.1.1, (a), (i), (1) + _RE_NUM_LEVEL3 = re.compile( + r"^\s*(?:\d+\.\d+\.\d+|\([a-z]\)|\([ivx]+\)|\(\d+\))\s", + re.IGNORECASE, + ) + # Two-level numbering: 1.1, A., B., a., b. + _RE_NUM_LEVEL2 = re.compile( + r"^\s*(?:\d+\.\d+(?!\.)\b|[A-Z]\.\s|[a-z]\.\s)", + ) + # Top-level numbering: 1., 2., I., II. + _RE_NUM_LEVEL1 = re.compile( + r"^\s*(?:\d+\.\s|[IVXLC]+\.\s)", + ) + + @staticmethod + def _strategy_numbering( + flat: List[Any], + dominant_level: int, + ) -> Optional[List[Any]]: + """Strategy B — detect generic numbering patterns. + + Returns re-leveled flat list, or None if fewer than 30% of + entries match any numbering pattern. + """ + matched = 0 + assignments: List[Optional[int]] = [] + + for e in flat: + title = e.title + if DocumentTreeIndexer._RE_NUM_LEVEL3.match(title): + assignments.append(3) + matched += 1 + elif DocumentTreeIndexer._RE_NUM_LEVEL2.match(title): + assignments.append(2) + matched += 1 + elif DocumentTreeIndexer._RE_NUM_LEVEL1.match(title): + assignments.append(1) + matched += 1 + else: + assignments.append(None) + + if matched < len(flat) * 0.3: + return None + + # Apply assignments; entries without a pattern get the level of + # the previous entry + 1 (capped at 3) + prev_level = 1 + for i, e in enumerate(flat): + if assignments[i] is not None: + e.level = assignments[i] + else: + e.level = min(prev_level + 1, 3) + prev_level = e.level + e.children = [] + return flat + + # -- Strategy C: uniform grouping fallback ------------------------- # + + @staticmethod + def _strategy_uniform_grouping( + flat: List[Any], + dominant_level: int, + ) -> Optional[List[Any]]: + """Strategy C — group entries into fixed-size buckets with virtual parents. + + Creates synthetic parent TOCEntry nodes whose char_start/char_end + and page_start/page_end are derived from the first and last child + in each group. + + Returns the re-leveled flat list including virtual parents, or None + on error. + """ + from sirchmunk.learnings.toc_extractor import TOCEntry + + group_size = DocumentTreeIndexer._GROUP_SIZE + num_groups = math.ceil(len(flat) / group_size) + if num_groups <= 1: + return None # Grouping would not improve anything + + parent_level = max(1, dominant_level - 1) if dominant_level > 1 else 1 + child_level = parent_level + 1 + + result: List[Any] = [] + for g in range(num_groups): + start_idx = g * group_size + end_idx = min((g + 1) * group_size, len(flat)) + group = flat[start_idx:end_idx] + + first = group[0] + last = group[-1] + + # Derive positions from children + char_start = first.char_start + char_end = last.char_end if last.char_end else None + page_start = first.page_start + page_end = last.page_start # Best available estimate + + virtual_parent = TOCEntry( + title=f"{first.title} \u2013 {last.title}", + level=parent_level, + char_start=char_start, + char_end=char_end, + page_start=page_start, + page_end=page_end, + children=[], + source="inferred", + ) + result.append(virtual_parent) + + # Set child level + for e in group: + e.level = child_level + e.children = [] + result.extend(group) + + return result diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 5e497f2..52c0db3 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -15,6 +15,7 @@ from sirchmunk.base import BaseSearch from sirchmunk.learnings.knowledge_base import KnowledgeBase +from sirchmunk.utils.document_extractor import DocumentExtractor from sirchmunk.llm.openai_chat import OpenAIChat from sirchmunk.llm.prompts import ( KEYWORD_QUERY_PLACEHOLDER, @@ -3808,42 +3809,100 @@ async def _tree_guided_sample( if not leaves: return None - # --- Read full text once for char_range slicing --- - try: - from sirchmunk.utils.file_utils import fast_extract - extraction = await fast_extract(file_path=file_path) - full_text = extraction.content or "" - except Exception: - full_text = "" + # --- Classify leaves by extraction method --- + trimmed = leaves[: self._TREE_SAMPLE_MAX_SECTIONS] + page_leaves, char_leaves, table_and_summary = self._classify_leaves(trimmed) - # --- Extract tree sections --- - parts: List[str] = [] - total_chars = 0 - for leaf in leaves[: self._TREE_SAMPLE_MAX_SECTIONS]: - # Table nodes: prefer summary (contains table markdown) - if getattr(leaf, 'content_type', 'text') == 'table' and leaf.summary: - segment = leaf.summary + # Collect (leaf, segment) pairs preserving original leaf order + leaf_segments: List[tuple] = [] # (leaf, segment_text) + + # -- Phase A: table / summary-only leaves -- + for leaf in table_and_summary: + leaf_segments.append((leaf, leaf.summary)) + + # -- Phase B: batch page-level extraction (single IO) -- + page_segment_map: dict = {} # id(leaf) -> segment + if page_leaves: + all_pages: set = set() + for _leaf, (sp, ep) in page_leaves: + all_pages.update(range(sp, ep + 1)) + try: + page_contents = DocumentExtractor.extract_pages( + file_path, sorted(all_pages), + ) + page_map = {pc.page_number: pc.content for pc in page_contents} + + for leaf, (sp, ep) in page_leaves: + seg_parts = [] + for p in range(sp, ep + 1): + text = page_map.get(p, "") + if text.strip(): + seg_parts.append(text) + if seg_parts: + page_segment_map[id(leaf)] = "\n".join(seg_parts) + elif getattr(leaf, 'summary', None): + page_segment_map[id(leaf)] = leaf.summary + except (FileNotFoundError, PermissionError): + raise # 文件系统错误应传播 + except Exception as e: + _loguru_logger.warning( + f"[TreeSample] Page extraction failed for {fname}: {e}, " + f"falling back to char_range for {len(page_leaves)} leaves" + ) + # Demote page_leaves → char_leaves + for leaf, _ in page_leaves: + if hasattr(leaf, 'char_range') and leaf.char_range: + char_leaves.append(leaf) + elif getattr(leaf, 'summary', None): + leaf_segments.append((leaf, leaf.summary)) + page_leaves_ok = False else: + page_leaves_ok = True + + if page_leaves_ok: + for leaf, _ in page_leaves: + seg = page_segment_map.get(id(leaf)) + if seg: + leaf_segments.append((leaf, seg)) + # If page extraction failed, demoted leaves are now in char_leaves + + # -- Phase C: char_range fallback (lazy full-text extraction) -- + if char_leaves: + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + + for leaf in char_leaves: start, end = leaf.char_range if self._is_valid_char_range(start, end, len(full_text)) and full_text: segment = full_text[start:end] - elif leaf.summary: + if segment.strip(): + leaf_segments.append((leaf, segment)) + elif getattr(leaf, 'summary', None): + leaf_segments.append((leaf, leaf.summary)) + elif getattr(leaf, 'summary', None): _loguru_logger.debug( - f"[TreeNav] char_range degraded for '{leaf.title}' " + f"[TreeSample] char_range degraded for '{leaf.title}' " f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" ) - segment = leaf.summary - else: - continue + leaf_segments.append((leaf, leaf.summary)) + + # --- Build parts with budget control --- + parts: List[str] = [] + total_chars = 0 + for leaf, segment in leaf_segments: segment = segment[: self._TREE_SAMPLE_SECTION_MAX_CHARS] if not segment.strip(): continue page_info = "" - if leaf.page_range: + if getattr(leaf, 'page_range', None): ps, pe = leaf.page_range page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" type_tag = " [TABLE]" if getattr(leaf, 'content_type', 'text') == 'table' else "" - header = f"[{fname} → {leaf.title}{page_info}{type_tag}]" + header = f"[{fname} \u2192 {leaf.title}{page_info}{type_tag}]" chunk = f"{header}\n{segment}" if total_chars + len(chunk) > max_chars: remaining = max_chars - total_chars @@ -3889,6 +3948,36 @@ async def _tree_guided_sample( ) return evidence + @staticmethod + def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: + """将叶节点按提取策略分类。 + + Returns: + (page_leaves, char_leaves, summary_leaves) 三元组: + - page_leaves: list of (leaf, page_range) tuples — 有有效 page_range 的 + - char_leaves: list of leaf — 需要 char_range fallback 的 + - summary_leaves: list of leaf — 只有 summary 可用的 + """ + page_leaves: List[tuple] = [] + char_leaves: List = [] + summary_leaves: List = [] + + for leaf in leaves: + # 表格类型节点优先使用 summary(结构化摘要) + if getattr(leaf, 'content_type', 'text') == 'table' and getattr(leaf, 'summary', None): + summary_leaves.append(leaf) + continue + + page_range = getattr(leaf, 'page_range', None) + if page_range and len(page_range) == 2 and page_range[0] is not None and page_range[0] > 0: + page_leaves.append((leaf, page_range)) + elif hasattr(leaf, 'char_range') and leaf.char_range: + char_leaves.append(leaf) + elif getattr(leaf, 'summary', None): + summary_leaves.append(leaf) + + return page_leaves, char_leaves, summary_leaves + def _is_valid_char_range( self, start: int, end: int, text_len: int, ) -> bool: @@ -3978,6 +4067,23 @@ def _format_table_evidence( return "\n\n".join(parts) + @staticmethod + def _append_evidence_part( + parts: List[str], fname: str, leaf, segment: str, + *, max_chars: int = 3000, + ) -> None: + """Format and append one leaf's evidence to *parts* (in-place).""" + text = segment[:max_chars] + if not text.strip(): + return + page_info = "" + if getattr(leaf, 'page_range', None): + ps, pe = leaf.page_range + page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" + type_tag = " [TABLE]" if getattr(leaf, 'content_type', 'text') == 'table' else "" + header = f"[{fname} \u2192 {leaf.title}{page_info}{type_tag}]" + parts.append(f"{header}\n{text}") + async def _navigate_tree_for_evidence( self, file_path: str, query: str, *, max_results: int = 3, ) -> Optional[str]: @@ -3986,6 +4092,11 @@ async def _navigate_tree_for_evidence( Uses 1 LLM call to drill into the compiled tree index for *file_path*, returning concatenated leaf content as evidence. Returns None when no tree cache is available. + + Extraction priority (highest first): + 1. page_range – page-level extraction via DocumentExtractor + 2. char_range – full-text extraction + slice (fallback) + 3. leaf.summary – last resort """ indexer = self._get_tree_indexer() if indexer is None: @@ -4003,39 +4114,86 @@ async def _navigate_tree_for_evidence( return None fname = Path(file_path).name - # Read leaf content from the original document via char_range parts: List[str] = [] - try: - from sirchmunk.utils.file_utils import fast_extract - extraction = await fast_extract(file_path=file_path) - full_text = extraction.content or "" - except Exception: - full_text = "" - for leaf in leaves: - # Table nodes: prefer summary (contains table markdown) - if getattr(leaf, 'content_type', 'text') == 'table' and leaf.summary: - segment = leaf.summary - else: + # ── Phase 1: classify leaves by available extraction method ── + page_leaves, char_leaves, summary_only = self._classify_leaves(leaves) + + for leaf in summary_only: + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + + # ── Phase 2: batch page-level extraction (single IO) ── + if page_leaves: + all_pages: set = set() + for _leaf, (sp, ep) in page_leaves: + all_pages.update(range(sp, ep + 1)) + try: + page_contents = DocumentExtractor.extract_pages( + file_path, sorted(all_pages), + ) + page_map = {pc.page_number: pc.content for pc in page_contents} + + for leaf, (sp, ep) in page_leaves: + segment_parts = [] + for p in range(sp, ep + 1): + text = page_map.get(p, "") + if text.strip(): + segment_parts.append(text) + if segment_parts: + self._append_evidence_part( + parts, fname, leaf, "\n".join(segment_parts), + ) + elif getattr(leaf, 'summary', None): + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + except (FileNotFoundError, PermissionError): + raise # 文件系统错误应传播 + except Exception as e: + _loguru_logger.warning( + f"[TreeNav] Page extraction failed for {fname}: {e}, " + f"falling back to char_range for {len(page_leaves)} leaves" + ) + # Demote page_leaves → char_leaves for char_range fallback + for leaf, _ in page_leaves: + if hasattr(leaf, 'char_range') and leaf.char_range: + char_leaves.append(leaf) + elif getattr(leaf, 'summary', None): + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + + # ── Phase 3: char_range fallback (lazy full-text extraction) ── + if char_leaves: + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + + for leaf in char_leaves: start, end = leaf.char_range if self._is_valid_char_range(start, end, len(full_text)) and full_text: segment = full_text[start:end] - elif leaf.summary: + if segment.strip(): + self._append_evidence_part( + parts, fname, leaf, segment, + ) + elif getattr(leaf, 'summary', None): + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + elif getattr(leaf, 'summary', None): _loguru_logger.debug( f"[TreeNav] char_range degraded for '{leaf.title}' " f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" ) - segment = leaf.summary - else: - continue - if segment.strip(): - page_info = "" - if leaf.page_range: - ps, pe = leaf.page_range - page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" - type_tag = " [TABLE]" if getattr(leaf, 'content_type', 'text') == 'table' else "" - header = f"[{fname} → {leaf.title}{page_info}{type_tag}]" - parts.append(f"{header}\n{segment[:3000]}") + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) if not parts: return None diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py index d114b7d..a022a8d 100644 --- a/src/sirchmunk/utils/document_extractor.py +++ b/src/sirchmunk/utils/document_extractor.py @@ -110,6 +110,25 @@ class ExtractionOutput: """Number of pages in the source document (if available).""" +# --------------------------------------------------------------------------- +# Page-level extraction output +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class PageContent: + """Single page extraction result. + + Returned by :meth:`DocumentExtractor.extract_pages` to represent the + text content of one PDF page. + """ + + page_number: int + """1-indexed page number.""" + + content: str + """Extracted text content (may be empty string).""" + + # --------------------------------------------------------------------------- # Document extractor facade # --------------------------------------------------------------------------- @@ -276,6 +295,78 @@ async def batch_extract( logger.error("Batch extraction failed for {} files", len(file_paths)) raise + # Page-level extraction ------------------------------------------------- + + @staticmethod + def extract_pages( + file_path: Union[str, Path], + pages: list[int], + ) -> list[PageContent]: + """Extract text content from specific PDF pages. + + Uses pypdf to read individual pages by 1-indexed page number. + Invalid page numbers (< 1 or > total pages) are silently skipped. + + Args: + file_path: Path to a PDF file. + pages: List of 1-indexed page numbers to extract. + + Returns: + List of :class:`PageContent` for each valid requested page, + in the order given by *pages*. + + Raises: + FileNotFoundError: If *file_path* does not exist. + Exception: On PDF parsing failure (logged before re-raise). + """ + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"PDF file not found: {path}") + + try: + from pypdf import PdfReader + + reader = PdfReader(str(path)) + total = len(reader.pages) + valid_pages = [p for p in pages if 1 <= p <= total] + return [ + PageContent( + page_number=p, + content=reader.pages[p - 1].extract_text() or "", + ) + for p in valid_pages + ] + except FileNotFoundError: + raise + except Exception as exc: + logger.error( + "Page-level extraction failed for {}: {}", + file_path, + exc, + ) + raise + + @staticmethod + def extract_page_range( + file_path: Union[str, Path], + start_page: int, + end_page: int, + ) -> list[PageContent]: + """Extract text content from a contiguous range of PDF pages. + + Convenience wrapper around :meth:`extract_pages`. + + Args: + file_path: Path to a PDF file. + start_page: First page (1-indexed, inclusive). + end_page: Last page (1-indexed, inclusive). + + Returns: + List of :class:`PageContent` for the requested range. + """ + pages = list(range(start_page, end_page + 1)) + return DocumentExtractor.extract_pages(file_path, pages) + # Internal helpers ----------------------------------------------------- @staticmethod From 2f3a25777f212e5d43c2f2a3648f508d75545c0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 26 Apr 2026 17:53:10 +0800 Subject: [PATCH 37/70] improve search tree index --- src/sirchmunk/learnings/tree_indexer.py | 15 ++++--- src/sirchmunk/search.py | 59 ++++++++++++++++++------- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 96c44b9..a720bf3 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -9,6 +9,7 @@ import json import math +import os import re from collections import Counter from dataclasses import dataclass, field @@ -230,9 +231,10 @@ async def build_tree( # is unreliable, causing overlapping ranges and search failures. # TODO: Re-enable when robust char_range calculation is implemented. # await self._deepen_large_leaves(root, content, max_depth=effective_depth) - # NOTE: _enrich_node_summaries disabled temporarily to isolate its impact. - # The summaries may inadvertently bias _select_children() navigation. - # await self._enrich_node_summaries(root, content) + # Node summary enrichment: controlled by SIRCHMUNK_SKIP_NODE_SUMMARIES env var. + # Set to "true" to skip during debugging / performance testing. + if os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() not in ("true", "1", "yes"): + await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, file_hash=file_hash, @@ -256,9 +258,10 @@ async def build_tree( # is unreliable, causing overlapping ranges and search failures. # TODO: Re-enable when robust char_range calculation is implemented. # await self._deepen_large_leaves(root, content, max_depth=effective_depth) - # NOTE: _enrich_node_summaries disabled temporarily to isolate its impact. - # The summaries may inadvertently bias _select_children() navigation. - # await self._enrich_node_summaries(root, content) + # Node summary enrichment: controlled by SIRCHMUNK_SKIP_NODE_SUMMARIES env var. + # Set to "true" to skip during debugging / performance testing. + if os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() not in ("true", "1", "yes"): + await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 52c0db3..1c47a77 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2543,10 +2543,21 @@ async def _search_fast( f"{len(best_files)} compile-hint files" ) else: - await self._logger.warning( - "[FAST:PureTree] No tree probes available, returning empty" + # Graceful degradation: fall back to keyword search when no tree is available + await self._logger.info( + "[FAST:PureTree] No tree probes available, falling back to keyword search" ) - return _NO_RESULTS_MESSAGE, None, context + best_files = await self._fast_find_best_file( + primary, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, **rga_kwargs, + ) + if not best_files and fallback: + best_files = await self._fast_find_best_file( + fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, **rga_kwargs, + ) + if not best_files: + return _NO_RESULTS_MESSAGE, None, context else: # --- Original rga-based retrieval logic --- # High-confidence catalog routing: skip rga, use catalog directly @@ -2675,16 +2686,32 @@ async def _rga_evidence() -> str: pass # 0.5 Table digest priority (pre-compiled PDF table evidence) - if ev is None and artifacts and artifacts.manifest_map: - _me = artifacts.manifest_map.get(fp) - if _me and getattr(_me, 'has_table_digest', False): - _all_tables = self._load_table_digest( - self.work_path, _me.file_hash, - ) - if _all_tables: - _table_ev = self._format_table_evidence(_all_tables) - if _table_ev: - ev = f"[{fn} - Table Evidence]\n{_table_ev}" + _all_tables = None + if ev is None and artifacts: + # Primary: manifest-based lookup + if artifacts.manifest_map: + _me = artifacts.manifest_map.get(fp) + if _me and getattr(_me, 'has_table_digest', False): + _all_tables = self._load_table_digest( + self.work_path, _me.file_hash, + ) + + # Fallback: direct hash-based lookup when manifest misses + if not _all_tables: + try: + from sirchmunk.utils.file_utils import get_fast_hash + _file_hash = get_fast_hash(fp) + if _file_hash: + _all_tables = self._load_table_digest( + self.work_path, _file_hash, + ) + except Exception: + pass + + if _all_tables: + _table_ev = self._format_table_evidence(_all_tables) + if _table_ev: + ev = f"[{fn} - Table Evidence]\n{_table_ev}" # 1. Tree-guided sampling FIRST for tree-indexed files if ( @@ -3492,8 +3519,8 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: # Prefer manifest-based detection (fast, O(1) per file) if manifest_map: tree_paths = {fp for fp, entry in manifest_map.items() if entry.has_tree} - # Fallback: scan tree cache directory (legacy path) - elif indexer is not None: + # Always try directory fallback if manifest-based detection found nothing + if not tree_paths and indexer is not None: tree_cache = self.work_path / ".cache" / "compile" / "trees" if tree_cache.exists(): try: @@ -4904,7 +4931,7 @@ async def _probe_tree_for_fast( Returns file paths of selected documents, or empty list when trees are unavailable or cover too few files to justify an LLM call. """ - if not artifacts or len(artifacts.tree_available_paths) <= 2: + if not artifacts or not artifacts.tree_available_paths: return [] try: From 9dd47bed8cb12ff96b49ad0526cdc77e1e7dd88c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 26 Apr 2026 19:53:14 +0800 Subject: [PATCH 38/70] update log --- benchmarks/financebench/run_benchmark.py | 75 ++++++++++++++++++++++-- src/sirchmunk/learnings/compiler.py | 6 ++ src/sirchmunk/learnings/tree_indexer.py | 17 +++++- src/sirchmunk/search.py | 45 ++++++++++++++ 4 files changed, 134 insertions(+), 9 deletions(-) diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py index 65af87d..183a6d3 100644 --- a/benchmarks/financebench/run_benchmark.py +++ b/benchmarks/financebench/run_benchmark.py @@ -34,24 +34,61 @@ from evaluate import compute_metrics from runner import run_batch +# --------------------------------------------------------------------------- +# Tee stdout to log file +# --------------------------------------------------------------------------- + + +class _TeeWriter: + """Duplicate stdout to both terminal and a log file.""" + + def __init__(self, log_path: str) -> None: + self._terminal = sys.stdout + self._log = open(log_path, "w", encoding="utf-8") # noqa: SIM115 + + def write(self, msg: str) -> int: + self._terminal.write(msg) + self._log.write(msg) + return len(msg) + + def flush(self) -> None: + self._terminal.flush() + self._log.flush() + + def close(self) -> None: + self._log.close() + + # Let logging / other code check the stream capabilities + @property + def encoding(self) -> str: + return getattr(self._terminal, "encoding", "utf-8") + + def isatty(self) -> bool: + return False + + def fileno(self) -> int: + return self._terminal.fileno() + + # --------------------------------------------------------------------------- # Logging # --------------------------------------------------------------------------- -def setup_logging(output_dir: str) -> str: +def setup_logging(output_dir: str, ts: str | None = None) -> tuple[str, str]: """Configure logging to file + console. Creates a timestamped log file under ``logs/`` (relative to *output_dir*'s parent, i.e. the benchmark root directory). Returns: - Absolute path to the log file. + Tuple of (absolute path to the log file, timestamp string). """ log_dir = Path("logs") log_dir.mkdir(parents=True, exist_ok=True) - ts = datetime.now().strftime("%Y%m%d_%H%M%S") + if ts is None: + ts = datetime.now().strftime("%Y%m%d_%H%M%S") log_path = log_dir / f"benchmark_{ts}.log" root_logger = logging.getLogger("financebench") @@ -77,7 +114,7 @@ def setup_logging(output_dir: str) -> str: root_logger.addHandler(fh) root_logger.addHandler(ch) - return str(log_path.resolve()) + return str(log_path.resolve()), ts # --------------------------------------------------------------------------- @@ -169,9 +206,17 @@ def main() -> None: cfg.limit = args.limit # 2. Setup logging - log_path = setup_logging(cfg.output_dir) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + log_path, ts = setup_logging(cfg.output_dir, ts=ts) logger = logging.getLogger("financebench") + # 2b. Tee stdout → debug log so SEARCH_WIKI_DEBUG prints are captured + log_dir = Path("logs") + log_dir.mkdir(parents=True, exist_ok=True) + debug_log_path = log_dir / f"benchmark_{ts}_debug.log" + tee = _TeeWriter(str(debug_log_path)) + sys.stdout = tee + # Print config source info work_env = Path(cfg.work_path) / ".env" logger.info("=" * 50) @@ -250,7 +295,25 @@ def main() -> None: # 10. Print summary _print_summary(results, metrics, total_time, results_path, metrics_path, log_path) + print(f" Debug log: {debug_log_path.resolve()}") + + # 11. Restore stdout + sys.stdout = tee._terminal + tee.close() + + +def _main_safe() -> None: + """Wrapper that guarantees stdout is restored even on exceptions.""" + try: + main() + except (KeyboardInterrupt, Exception): + # Restore stdout if tee was installed + if hasattr(sys.stdout, "_terminal"): + terminal = sys.stdout._terminal + sys.stdout.close() + sys.stdout = terminal + raise if __name__ == "__main__": - main() + _main_safe() diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 92dba7f..62f3e19 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -425,6 +425,8 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: has_table_digest=result.has_table_digest, table_count=result.table_count, ) + _mentry = manifest.files[result.path] + print(f"SEARCH_WIKI_DEBUG [C4] manifest_entry: has_tree={_mentry.has_tree}, has_table_digest={_mentry.has_table_digest}, file_hash={_mentry.file_hash}", flush=True) # Phase 3: aggregate results into knowledge network await self._log.info("[Compile] Phase 3: Knowledge aggregation") @@ -559,6 +561,7 @@ async def _compile_single_file( the pipeline skips tree building and summarises via a direct LLM call. """ result = FileCompileResult(path=entry.path) + print(f"SEARCH_WIKI_DEBUG [C1] _compile_single_file: file_path={entry.path}, file_hash={entry.file_hash}", flush=True) try: await self._log.info(f"[Compile] Processing: {Path(entry.path).name}") @@ -599,6 +602,7 @@ async def _compile_single_file( # Record TOC / tree metrics on the result for manifest persistence result.has_explicit_toc = toc_entries is not None and len(toc_entries) > 0 result.tree_node_count = self._count_tree_nodes(result.tree) + print(f"SEARCH_WIKI_DEBUG [C2] tree_build: success={result.tree is not None}, nodes={result.tree_node_count}, tree.file_path={result.tree.file_path if result.tree else 'N/A'}", flush=True) # Enrich content with structural metadata for non-text types ext = Path(entry.path).suffix.lower() @@ -650,6 +654,8 @@ async def _compile_single_file( except Exception: pass + print(f"SEARCH_WIKI_DEBUG [C3] table_digest: generated={result.has_table_digest}, count={result.table_count}", flush=True) + # Integrate tables into tree: annotate counts + create table child nodes if result.tree and result.tree.root and extraction.tables: self._integrate_tables_into_tree( diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index a720bf3..2e2f909 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -233,7 +233,9 @@ async def build_tree( # await self._deepen_large_leaves(root, content, max_depth=effective_depth) # Node summary enrichment: controlled by SIRCHMUNK_SKIP_NODE_SUMMARIES env var. # Set to "true" to skip during debugging / performance testing. - if os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() not in ("true", "1", "yes"): + _skip_summaries = os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() in ("true", "1", "yes") + print(f"SEARCH_WIKI_DEBUG [T1] enrich_node_summaries (TOC path): skip={_skip_summaries}, env={os.getenv('SIRCHMUNK_SKIP_NODE_SUMMARIES', '')}", flush=True) + if not _skip_summaries: await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, @@ -260,7 +262,9 @@ async def build_tree( # await self._deepen_large_leaves(root, content, max_depth=effective_depth) # Node summary enrichment: controlled by SIRCHMUNK_SKIP_NODE_SUMMARIES env var. # Set to "true" to skip during debugging / performance testing. - if os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() not in ("true", "1", "yes"): + _skip_summaries = os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() in ("true", "1", "yes") + print(f"SEARCH_WIKI_DEBUG [T1] enrich_node_summaries (recursive path): skip={_skip_summaries}, env={os.getenv('SIRCHMUNK_SKIP_NODE_SUMMARIES', '')}", flush=True) + if not _skip_summaries: await self._enrich_node_summaries(root, content) tree = DocumentTree( @@ -304,6 +308,8 @@ async def navigate( if tree.root is None: return [] + print(f"SEARCH_WIKI_DEBUG [T2] navigate: query={query[:80]}, total_nodes={self._count_nodes(tree.root)}", flush=True) + candidates = tree.root.children if tree.root.children else [tree.root] if not candidates: return [tree.root] @@ -318,6 +324,7 @@ async def navigate( selected = await self._select_children( frontier, query, max_selections=max_results, ) + print(f"SEARCH_WIKI_DEBUG [T3] navigate layer: depth={depth}, selected={len(selected)}, names={[n.title[:30] for n in selected][:5]}", flush=True) if not selected: break @@ -351,7 +358,10 @@ async def navigate( if n.node_id not in seen_ids: seen_ids.add(n.node_id) unique.append(n) - return unique[:max_results] + leaves = unique[:max_results] + _page_valid = sum(1 for l in leaves if getattr(l, 'page_range', None) and len(l.page_range) == 2 and l.page_range[0]) + print(f"SEARCH_WIKI_DEBUG [T4] navigate result: leaves={len(leaves)}, page_range_valid={_page_valid}", flush=True) + return leaves def load_tree(self, file_path: str) -> Optional[DocumentTree]: """Load a cached tree index for the given file (sync).""" @@ -821,6 +831,7 @@ def _cache_path(self, file_hash: str) -> Path: def _save_cache(self, file_hash: str, tree: DocumentTree) -> None: path = self._cache_path(file_hash) path.write_text(tree.to_json(), encoding="utf-8") + print(f"SEARCH_WIKI_DEBUG [C5] tree_json_saved: path={path}", flush=True) def _load_cache(self, file_hash: str) -> Optional[DocumentTree]: path = self._cache_path(file_hash) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 1c47a77..738db15 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2527,6 +2527,8 @@ async def _search_fast( {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} for p in _tree_probed_files[:top_k_files] ] + print(f"SEARCH_WIKI_DEBUG [D7] _tree_probed_files={_tree_probed_files}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D8] best_files={[bf['path'] for bf in best_files]}", flush=True) await self._logger.info( f"[FAST:PureTree] Using {len(best_files)} tree-probed files: " f"{[Path(p).name for p in _tree_probed_files[:top_k_files]]}" @@ -2651,6 +2653,11 @@ async def _search_fast( tree_nav_done: Set[str] = set() tree_nav_target = best_files[0]["path"] + print(f"SEARCH_WIKI_DEBUG [D9] tree_nav_target={tree_nav_target}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D10] tree_nav_match={tree_nav_target in (artifacts.tree_available_paths if artifacts else set())}", flush=True) + if artifacts and tree_nav_target not in artifacts.tree_available_paths: + print(f"SEARCH_WIKI_DEBUG [D11] MISMATCH! tree_available_paths={artifacts.tree_available_paths}", flush=True) + if artifacts and tree_nav_target in artifacts.tree_available_paths: tree_task = self._navigate_tree_for_evidence(tree_nav_target, query) tree_nav_done.add(tree_nav_target) @@ -2669,6 +2676,8 @@ async def _rga_evidence() -> str: ext = Path(fp).suffix.lower() ev = None + print(f"SEARCH_WIKI_DEBUG [D12] _rga_evidence: fp={fp}", flush=True) + # 0. Excel digest priority (pre-compiled evidence) if artifacts and artifacts.manifest_map: manifest_entry = artifacts.manifest_map.get(fp) @@ -2708,12 +2717,16 @@ async def _rga_evidence() -> str: except Exception: pass + print(f"SEARCH_WIKI_DEBUG [D13] table_digest: manifest_lookup={'found' if artifacts.manifest_map and artifacts.manifest_map.get(fp) else 'miss'}, has_table_digest={getattr(artifacts.manifest_map.get(fp), 'has_table_digest', False) if artifacts.manifest_map else 'N/A'}, hash_fallback={'tried' if not _all_tables else 'skipped'}, tables_count={len(_all_tables) if _all_tables else 0}", flush=True) + if _all_tables: _table_ev = self._format_table_evidence(_all_tables) if _table_ev: ev = f"[{fn} - Table Evidence]\n{_table_ev}" # 1. Tree-guided sampling FIRST for tree-indexed files + _tree_cond = artifacts and fp in artifacts.tree_available_paths and fp not in tree_nav_done + print(f"SEARCH_WIKI_DEBUG [D14] tree_sample: cond={_tree_cond}, in_tree_paths={fp in (artifacts.tree_available_paths if artifacts else set())}, in_nav_done={fp in tree_nav_done}", flush=True) if ( artifacts and fp in artifacts.tree_available_paths @@ -2755,6 +2768,14 @@ async def _rga_evidence() -> str: parts.append(ev[:remaining]) chars += len(parts[-1]) context.mark_file_read(fp) + + _ev_source = "none" + if ev: + if "Table Evidence" in ev: _ev_source = "table_digest" + elif "Pre-compiled" in ev: _ev_source = "excel_digest" + elif "TreeSample" in str(ev)[:50] or "TreeNav" in str(ev)[:50]: _ev_source = "tree" + else: _ev_source = "rga_or_other" + print(f"SEARCH_WIKI_DEBUG [D15] ev_source={_ev_source}, ev_len={len(ev) if ev else 0}", flush=True) return "\n\n---\n\n".join(parts) # Launch tree navigation for the primary file alongside rga @@ -2770,6 +2791,10 @@ async def _rga_evidence() -> str: evidence_parts_final.append(rga_ev) evidence = "\n\n---\n\n".join(evidence_parts_final) + print(f"SEARCH_WIKI_DEBUG [D16] tree_ev: {'yes' if tree_ev else 'no'}, len={len(tree_ev) if tree_ev else 0}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D17] rga_ev: {'yes' if rga_ev else 'no'}, len={len(rga_ev) if rga_ev else 0}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D18] final_evidence_len={len(evidence)}", flush=True) + if not evidence or len(evidence.strip()) < 20: if llm_fallback: await self._logger.info( @@ -3549,6 +3574,9 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: except Exception: pass + print(f"SEARCH_WIKI_DEBUG [D1] manifest_map: {len(manifest_map)} entries, keys={list(manifest_map.keys())[:3]}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D2] tree_available_paths: {tree_paths}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D3] manifest_fallback_executed: {manifest_map and not tree_paths}", flush=True) return CompileArtifacts( catalog=catalog, catalog_map=catalog_map, @@ -3804,6 +3832,8 @@ async def _tree_guided_sample( if max_chars <= 0: max_chars = self._FAST_MAX_EVIDENCE_CHARS + print(f"SEARCH_WIKI_DEBUG [S1] _tree_guided_sample: file_path={file_path}", flush=True) + # --- Guard: tree availability --- if artifacts is not None: if file_path not in artifacts.tree_available_paths: @@ -3839,6 +3869,7 @@ async def _tree_guided_sample( # --- Classify leaves by extraction method --- trimmed = leaves[: self._TREE_SAMPLE_MAX_SECTIONS] page_leaves, char_leaves, table_and_summary = self._classify_leaves(trimmed) + print(f"SEARCH_WIKI_DEBUG [S2] classify_leaves: page={len(page_leaves)}, char={len(char_leaves)}, table_summary={len(table_and_summary)}", flush=True) # Collect (leaf, segment) pairs preserving original leaf order leaf_segments: List[tuple] = [] # (leaf, segment_text) @@ -3968,6 +3999,7 @@ async def _tree_guided_sample( return None evidence = "\n\n".join(parts) + print(f"SEARCH_WIKI_DEBUG [S3] _tree_guided_sample result: len={len(evidence) if evidence else 0}", flush=True) await self._logger.info( f"[TreeSample] {fname}: " f"{len(parts)} sections, {total_chars} chars " @@ -4126,6 +4158,7 @@ async def _navigate_tree_for_evidence( 3. leaf.summary – last resort """ indexer = self._get_tree_indexer() + print(f"SEARCH_WIKI_DEBUG [N1] _navigate_tree_for_evidence: file_path={file_path}", flush=True) if indexer is None: return None tree = indexer.load_tree(file_path) @@ -4137,6 +4170,8 @@ async def _navigate_tree_for_evidence( except Exception: return None + print(f"SEARCH_WIKI_DEBUG [N2] navigate_result: {len(leaves) if leaves else 0} leaves", flush=True) + if not leaves: return None @@ -4145,6 +4180,7 @@ async def _navigate_tree_for_evidence( # ── Phase 1: classify leaves by available extraction method ── page_leaves, char_leaves, summary_only = self._classify_leaves(leaves) + print(f"SEARCH_WIKI_DEBUG [N3] classify_leaves: page={len(page_leaves)}, char={len(char_leaves)}, summary={len(summary_only)}", flush=True) for leaf in summary_only: self._append_evidence_part( @@ -4191,6 +4227,9 @@ async def _navigate_tree_for_evidence( self._append_evidence_part( parts, fname, leaf, leaf.summary, ) + print(f"SEARCH_WIKI_DEBUG [N4] page_extraction: page_leaves_ok=False", flush=True) + else: + print(f"SEARCH_WIKI_DEBUG [N4] page_extraction: page_leaves_ok=True", flush=True) # ── Phase 3: char_range fallback (lazy full-text extraction) ── if char_leaves: @@ -4256,7 +4295,10 @@ async def _navigate_tree_for_evidence( except Exception: pass + print(f"SEARCH_WIKI_DEBUG [N5] table_supplement: tables_loaded={len(_all_tables) if '_all_tables' in dir() and _all_tables else 0}", flush=True) + evidence = "\n\n".join(parts) + print(f"SEARCH_WIKI_DEBUG [N6] _navigate_tree_for_evidence result: len={len(evidence) if evidence else 0}", flush=True) await self._logger.info( f"[FAST:TreeNav] Extracted {len(parts)} sections, " f"{len(evidence)} chars from {fname}" @@ -4931,16 +4973,19 @@ async def _probe_tree_for_fast( Returns file paths of selected documents, or empty list when trees are unavailable or cover too few files to justify an LLM call. """ + print(f"SEARCH_WIKI_DEBUG [D4] _probe_tree_for_fast: tree_available_paths={len(artifacts.tree_available_paths) if artifacts else 0}", flush=True) if not artifacts or not artifacts.tree_available_paths: return [] try: trees = self._load_cached_trees() + print(f"SEARCH_WIKI_DEBUG [D5] loaded_trees: {len(trees)} trees, paths={[t.file_path for t in trees][:3]}", flush=True) if not trees: return [] result = await self._llm_select_from_trees( query, trees, max_select=self._FAST_TREE_PROBE_MAX_FILES, ) + print(f"SEARCH_WIKI_DEBUG [D6] llm_select_result: {result}", flush=True) if result: await self._logger.info( f"[FAST:TreeProbe] Selected {len(result)} files " From 8ff1f98c207e3c998bfaf10155cd72c9b0bfb871 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 26 Apr 2026 21:32:01 +0800 Subject: [PATCH 39/70] enhance search fast for compile --- src/sirchmunk/learnings/tree_indexer.py | 283 +++++++++++++++++++++--- src/sirchmunk/search.py | 17 +- 2 files changed, 269 insertions(+), 31 deletions(-) diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 2e2f909..2d7e277 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -163,6 +163,9 @@ class DocumentTreeIndexer: # Number of nodes per group in paginated selection. _GROUP_PAGE_SIZE: int = 15 + # Minimum navigation depth before allowing early termination. + _NAV_MIN_DEPTH: int = 2 + def __init__( self, llm: OpenAIChat, @@ -289,18 +292,21 @@ async def navigate( *, max_results: int = 3, max_depth: int = 4, + min_depth: int = 2, ) -> List[TreeNode]: """Adaptive-depth LLM-driven tree navigation. Iteratively descends the tree using _select_children() at each level, collecting leaf nodes until *max_results* are found or *max_depth* is - reached. + reached. Enforces *min_depth* descent before allowing early + termination to avoid shallow results. Args: tree: DocumentTree with a root node. query: Search query for relevance selection. max_results: Maximum number of leaf nodes to return. max_depth: Maximum descent depth (default 4). + min_depth: Minimum depth before early termination (default 2). Returns: List of the most relevant leaf TreeNodes. @@ -314,6 +320,10 @@ async def navigate( if not candidates: return [tree.root] + # Adaptive min-depth: clamp to tree's actual depth + tree_max_depth = self._max_node_depth(tree.root) + effective_min_depth = min(min_depth, max(tree_max_depth - 1, 1)) + result_leaves: List[TreeNode] = [] visited: set = set() # prevent cycles frontier = candidates @@ -325,7 +335,21 @@ async def navigate( frontier, query, max_selections=max_results, ) print(f"SEARCH_WIKI_DEBUG [T3] navigate layer: depth={depth}, selected={len(selected)}, names={[n.title[:30] for n in selected][:5]}", flush=True) + if not selected: + # Fix A.1: when depth < effective_min_depth, expand all frontier children + if depth < effective_min_depth: + next_frontier: List[TreeNode] = [] + for node in frontier: + if node.children: + next_frontier.extend(node.children) + else: + result_leaves.append(node) + if not next_frontier: + break + frontier = next_frontier + depth += 1 + continue break next_frontier: List[TreeNode] = [] @@ -335,14 +359,25 @@ async def navigate( continue visited.add(node_id) + # Fix A.2: leaf determination with depth constraint if node.leaf or not node.children: - result_leaves.append(node) + if depth >= effective_min_depth: + result_leaves.append(node) + elif node.children: + next_frontier.extend(node.children) + else: + # True leaf (no children), cannot descend further + result_leaves.append(node) else: next_frontier.extend(node.children) - if len(result_leaves) >= max_results: + # Fix A.3: early termination requires depth >= effective_min_depth + if len(result_leaves) >= max_results and depth >= effective_min_depth: break + # Fix A.4: check for empty next_frontier + if not next_frontier: + break frontier = next_frontier depth += 1 @@ -404,6 +439,9 @@ async def _build_tree_from_toc( # Infer hierarchy when TOC entries are flat (all same level) toc_entries = self._infer_hierarchy(toc_entries) + # Merge consecutive fragment entries into virtual parents + toc_entries = self._merge_fragment_entries(toc_entries) + seen_ids: set = set() children = self._toc_entries_to_nodes( toc_entries, content, len(content), seen_ids, @@ -425,6 +463,78 @@ async def _build_tree_from_toc( children=children, ) + @staticmethod + def _merge_fragment_entries(entries: List[Any]) -> List[Any]: + """Merge consecutive fragment TOC entries into virtual parent nodes. + + Detects runs of >=3 consecutive entries that have tiny char_range + spans (<500) and no children, then collapses them into a single + virtual 'Preamble' entry. Uses only structural signals (char spans, + children counts) — no domain-specific keywords. + + Safety valve: returns original *entries* if result has < 2 entries. + """ + if len(entries) <= 5: + return entries + + # Phase 1: Detect fragment runs + def _is_fragment(e: Any) -> bool: + span = 0 + if hasattr(e, 'char_start') and hasattr(e, 'char_end'): + if e.char_end and e.char_start is not None: + span = e.char_end - e.char_start + has_children = bool(getattr(e, 'children', None)) + return span < 500 and not has_children + + # Find runs of consecutive fragments + runs: List[List[int]] = [] # list of [start_idx, end_idx] inclusive + i = 0 + while i < len(entries): + if _is_fragment(entries[i]): + run_start = i + while i < len(entries) and _is_fragment(entries[i]): + i += 1 + if (i - run_start) >= 3: # Only merge runs of 3+ + runs.append([run_start, i - 1]) + else: + i += 1 + + if not runs: + return entries + + # Phase 2: Merge each run into a virtual parent + from copy import deepcopy + + result: List[Any] = [] + prev_end = -1 + for run_start, run_end in runs: + # Add non-fragment entries before this run + for j in range(prev_end + 1, run_start): + result.append(entries[j]) + + # Create virtual parent from the run + first_entry = entries[run_start] + last_entry = entries[run_end] + + merged = deepcopy(first_entry) + merged.title = f"Preamble ({run_end - run_start + 1} sections)" + if hasattr(last_entry, 'char_end') and last_entry.char_end: + merged.char_end = last_entry.char_end + # Set children to the original entries + merged.children = list(entries[run_start:run_end + 1]) + result.append(merged) + prev_end = run_end + + # Add remaining entries after last run + for j in range(prev_end + 1, len(entries)): + result.append(entries[j]) + + # Safety valve + if len(result) < 2: + return entries + + return result + @staticmethod def _toc_entries_to_nodes( entries: List[Any], @@ -672,34 +782,161 @@ def _resolve_positions( and (s["end"] - s["start"]) / max(text_len, 1) < _MAX_SPAN_RATIO ] + @staticmethod + def _filter_low_value_nodes( + nodes: List["TreeNode"], + *, + min_remaining: int = 3, + ) -> List["TreeNode"]: + """Filter out low-value fragment nodes using structural signals. + + Applies three generic heuristics (no domain-specific keywords): + 1. Short-page leaf: page_range spans <= 2 pages AND no children AND + summary length < 100 chars. + 2. Tiny fragment: title < 10 chars AND no children AND + char_range span < 200 chars. + 3. Duplicate page_range: among nodes sharing the same page_range, + keep only the one with the largest char_range span. + + Safety valve: returns original *nodes* if fewer than *min_remaining* + survive filtering. + """ + if len(nodes) <= min_remaining: + return nodes + + # Pass 1: identify fragment nodes + keep: List[bool] = [True] * len(nodes) + + for i, n in enumerate(nodes): + pr = getattr(n, 'page_range', None) + has_children = bool(n.children) + summary_len = len(n.summary) if n.summary else 0 + title_len = len(n.title.strip()) if n.title else 0 + cr = getattr(n, 'char_range', (0, 0)) + span = (cr[1] - cr[0]) if cr and len(cr) == 2 else 0 + + # Heuristic 1: short-page leaf + if ( + pr and len(pr) == 2 + and pr[0] is not None and pr[1] is not None + and (pr[1] - pr[0]) <= 1 + and not has_children + and summary_len < 100 + ): + keep[i] = False + continue + + # Heuristic 2: tiny fragment + if title_len < 10 and not has_children and span < 200: + keep[i] = False + continue + + # Pass 2: deduplicate by page_range + page_range_groups: dict = {} # page_range -> list of (index, span) + for i, n in enumerate(nodes): + if not keep[i]: + continue + pr = getattr(n, 'page_range', None) + if pr and len(pr) == 2: + key = (pr[0], pr[1]) + cr = getattr(n, 'char_range', (0, 0)) + span = (cr[1] - cr[0]) if cr and len(cr) == 2 else 0 + page_range_groups.setdefault(key, []).append((i, span)) + + for key, group in page_range_groups.items(): + if len(group) > 1: + # Keep only the node with largest char_range span + best_idx = max(group, key=lambda x: x[1])[0] + for idx, _ in group: + if idx != best_idx: + keep[idx] = False + + filtered = [n for i, n in enumerate(nodes) if keep[i]] + return filtered if len(filtered) >= min_remaining else nodes + + @staticmethod + def _build_node_descriptor(node: "TreeNode", index: int) -> str: + """Build a rich descriptor string for a single tree node. + + Includes structural signals: page span, table count, subsection + count, and depth information to help LLM make informed selections. + """ + parts = [f"[{index}] {node.title}"] + + # Page range with span + pr = getattr(node, 'page_range', None) + if pr and len(pr) == 2 and pr[0] is not None: + span_pages = pr[1] - pr[0] + 1 if pr[1] else 1 + parts.append(f"[pages {pr[0]}-{pr[1]}, {span_pages}p]") + + # Table count + if node.table_count > 0: + parts.append(f"[{node.table_count} tables]") + + # Subsections + child_count = len(node.children) + if child_count > 0: + parts.append(f"[{child_count} subsections]") + + # Summary + summary = (node.summary or "")[:200] + if summary: + parts.append(f": {summary}") + + return " ".join(parts) + + @staticmethod + def _build_selection_prompt( + nodes: List["TreeNode"], + query: str, + max_selections: int, + ) -> str: + """Build unified LLM prompt for branch selection. + + Uses structural signals to guide LLM toward high-value sections: + tables, subsection depth, page span. No domain-specific keywords. + """ + listing = "\n".join( + DocumentTreeIndexer._build_node_descriptor(n, i) + for i, n in enumerate(nodes) + ) + + sel_hint = f"1-{min(max_selections, len(nodes))}" + + return ( + f"Given the query: \"{query}\"\n\n" + f"Select the {sel_hint} most relevant sections (by index number):\n" + f"{listing}\n\n" + f"Selection criteria:\n" + f"- Prioritize sections containing tables and data\n" + f"- Prefer sections with many subsections over small leaf fragments\n" + f"- Avoid sections covering only 1-2 pages with no subsections\n" + f"- When uncertain, prefer larger sections that can be narrowed later\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + async def _select_children( self, nodes: List[TreeNode], query: str, *, max_selections: int = 3, ) -> List[TreeNode]: """LLM-driven branch selection: pick the most relevant children. - Dispatches to paginated selection when *nodes* exceeds - ``_PAGE_SIZE_THRESHOLD`` to avoid overwhelming the LLM. + Pre-filters low-value fragments, then dispatches to paginated + selection when *nodes* exceeds ``_PAGE_SIZE_THRESHOLD``. """ if len(nodes) <= 2: return nodes + # Pre-filter low-value fragment nodes + nodes = self._filter_low_value_nodes(nodes) + if len(nodes) <= 2: + return nodes + if len(nodes) > self._PAGE_SIZE_THRESHOLD: return await self._select_children_paginated( nodes, query, max_selections=max_selections, ) - listing = "\n".join( - f"[{i}] {n.title}{self._format_page_range(n.page_range)}" - f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" - f": {n.summary[:150]}" - for i, n in enumerate(nodes) - ) - - prompt = ( - f"Given the query: \"{query}\"\n\n" - f"Select the 1-2 most relevant sections (by index number):\n{listing}\n\n" - f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" - ) + prompt = self._build_selection_prompt(nodes, query, max_selections) resp = await self._llm.achat([{"role": "user", "content": prompt}]) try: raw = resp.content.strip() @@ -797,17 +1034,7 @@ async def _select_from_group( if len(group) <= 2: return group - listing = "\n".join( - f"[{i}] {n.title}{self._format_page_range(n.page_range)}" - f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" - f": {n.summary[:150]}" - for i, n in enumerate(group) - ) - prompt = ( - f"Given the query: \"{query}\"\n\n" - f"Select the 1-2 most relevant sections (by index number):\n{listing}\n\n" - f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" - ) + prompt = self._build_selection_prompt(group, query, max_selections) try: resp = await self._llm.achat([{"role": "user", "content": prompt}]) raw = resp.content.strip() diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 738db15..e12fd39 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -4022,9 +4022,20 @@ def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: summary_leaves: List = [] for leaf in leaves: - # 表格类型节点优先使用 summary(结构化摘要) - if getattr(leaf, 'content_type', 'text') == 'table' and getattr(leaf, 'summary', None): - summary_leaves.append(leaf) + # 表格类型节点:优先 page-level 提取获取完整原始内容 + if getattr(leaf, 'content_type', 'text') == 'table': + page_range = getattr(leaf, 'page_range', None) + if ( + page_range + and len(page_range) == 2 + and page_range[0] is not None + and page_range[0] > 0 + ): + page_leaves.append((leaf, page_range)) + elif getattr(leaf, 'summary', None): + summary_leaves.append(leaf) + else: + char_leaves.append(leaf) continue page_range = getattr(leaf, 'page_range', None) From 464d8d511093f07d388313549df7f1fab8dfbb43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 26 Apr 2026 22:49:59 +0800 Subject: [PATCH 40/70] enhance tree index --- src/sirchmunk/learnings/tree_indexer.py | 130 +++++++++++------------- src/sirchmunk/llm/prompts.py | 12 ++- src/sirchmunk/search.py | 114 ++++++++++++++++----- 3 files changed, 158 insertions(+), 98 deletions(-) diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 2d7e277..9cf450e 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -292,21 +292,21 @@ async def navigate( *, max_results: int = 3, max_depth: int = 4, - min_depth: int = 2, + min_depth: int = 1, ) -> List[TreeNode]: """Adaptive-depth LLM-driven tree navigation. Iteratively descends the tree using _select_children() at each level, collecting leaf nodes until *max_results* are found or *max_depth* is reached. Enforces *min_depth* descent before allowing early - termination to avoid shallow results. + termination to avoid overly shallow results. Args: tree: DocumentTree with a root node. query: Search query for relevance selection. max_results: Maximum number of leaf nodes to return. max_depth: Maximum descent depth (default 4). - min_depth: Minimum depth before early termination (default 2). + min_depth: Minimum depth before early termination (default 1). Returns: List of the most relevant leaf TreeNodes. @@ -320,6 +320,16 @@ async def navigate( if not candidates: return [tree.root] + # Skip single-child container chains (e.g. SEC boilerplate wrappers + # like "UNITED STATES SECURITIES AND EXCHANGE COMMISSION" → "FORM 10-K") + # to avoid wasting navigation depth on structural-only nodes. + while ( + len(candidates) == 1 + and candidates[0].children + and not candidates[0].leaf + ): + candidates = candidates[0].children + # Adaptive min-depth: clamp to tree's actual depth tree_max_depth = self._max_node_depth(tree.root) effective_min_depth = min(min_depth, max(tree_max_depth - 1, 1)) @@ -359,17 +369,10 @@ async def navigate( continue visited.add(node_id) - # Fix A.2: leaf determination with depth constraint - if node.leaf or not node.children: - if depth >= effective_min_depth: - result_leaves.append(node) - elif node.children: - next_frontier.extend(node.children) - else: - # True leaf (no children), cannot descend further - result_leaves.append(node) - else: + if node.children: next_frontier.extend(node.children) + else: + result_leaves.append(node) # Fix A.3: early termination requires depth >= effective_min_depth if len(result_leaves) >= max_results and depth >= effective_min_depth: @@ -788,68 +791,58 @@ def _filter_low_value_nodes( *, min_remaining: int = 3, ) -> List["TreeNode"]: - """Filter out low-value fragment nodes using structural signals. - - Applies three generic heuristics (no domain-specific keywords): - 1. Short-page leaf: page_range spans <= 2 pages AND no children AND - summary length < 100 chars. - 2. Tiny fragment: title < 10 chars AND no children AND - char_range span < 200 chars. - 3. Duplicate page_range: among nodes sharing the same page_range, - keep only the one with the largest char_range span. - - Safety valve: returns original *nodes* if fewer than *min_remaining* - survive filtering. + """Remove only structurally empty or exact-duplicate nodes. + + Intentionally conservative: the LLM selection step receives rich + structural descriptors (page span, table count, subsection count) + and is trusted to judge relevance. This filter removes only + definitive noise that would waste LLM context: + + 1. Empty placeholders — no title, no children, zero char span, + and no summary. + 2. Exact duplicates — identical (title, page_range) pairs; among + duplicates the node with the richest structure is kept. + + Safety: returns original *nodes* when fewer than *min_remaining* + would survive. """ if len(nodes) <= min_remaining: return nodes - # Pass 1: identify fragment nodes keep: List[bool] = [True] * len(nodes) - for i, n in enumerate(nodes): - pr = getattr(n, 'page_range', None) - has_children = bool(n.children) - summary_len = len(n.summary) if n.summary else 0 - title_len = len(n.title.strip()) if n.title else 0 - cr = getattr(n, 'char_range', (0, 0)) - span = (cr[1] - cr[0]) if cr and len(cr) == 2 else 0 - - # Heuristic 1: short-page leaf - if ( - pr and len(pr) == 2 - and pr[0] is not None and pr[1] is not None - and (pr[1] - pr[0]) <= 1 - and not has_children - and summary_len < 100 - ): - keep[i] = False - continue + def _char_span(n: "TreeNode") -> int: + cr = getattr(n, "char_range", (0, 0)) + return (cr[1] - cr[0]) if cr and len(cr) == 2 else 0 - # Heuristic 2: tiny fragment - if title_len < 10 and not has_children and span < 200: + # Pass 1: remove structurally empty placeholder nodes + for i, n in enumerate(nodes): + title = (n.title or "").strip() + if not title and not n.children and _char_span(n) == 0 and not n.summary: keep[i] = False - continue - # Pass 2: deduplicate by page_range - page_range_groups: dict = {} # page_range -> list of (index, span) + # Pass 2: deduplicate exact (title, page_range) pairs — + # keep the node with more structural richness. + seen: dict = {} # (title, page_range_key) → index for i, n in enumerate(nodes): if not keep[i]: continue - pr = getattr(n, 'page_range', None) - if pr and len(pr) == 2: - key = (pr[0], pr[1]) - cr = getattr(n, 'char_range', (0, 0)) - span = (cr[1] - cr[0]) if cr and len(cr) == 2 else 0 - page_range_groups.setdefault(key, []).append((i, span)) - - for key, group in page_range_groups.items(): - if len(group) > 1: - # Keep only the node with largest char_range span - best_idx = max(group, key=lambda x: x[1])[0] - for idx, _ in group: - if idx != best_idx: - keep[idx] = False + title = (n.title or "").strip() + pr = getattr(n, "page_range", None) + pr_key = (pr[0], pr[1]) if pr and len(pr) == 2 else None + dup_key = (title, pr_key) + if dup_key in seen: + prev_i = seen[dup_key] + prev = nodes[prev_i] + richness = (len(n.children), getattr(n, "table_count", 0), _char_span(n)) + prev_richness = (len(prev.children), getattr(prev, "table_count", 0), _char_span(prev)) + if richness > prev_richness: + keep[prev_i] = False + seen[dup_key] = i + else: + keep[i] = False + else: + seen[dup_key] = i filtered = [n for i, n in enumerate(nodes) if keep[i]] return filtered if len(filtered) >= min_remaining else nodes @@ -908,9 +901,9 @@ def _build_selection_prompt( f"Select the {sel_hint} most relevant sections (by index number):\n" f"{listing}\n\n" f"Selection criteria:\n" - f"- Prioritize sections containing tables and data\n" - f"- Prefer sections with many subsections over small leaf fragments\n" - f"- Avoid sections covering only 1-2 pages with no subsections\n" + f"- Prioritize sections most likely to answer the query\n" + f"- Sections with tables, data, or subsections are often high-value\n" + f"- Short sections containing relevant data should not be dismissed\n" f"- When uncertain, prefer larger sections that can be narrowed later\n\n" f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" ) @@ -920,8 +913,9 @@ async def _select_children( ) -> List[TreeNode]: """LLM-driven branch selection: pick the most relevant children. - Pre-filters low-value fragments, then dispatches to paginated - selection when *nodes* exceeds ``_PAGE_SIZE_THRESHOLD``. + Removes only definitive noise (empty / duplicate nodes), then + dispatches to paginated selection when *nodes* exceeds + ``_PAGE_SIZE_THRESHOLD``. Relevance judgment is delegated to the LLM. """ if len(nodes) <= 2: return nodes diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 27338a2..074847e 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -422,6 +422,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 1. **Language Continuity**: The output must be in the SAME language as the User Input. 2. **Format**: Use Markdown (headings, bullet points, and bold text) for high readability. 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. +4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. ### Input Data - **User Input**: {user_input} @@ -442,8 +443,11 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". ### Output Format + +[If the query asks for a specific value, ratio, number, or factual answer, state ONLY the direct answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). If the query is open-ended, write a one-sentence conclusion.] + -[Generate the Markdown Briefing here] +[Generate the Markdown Briefing here with detailed analysis and supporting evidence] true/false true/false @@ -458,6 +462,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 1. **Language Continuity**: The output must be in the SAME language as the User Input. 2. **Format**: Use Markdown (headings, bullet points, and bold text) for high readability. 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. +4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. ### Document Context {document_context} @@ -481,8 +486,11 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". ### Output Format + +[If the query asks for a specific value, ratio, number, or factual answer, state ONLY the direct answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). If the query is open-ended, write a one-sentence conclusion.] + -[Generate the Markdown Briefing here] +[Generate the Markdown Briefing here with detailed analysis and supporting evidence] true/false true/false diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index e12fd39..921c383 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -934,20 +934,21 @@ async def _search_by_filename( @staticmethod def _parse_summary_response(llm_response: str) -> Tuple[str, bool, bool]: - """ - Parse LLM response to extract summary and quality decisions. + """Parse LLM response to extract summary, precise answer, and quality decisions. - Args: - llm_response: Raw LLM response containing SUMMARY, SHOULD_ANSWER and SHOULD_SAVE tags + When a ```` tag is present, its content is prepended to + the summary so downstream consumers (evaluation judges, UIs) see the + direct answer prominently without needing separate tag awareness. Returns: Tuple of (summary_text, should_save_flag, should_answer_flag) """ summary_fields = extract_fields( content=llm_response, - tags=["SUMMARY", "SHOULD_ANSWER", "SHOULD_SAVE"], + tags=["PRECISE_ANSWER", "SUMMARY", "SHOULD_ANSWER", "SHOULD_SAVE"], ) + precise = str(summary_fields.get("precise_answer") or "").strip() summary = str(summary_fields.get("summary") or "").strip() should_answer_str = str(summary_fields.get("should_answer") or "false").strip().lower() should_save_str = str(summary_fields.get("should_save") or "false").strip().lower() @@ -955,8 +956,11 @@ def _parse_summary_response(llm_response: str) -> Tuple[str, bool, bool]: should_answer = should_answer_str in ["true", "yes", "1"] should_save = should_save_str in ["true", "yes", "1"] - # If extraction failed, use entire response as summary and default to conservative: - # not answerable and not saveable. + if precise and summary: + summary = f"**Answer: {precise}**\n\n{summary}" + elif precise: + summary = precise + if not summary: summary = llm_response.strip() should_answer = False @@ -2198,7 +2202,7 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Maximum files returned by catalog keyword-overlap probe in DEEP mode.""" # --- Tree-guided sampling constants --- - _TREE_SAMPLE_MAX_SECTIONS = 3 + _TREE_SAMPLE_MAX_SECTIONS = 5 """Max tree sections to include per file in tree-guided sampling.""" _TREE_SAMPLE_SECTION_MAX_CHARS = 3000 """Max chars per tree section.""" @@ -2218,10 +2222,10 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """char_range spanning more than this ratio of the document is treated as invalid.""" # --- Self-correction expanded sampling --- - _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 6 - """Expanded tree navigation leaf count for same-file re-sampling (default nav uses 3).""" - _SELF_CORRECT_EXPANDED_SECTIONS: int = 5 - """Expanded tree sample sections for same-file re-sampling (default uses 3).""" + _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 10 + """Expanded tree navigation leaf count for same-file re-sampling (default nav uses 5).""" + _SELF_CORRECT_EXPANDED_SECTIONS: int = 8 + """Expanded tree sample sections for same-file re-sampling (default uses 5).""" # --- Evidence acceptance thresholds --- _EVIDENCE_MIN_ACCEPT_LENGTH: int = 800 @@ -2720,7 +2724,9 @@ async def _rga_evidence() -> str: print(f"SEARCH_WIKI_DEBUG [D13] table_digest: manifest_lookup={'found' if artifacts.manifest_map and artifacts.manifest_map.get(fp) else 'miss'}, has_table_digest={getattr(artifacts.manifest_map.get(fp), 'has_table_digest', False) if artifacts.manifest_map else 'N/A'}, hash_fallback={'tried' if not _all_tables else 'skipped'}, tables_count={len(_all_tables) if _all_tables else 0}", flush=True) if _all_tables: - _table_ev = self._format_table_evidence(_all_tables) + _table_ev = self._format_table_evidence( + _all_tables, query=query, + ) if _table_ev: ev = f"[{fn} - Table Evidence]\n{_table_ev}" @@ -4009,20 +4015,26 @@ async def _tree_guided_sample( @staticmethod def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: - """将叶节点按提取策略分类。 + """Classify leaf nodes by preferred extraction strategy. + + For non-table leaves, **char_range** (kreuzberg markdown) is preferred + over page_range (pypdf raw text) because compile-time extraction + preserves table layout and column structure far better than pypdf's + ``extract_text()``. page_range remains available on each leaf for + table-supplement filtering even when the leaf is routed to char_leaves. Returns: - (page_leaves, char_leaves, summary_leaves) 三元组: - - page_leaves: list of (leaf, page_range) tuples — 有有效 page_range 的 - - char_leaves: list of leaf — 需要 char_range fallback 的 - - summary_leaves: list of leaf — 只有 summary 可用的 + (page_leaves, char_leaves, summary_leaves) triple: + - page_leaves: list of (leaf, page_range) — page-level extraction + - char_leaves: list of leaf — kreuzberg char_range extraction + - summary_leaves: list of leaf — only summary available """ page_leaves: List[tuple] = [] char_leaves: List = [] summary_leaves: List = [] for leaf in leaves: - # 表格类型节点:优先 page-level 提取获取完整原始内容 + # Table nodes: prefer page-level extraction for raw original content if getattr(leaf, 'content_type', 'text') == 'table': page_range = getattr(leaf, 'page_range', None) if ( @@ -4038,11 +4050,21 @@ def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: char_leaves.append(leaf) continue + # Non-table leaves: prefer char_range (kreuzberg markdown) over + # page_range (pypdf raw text) for higher-fidelity table rendering. + has_char = hasattr(leaf, 'char_range') and leaf.char_range page_range = getattr(leaf, 'page_range', None) - if page_range and len(page_range) == 2 and page_range[0] is not None and page_range[0] > 0: - page_leaves.append((leaf, page_range)) - elif hasattr(leaf, 'char_range') and leaf.char_range: + has_page = ( + page_range + and len(page_range) == 2 + and page_range[0] is not None + and page_range[0] > 0 + ) + + if has_char: char_leaves.append(leaf) + elif has_page: + page_leaves.append((leaf, page_range)) elif getattr(leaf, 'summary', None): summary_leaves.append(leaf) @@ -4095,27 +4117,62 @@ def _filter_tables_by_page_range( and page_start <= t["page_number"] <= page_end ] + @staticmethod + def _score_table_relevance( + markdown: str, query_tokens: frozenset, + ) -> float: + """Score a table's relevance to the query via token overlap. + + Returns a value in [0, 1] representing the fraction of *query_tokens* + found in the table's markdown text (case-insensitive). + """ + if not markdown or not query_tokens: + return 0.0 + md_lower = markdown.lower() + hits = sum(1 for tok in query_tokens if tok in md_lower) + return hits / len(query_tokens) + @staticmethod def _format_table_evidence( tables: List[Dict[str, Any]], - max_chars: int = 3000, + max_chars: int = 6000, + query: str = "", ) -> str: """Format table digest entries as LLM-friendly evidence text. + When *query* is provided, tables are **sorted by relevance** to the + query before budget truncation, ensuring critical tables are included + even when they appear late in page order. + Strategy: - - Small tables (<1000 chars): preserve full Markdown - - Large tables: truncate to max_chars with "(truncated)" note + - Query-relevant tables are prioritised via keyword overlap scoring - Each table prefixed with "[Table from page N]" + - Large tables truncated with "(truncated)" note Returns concatenated formatted table evidence string. """ if not tables: return "" + ordered = tables + if query: + query_tokens = frozenset( + tok for tok in query.lower().split() if len(tok) > 2 + ) + if query_tokens: + scored = [ + (AgenticSearch._score_table_relevance( + t.get("markdown", ""), query_tokens, + ), idx, t) + for idx, t in enumerate(tables) + ] + scored.sort(key=lambda x: (-x[0], x[1])) + ordered = [t for _, _, t in scored] + parts: List[str] = [] remaining = max_chars - for table in tables: + for table in ordered: if remaining <= 0: break @@ -4155,7 +4212,7 @@ def _append_evidence_part( parts.append(f"{header}\n{text}") async def _navigate_tree_for_evidence( - self, file_path: str, query: str, *, max_results: int = 3, + self, file_path: str, query: str, *, max_results: int = 5, ) -> Optional[str]: """LLM-driven tree navigation: select relevant sections and read leaf content. @@ -4297,7 +4354,8 @@ async def _navigate_tree_for_evidence( ) if leaf_tables: table_text = self._format_table_evidence( - leaf_tables, max_chars=2000, + leaf_tables, max_chars=4000, + query=query, ) if table_text: parts.append( From bdd8bdc666628de6120075bcadfffd3d4f37b9a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 26 Apr 2026 23:05:03 +0800 Subject: [PATCH 41/70] fix review --- src/sirchmunk/learnings/compiler.py | 13 +++ src/sirchmunk/search.py | 124 ++++++++++++++++++++++------ 2 files changed, 114 insertions(+), 23 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 62f3e19..ef901aa 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -635,6 +635,19 @@ async def _compile_single_file( except Exception: pass + # Cache compile-time ENHANCED content so search can slice + # char_range from the same text the tree was built from. + try: + file_hash_content = get_fast_hash(entry.path) or "" + if file_hash_content and content: + content_dir = self._compile_dir / "content" + content_dir.mkdir(parents=True, exist_ok=True) + (content_dir / f"{file_hash_content}.txt").write_text( + content, encoding="utf-8", + ) + except Exception: + pass + # Persist table digest for documents with extracted tables if extraction.tables: try: diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 921c383..c38dfee 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -3930,14 +3930,16 @@ async def _tree_guided_sample( leaf_segments.append((leaf, seg)) # If page extraction failed, demoted leaves are now in char_leaves - # -- Phase C: char_range fallback (lazy full-text extraction) -- + # -- Phase C: char_range extraction (compile-consistent content) -- if char_leaves: - try: - from sirchmunk.utils.file_utils import fast_extract - extraction = await fast_extract(file_path=file_path) - full_text = extraction.content or "" - except Exception: - full_text = "" + full_text = self._load_compile_content(self.work_path, file_path) + if not full_text: + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" for leaf in char_leaves: start, end = leaf.char_range @@ -4084,6 +4086,31 @@ def _is_valid_char_range( span_ratio = (end - start) / text_len return span_ratio < self._CHAR_RANGE_MAX_SPAN_RATIO + @staticmethod + def _load_compile_content( + work_path: Path, file_path: str, + ) -> Optional[str]: + """Load the ENHANCED content cached at compile time. + + Compile stores the kreuzberg ENHANCED-profile content alongside the + tree index so that search-time ``char_range`` slicing operates on + the *same* text the ranges were computed from. Returns ``None`` + when the cache file is missing (e.g. pre-cache compile run). + """ + try: + from sirchmunk.utils.file_utils import get_fast_hash + file_hash = get_fast_hash(file_path) + if not file_hash: + return None + cache_path = ( + work_path / ".cache" / "compile" / "content" / f"{file_hash}.txt" + ) + if cache_path.exists(): + return cache_path.read_text(encoding="utf-8") + except Exception: + pass + return None + @staticmethod def _load_table_digest( work_path: Path, file_hash: str, @@ -4221,8 +4248,8 @@ async def _navigate_tree_for_evidence( Returns None when no tree cache is available. Extraction priority (highest first): - 1. page_range – page-level extraction via DocumentExtractor - 2. char_range – full-text extraction + slice (fallback) + 1. char_range – compile-time ENHANCED content slice (preserves tables) + 2. page_range – page-level extraction via DocumentExtractor (fallback) 3. leaf.summary – last resort """ indexer = self._get_tree_indexer() @@ -4299,14 +4326,22 @@ async def _navigate_tree_for_evidence( else: print(f"SEARCH_WIKI_DEBUG [N4] page_extraction: page_leaves_ok=True", flush=True) - # ── Phase 3: char_range fallback (lazy full-text extraction) ── + # ── Phase 3: char_range extraction (compile-consistent content) ── if char_leaves: - try: - from sirchmunk.utils.file_utils import fast_extract - extraction = await fast_extract(file_path=file_path) - full_text = extraction.content or "" - except Exception: - full_text = "" + # Prefer compile-time ENHANCED content (matches char_range offsets + # exactly). Fall back to fast_extract only when cache is absent. + full_text = self._load_compile_content(self.work_path, file_path) + if not full_text: + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + + # Leaves whose char_range is invalid but have a valid page_range + # are demoted to page extraction instead of discarding to summary. + page_fallback_leaves: List[tuple] = [] for leaf in char_leaves: start, end = leaf.char_range @@ -4320,14 +4355,57 @@ async def _navigate_tree_for_evidence( self._append_evidence_part( parts, fname, leaf, leaf.summary, ) - elif getattr(leaf, 'summary', None): - _loguru_logger.debug( - f"[TreeNav] char_range degraded for '{leaf.title}' " - f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" - ) - self._append_evidence_part( - parts, fname, leaf, leaf.summary, + else: + # char_range covers too much of the document (or text is + # empty). Try page_range extraction before falling back + # to summary. + pr = getattr(leaf, 'page_range', None) + if ( + pr + and len(pr) == 2 + and pr[0] is not None + and pr[0] > 0 + ): + page_fallback_leaves.append((leaf, pr)) + elif getattr(leaf, 'summary', None): + _loguru_logger.debug( + f"[TreeNav] char_range degraded for '{leaf.title}' " + f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), " + f"using summary" + ) + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + + # Batch page extraction for demoted leaves (same pattern as Phase 2) + if page_fallback_leaves: + all_fb_pages: set = set() + for _lf, (sp, ep) in page_fallback_leaves: + all_fb_pages.update(range(sp, ep + 1)) + try: + fb_contents = DocumentExtractor.extract_pages( + file_path, sorted(all_fb_pages), ) + fb_map = {pc.page_number: pc.content for pc in fb_contents} + for lf, (sp, ep) in page_fallback_leaves: + seg_parts = [ + fb_map[p] for p in range(sp, ep + 1) + if fb_map.get(p, "").strip() + ] + if seg_parts: + self._append_evidence_part( + parts, fname, lf, "\n".join(seg_parts), + ) + elif getattr(lf, 'summary', None): + self._append_evidence_part( + parts, fname, lf, lf.summary, + ) + except Exception: + for lf, _ in page_fallback_leaves: + if getattr(lf, 'summary', None): + self._append_evidence_part( + parts, fname, lf, lf.summary, + ) if not parts: return None From d3b91d679e17ae2e697ae5913cec702eaa41ce52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 27 Apr 2026 02:09:00 +0800 Subject: [PATCH 42/70] improve kreuzberg table extraction --- src/sirchmunk/learnings/compiler.py | 295 ++++++++++++++++++---- src/sirchmunk/utils/document_extractor.py | 20 +- 2 files changed, 265 insertions(+), 50 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index ef901aa..ddd509b 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -63,6 +63,9 @@ # this threshold are skipped during targeted extraction. _TABLE_NUMERIC_DENSITY_THRESHOLD = 0.15 +# Selective force-OCR: max pages to re-extract with forced OCR per document +_FORCE_OCR_MAX_PAGES = 30 + # Excel table-level adaptive sampling constants _XLSX_TOTAL_ROW_BUDGET = 100 # Total sampled rows budget across all sheets _XLSX_MIN_ROWS_PER_SHEET = 3 # Minimum sampled rows per sheet @@ -676,50 +679,43 @@ async def _compile_single_file( content=content, total_pages=extraction.page_count, ) - # Phase 2.5: Targeted table extraction via generic structural signals + # Phase 2.5: Targeted table extraction via tree-node structural signals if result.tree and result.tree.root and ext == ".pdf": targeted_tables = await self._targeted_table_extraction( entry.path, result.tree, ) - if targeted_tables: - # Load existing table digest (if any) and merge - digest_dir = self._compile_dir / "table_digests" - file_hash = get_fast_hash(entry.path) or "" - existing_digest: list[dict] = [] - if file_hash and result.has_table_digest: - digest_path = digest_dir / f"{file_hash}.json" - if digest_path.exists(): - try: - raw = json.loads( - digest_path.read_text(encoding="utf-8") - ) - existing_digest = raw.get("tables", []) - except Exception: - pass - merged = self._merge_table_digests( - existing_digest, targeted_tables, + await self._supplement_table_digest( + entry.path, targeted_tables, result, + source_label="Targeted extraction", + ) + + # Phase 2.6: Content-based full-page table scan (tree-independent) + if ext == ".pdf" and extraction.page_count: + covered_pages = self._get_covered_table_pages(entry.path) + content_tables = await self._content_based_table_scan( + entry.path, + extraction.page_count, + covered_pages, + ) + await self._supplement_table_digest( + entry.path, content_tables, result, + source_label="Content-based scan", + ) + + # Phase 2.7: Selective force-OCR for high-density gap pages + if ext == ".pdf" and extraction.page_count: + covered_after_scan = self._get_covered_table_pages(entry.path) + gap_pages = self._find_force_ocr_candidates( + entry.path, extraction.page_count, covered_after_scan, + ) + if gap_pages: + ocr_tables = await self._selective_force_ocr_tables( + entry.path, gap_pages, + ) + await self._supplement_table_digest( + entry.path, ocr_tables, result, + source_label="Selective force-OCR", ) - if merged and file_hash: - digest_dir.mkdir(parents=True, exist_ok=True) - digest_path = digest_dir / f"{file_hash}.json" - digest_path.write_text( - json.dumps( - { - "version": 1, - "table_count": len(merged), - "tables": merged, - }, - ensure_ascii=False, - ), - encoding="utf-8", - ) - result.has_table_digest = True - result.table_count = len(merged) - await self._log.info( - f"[Compile] Targeted table extraction added " - f"{len(targeted_tables)} tables for " - f"{Path(entry.path).name}" - ) except Exception as exc: result.error = str(exc) @@ -1707,17 +1703,226 @@ def _merge_table_digests( page = cls._get_table_page(tbl) if page is not None and page in existing_pages: continue - # Normalise to digest table format for consistency merged.append({ "page_number": page, - "markdown": tbl.get("content", ""), - "row_count": None, - "col_count": None, - "cells": [], - "source": tbl.get("source", "targeted"), + "markdown": tbl.get("markdown", "") or tbl.get("content", ""), + "row_count": tbl.get("row_count"), + "col_count": tbl.get("col_count"), + "cells": tbl.get("cells", []), + "source": tbl.get("source", "supplementary"), }) return merged + async def _supplement_table_digest( + self, + file_path: str, + new_tables: list[dict], + result: "FileCompileResult", + *, + source_label: str, + ) -> None: + """Merge supplementary tables into the persisted table digest. + + Loads the existing digest (if any), merges *new_tables* with + page-level deduplication, and writes the updated digest back. + Updates *result* metadata in place. + """ + if not new_tables: + return + + file_hash = get_fast_hash(file_path) or "" + if not file_hash: + return + + digest_dir = self._compile_dir / "table_digests" + digest_path = digest_dir / f"{file_hash}.json" + + existing: list[dict] = [] + if result.has_table_digest and digest_path.exists(): + try: + raw = json.loads(digest_path.read_text(encoding="utf-8")) + existing = raw.get("tables", []) + except Exception: + pass + + merged = self._merge_table_digests(existing, new_tables) + if not merged: + return + + digest_dir.mkdir(parents=True, exist_ok=True) + digest_path.write_text( + json.dumps( + {"version": 1, "table_count": len(merged), "tables": merged}, + ensure_ascii=False, + ), + encoding="utf-8", + ) + result.has_table_digest = True + result.table_count = len(merged) + await self._log.info( + f"[Compile] {source_label}: +{len(new_tables)} tables for " + f"{Path(file_path).name} (total={len(merged)})" + ) + + def _get_covered_table_pages(self, file_path: str) -> Set[int]: + """Return the set of page numbers already present in the table digest.""" + file_hash = get_fast_hash(file_path) or "" + if not file_hash: + return set() + + digest_path = ( + self._compile_dir / "table_digests" / f"{file_hash}.json" + ) + if not digest_path.exists(): + return set() + + try: + raw = json.loads(digest_path.read_text(encoding="utf-8")) + pages: Set[int] = set() + for t in raw.get("tables", []): + p = self._get_table_page(t) + if p is not None: + pages.add(p) + return pages + except Exception: + return set() + + # ------------------------------------------------------------------ # + # Tree-independent content-based table scanning (P1) # + # ------------------------------------------------------------------ # + + async def _content_based_table_scan( + self, + file_path: str, + total_pages: Optional[int], + kreuzberg_table_pages: Set[int], + ) -> list[dict]: + """Scan *all* PDF pages for table-like regions via numeric density. + + Unlike :meth:`_targeted_table_extraction` this method does **not** + depend on tree node metadata (``page_range``, ``table_count``). + It reads every page through pypdf and applies the same density + + region-detection heuristics, skipping pages that already have a + kreuzberg-detected table. + + Args: + file_path: Path to the PDF file. + total_pages: Total page count (from extraction metadata). + kreuzberg_table_pages: Page numbers already covered by kreuzberg + layout-detected tables. + + Returns: + List of table dicts compatible with the digest format:: + + {"page": int, "content": str, "source": "content_scan"} + """ + if not total_pages or total_pages <= 0: + return [] + + all_page_nums = list(range(1, total_pages + 1)) + try: + pages = DocumentExtractor.extract_pages(file_path, all_page_nums) + except Exception as exc: + await self._log.warning( + f"[Compile] Content-based scan: page read failed for " + f"{Path(file_path).name}: {exc}" + ) + return [] + + results: list[dict] = [] + for pc in pages: + if pc.page_number in kreuzberg_table_pages: + continue + if not self._page_has_table_density(pc.content): + continue + regions = self._identify_table_regions(pc.content) + for region in regions: + results.append({ + "page": pc.page_number, + "content": region[:_TARGETED_TABLE_MAX_CHARS], + "source": "content_scan", + }) + return results + + def _find_force_ocr_candidates( + self, + file_path: str, + total_pages: Optional[int], + covered_pages: Set[int], + ) -> List[int]: + """Identify pages worth re-extracting with forced OCR. + + Returns 0-indexed page numbers for pages that have high numeric + density (suggesting tabular content) but are NOT already covered + by any table in the digest. The result is capped at + :data:`_FORCE_OCR_MAX_PAGES`. + """ + if not total_pages or total_pages <= 0: + return [] + + all_page_nums = list(range(1, total_pages + 1)) + try: + pages = DocumentExtractor.extract_pages(file_path, all_page_nums) + except Exception: + return [] + + candidates: List[int] = [] + for pc in pages: + if pc.page_number in covered_pages: + continue + if self._page_has_table_density(pc.content): + candidates.append(pc.page_number - 1) # 0-indexed for kreuzberg + + return sorted(candidates)[:_FORCE_OCR_MAX_PAGES] + + # ------------------------------------------------------------------ # + # Selective force-OCR re-extraction (P2) # + # ------------------------------------------------------------------ # + + async def _selective_force_ocr_tables( + self, + file_path: str, + gap_pages: List[int], + ) -> list[dict[str, Any]]: + """Re-extract specific pages with forced OCR + layout detection. + + For pages where the native text layer was not recognized as tables + by kreuzberg's RT-DETR model, re-rendering as images may yield + better layout detection results. Uses ``force_ocr_pages`` so only + the targeted pages are OCR'd (no full-document penalty). + + Args: + file_path: Path to the PDF. + gap_pages: 0-indexed page numbers to force OCR on. Capped at + :data:`_FORCE_OCR_MAX_PAGES` to bound compile time. + + Returns: + List of kreuzberg-format table dicts (with ``markdown``, + ``cells``, ``page_number``). + """ + from sirchmunk.utils.document_extractor import ExtractionProfile + + if not gap_pages: + return [] + + capped = sorted(gap_pages)[:_FORCE_OCR_MAX_PAGES] + + profile = ExtractionProfile( + output_format="markdown", + extract_tables=True, + force_ocr_pages=tuple(capped), + ) + try: + extraction = await DocumentExtractor.extract(file_path, profile) + except Exception as exc: + await self._log.warning( + f"[Compile] Selective force-OCR failed for " + f"{Path(file_path).name}: {exc}" + ) + return [] + + return extraction.tables + # ------------------------------------------------------------------ # # Summary index for embedding + BM25 fallback # # ------------------------------------------------------------------ # diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py index a022a8d..b2835f5 100644 --- a/src/sirchmunk/utils/document_extractor.py +++ b/src/sirchmunk/utils/document_extractor.py @@ -72,6 +72,14 @@ class ExtractionProfile: when set, OCR is always applied regardless of text layer presence. """ + force_ocr_pages: Optional[tuple[int, ...]] = None + """Force OCR on specific pages only (0-indexed). + + Maps to kreuzberg's ``ExtractionConfig.force_ocr_pages``. + Mutually exclusive with :attr:`force_ocr` — when both are set, + ``force_ocr`` takes precedence. + """ + pdf_password: Optional[str] = None """Password for encrypted PDFs.""" @@ -459,12 +467,12 @@ def _build_config(profile: ExtractionProfile): if profile.extract_tables: try: from kreuzberg import LayoutDetectionConfig - # kreuzberg >= 4.5.0: model-based table detection (RT-DETR v2) - # Default: table_model="tatr", apply_heuristics=True - layout_config = LayoutDetectionConfig() + layout_config = LayoutDetectionConfig( + confidence_threshold=0.3, + apply_heuristics=True, + table_model="slanet_auto", + ) except ImportError: - # kreuzberg < 4.5.0: tables extracted via heuristics only; - # filtering is handled in _convert_result(). pass # --- Assemble ExtractionConfig --- @@ -475,6 +483,8 @@ def _build_config(profile: ExtractionProfile): kwargs["ocr"] = ocr_config if profile.force_ocr: kwargs["force_ocr"] = True + elif profile.force_ocr_pages: + kwargs["force_ocr_pages"] = list(profile.force_ocr_pages) if page_config is not None: kwargs["pages"] = page_config if pdf_config is not None: From 63ed04795a704da5e6214b652871d0882df776ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 27 Apr 2026 02:49:02 +0800 Subject: [PATCH 43/70] enhance compiler table extraction --- src/sirchmunk/learnings/compiler.py | 252 ++++++++++++++++++++++------ 1 file changed, 198 insertions(+), 54 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index ddd509b..3316e54 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -8,6 +8,7 @@ """ import asyncio +import bisect import json import math import os @@ -66,6 +67,23 @@ # Selective force-OCR: max pages to re-extract with forced OCR per document _FORCE_OCR_MAX_PAGES = 30 +# Shared numeric-token regex for table detection heuristics. +# Matches: $1,234 (1,234) 12.5% 3.14e-5 1,000 +_NUM_TOKEN_RE = re.compile( + r"(?:" + r"[\$€£¥]\s*[\d,.]+|" + r"\([\d,.]+\)|" + r"[\d,.]+%|" + r"[\d]+\.[\d]+(?:[eE][+-]?\d+)?|" + r"[\d,]{2,}" + r")" +) + +# A single line with >= this many numeric tokens is treated as a dense +# table row (or multiple rows concatenated), enabling detection even when +# pypdf flattens the entire page to one or two lines. +_DENSE_LINE_MIN_TOKENS = 15 + # Excel table-level adaptive sampling constants _XLSX_TOTAL_ROW_BUDGET = 100 # Total sampled rows budget across all sheets _XLSX_MIN_ROWS_PER_SHEET = 3 # Minimum sampled rows per sheet @@ -692,10 +710,16 @@ async def _compile_single_file( # Phase 2.6: Content-based full-page table scan (tree-independent) if ext == ".pdf" and extraction.page_count: covered_pages = self._get_covered_table_pages(entry.path) + tree_root = ( + result.tree.root + if result.tree and result.tree.root else None + ) content_tables = await self._content_based_table_scan( entry.path, extraction.page_count, covered_pages, + enhanced_content=content, + tree_root=tree_root, ) await self._supplement_table_digest( entry.path, content_tables, result, @@ -1595,10 +1619,15 @@ def _walk(node: "TreeNode") -> None: @staticmethod def _page_has_table_density(page_text: str) -> bool: - """Return True if *page_text* has numeric density above the threshold. + """Return True if *page_text* likely contains tabular numeric data. - Counts digits and common table symbols (``$``, ``%``, ``(``, ``)``) - relative to total non-whitespace characters. + Two independent signals (either suffices): + + 1. **Character-level density** — fraction of digit/symbol chars + relative to total non-whitespace exceeds the threshold. + 2. **Token-dense line** — any single line contains + ``_DENSE_LINE_MIN_TOKENS`` or more numeric tokens, which + catches pages where pypdf flattens all content into ≤ 2 lines. """ if not page_text: return False @@ -1609,17 +1638,26 @@ def _page_has_table_density(page_text: str) -> bool: 1 for ch in page_text if ch.isdigit() or ch in "$%(),.+-" ) - return (numeric_chars / non_ws) >= _TABLE_NUMERIC_DENSITY_THRESHOLD + if (numeric_chars / non_ws) >= _TABLE_NUMERIC_DENSITY_THRESHOLD: + return True + return any( + len(_NUM_TOKEN_RE.findall(line)) >= _DENSE_LINE_MIN_TOKENS + for line in page_text.split("\n") + ) @staticmethod def _identify_table_regions(page_text: str) -> list[str]: """Identify contiguous table-like regions in *page_text*. - Heuristic rules: - - Lines containing multiple numeric tokens (dollar amounts, %, - parenthesised negatives) are considered *numeric rows*. - - A run of >= 3 consecutive numeric rows forms a table region. - - Leading/trailing whitespace rows are trimmed. + Two complementary strategies: + + 1. **Consecutive-line detection** — a run of ≥ 3 lines each + containing ≥ 2 numeric tokens forms a table region. Works + well when pypdf preserves per-row line breaks. + 2. **Dense-line detection** — a *single* line with ≥ + ``_DENSE_LINE_MIN_TOKENS`` numeric tokens is treated as a + table region. This handles PDFs where pypdf collapses + the entire page into one or two very long lines. Returns: List of extracted region strings (may be empty). @@ -1627,52 +1665,44 @@ def _identify_table_regions(page_text: str) -> list[str]: if not page_text: return [] - # Pattern: line has at least 2 numeric-looking tokens - _NUM_TOKEN = re.compile( - r"(?:" - r"[\$€£¥]\s*[\d,.]+|" - r"\([\d,.]+\)|" - r"[\d,.]+%|" - r"[\d]+\.[\d]+(?:[eE][+-]?\d+)?|" - r"[\d,]{2,}" - r")" - ) _MIN_NUMS_PER_LINE = 2 _MIN_CONSECUTIVE = 3 lines = page_text.split("\n") - is_numeric = [ - len(_NUM_TOKEN.findall(line)) >= _MIN_NUMS_PER_LINE - for line in lines + token_counts = [ + len(_NUM_TOKEN_RE.findall(line)) for line in lines ] regions: list[str] = [] - run_start: int | None = None + captured_lines: set[int] = set() - for i, flag in enumerate(is_numeric): - if flag: + # --- Strategy 1: consecutive-line runs --- + run_start: int | None = None + for i, cnt in enumerate(token_counts): + if cnt >= _MIN_NUMS_PER_LINE: if run_start is None: run_start = i else: if run_start is not None: - run_len = i - run_start - if run_len >= _MIN_CONSECUTIVE: - # Include one context line above/below + if i - run_start >= _MIN_CONSECUTIVE: start = max(0, run_start - 1) end = min(len(lines), i + 1) regions.append( "\n".join(lines[start:end]).strip() ) + captured_lines.update(range(start, end)) run_start = None - - # Flush trailing run - if run_start is not None: - run_len = len(lines) - run_start - if run_len >= _MIN_CONSECUTIVE: - start = max(0, run_start - 1) - regions.append( - "\n".join(lines[start:]).strip() - ) + if run_start is not None and len(lines) - run_start >= _MIN_CONSECUTIVE: + start = max(0, run_start - 1) + regions.append("\n".join(lines[start:]).strip()) + captured_lines.update(range(start, len(lines))) + + # --- Strategy 2: dense-line detection --- + for i, cnt in enumerate(token_counts): + if cnt >= _DENSE_LINE_MIN_TOKENS and i not in captured_lines: + start = max(0, i - 1) + end = min(len(lines), i + 2) + regions.append("\n".join(lines[start:end]).strip()) return regions @@ -1795,30 +1825,54 @@ async def _content_based_table_scan( self, file_path: str, total_pages: Optional[int], - kreuzberg_table_pages: Set[int], + covered_pages: Set[int], + *, + enhanced_content: Optional[str] = None, + tree_root: Optional[Any] = None, ) -> list[dict]: - """Scan *all* PDF pages for table-like regions via numeric density. + """Scan PDF pages for table-like regions via numeric density. - Unlike :meth:`_targeted_table_extraction` this method does **not** - depend on tree node metadata (``page_range``, ``table_count``). - It reads every page through pypdf and applies the same density + - region-detection heuristics, skipping pages that already have a - kreuzberg-detected table. + Uses a two-tier strategy: + + 1. **pypdf page scan** — reads every page individually. Works well + when pypdf preserves per-row line breaks. + 2. **ENHANCED content fallback** — if pypdf yields poor line + structure (> 50 % of pages have ≤ 3 lines), falls back to + scanning the kreuzberg ENHANCED markdown content, which often + has better formatting. Page numbers are recovered via the + tree's ``char_range → page_range`` mapping. Args: - file_path: Path to the PDF file. - total_pages: Total page count (from extraction metadata). - kreuzberg_table_pages: Page numbers already covered by kreuzberg - layout-detected tables. + file_path: Path to the PDF file. + total_pages: Total page count. + covered_pages: Page numbers already in the table digest. + enhanced_content: Cached kreuzberg ENHANCED text (optional). + tree_root: Tree root node for char → page mapping (optional). Returns: - List of table dicts compatible with the digest format:: - - {"page": int, "content": str, "source": "content_scan"} + List of table dicts compatible with the digest format. """ if not total_pages or total_pages <= 0: return [] + results = await self._pypdf_page_scan( + file_path, total_pages, covered_pages, + ) + + if results or not enhanced_content or not tree_root: + return results + + return self._enhanced_content_scan( + enhanced_content, total_pages, covered_pages, tree_root, + ) + + async def _pypdf_page_scan( + self, + file_path: str, + total_pages: int, + covered_pages: Set[int], + ) -> list[dict]: + """Primary scan: per-page pypdf extraction with density heuristics.""" all_page_nums = list(range(1, total_pages + 1)) try: pages = DocumentExtractor.extract_pages(file_path, all_page_nums) @@ -1830,20 +1884,110 @@ async def _content_based_table_scan( return [] results: list[dict] = [] + poor_line_count = 0 for pc in pages: - if pc.page_number in kreuzberg_table_pages: + if len(pc.content.split("\n")) <= 3: + poor_line_count += 1 + if pc.page_number in covered_pages: continue if not self._page_has_table_density(pc.content): continue - regions = self._identify_table_regions(pc.content) - for region in regions: + for region in self._identify_table_regions(pc.content): results.append({ "page": pc.page_number, "content": region[:_TARGETED_TABLE_MAX_CHARS], "source": "content_scan", }) + + if results: + return results + + # Signal that pypdf line quality is poor — caller should try fallback + if poor_line_count > total_pages * 0.5: + return [] + + return results + + @staticmethod + def _enhanced_content_scan( + enhanced_content: str, + total_pages: int, + covered_pages: Set[int], + tree_root: Any, + ) -> list[dict]: + """Fallback scan: use ENHANCED (kreuzberg markdown) content. + + Scans the full ENHANCED text line-by-line for dense-token lines, + then maps each detected region back to a page number using the + tree's ``char_range → page_range`` mapping. + """ + char_page_map = KnowledgeCompiler._build_char_to_page_map( + tree_root, total_pages, + ) + if not char_page_map: + return [] + + breakpoints = [cp[0] for cp in char_page_map] + + results: list[dict] = [] + offset = 0 + for line in enhanced_content.split("\n"): + token_count = len(_NUM_TOKEN_RE.findall(line)) + if token_count >= _DENSE_LINE_MIN_TOKENS: + idx = bisect.bisect_right(breakpoints, offset) - 1 + page = char_page_map[max(0, idx)][1] if idx >= 0 else 1 + if page not in covered_pages: + results.append({ + "page": page, + "content": line[:_TARGETED_TABLE_MAX_CHARS], + "source": "content_scan:enhanced", + }) + covered_pages.add(page) + offset += len(line) + 1 # +1 for '\n' + return results + @staticmethod + def _build_char_to_page_map( + tree_root: Any, + total_pages: int, + ) -> list[tuple[int, int]]: + """Build a sorted (char_start, page_number) list from tree leaves. + + Enables efficient binary-search lookup from any character offset + in the ENHANCED content to the corresponding page number. + """ + entries: list[tuple[int, int]] = [] + + def _collect(node: Any) -> None: + children = getattr(node, "children", None) or [] + if isinstance(node, dict): + children = node.get("children", []) + pr = ( + getattr(node, "page_range", None) + if not isinstance(node, dict) + else node.get("page_range") + ) + cr = ( + getattr(node, "char_range", None) + if not isinstance(node, dict) + else node.get("char_range") + ) + if not children and cr and pr: + page = pr[0] if isinstance(pr, (list, tuple)) else pr + char_start = cr[0] if isinstance(cr, (list, tuple)) else cr + if page and char_start is not None: + entries.append((int(char_start), int(page))) + for ch in children: + _collect(ch) + + _collect(tree_root) + + if not entries: + return [(0, 1)] + entries.sort() + return entries + def _find_force_ocr_candidates( self, file_path: str, From b760119624bc22c7d4eb6e3b03cff588002e12b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 27 Apr 2026 14:37:55 +0800 Subject: [PATCH 44/70] fix table extraction --- src/sirchmunk/learnings/compiler.py | 379 +++++++++++++++++++++++++++- src/sirchmunk/llm/prompts.py | 18 ++ src/sirchmunk/search.py | 39 ++- 3 files changed, 432 insertions(+), 4 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 3316e54..8b08357 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -84,6 +84,28 @@ # pypdf flattens the entire page to one or two lines. _DENSE_LINE_MIN_TOKENS = 15 +# --------------------------------------------------------------------------- +# Heading normalisation: candidate extraction patterns +# --------------------------------------------------------------------------- +# kreuzberg sometimes renders section titles as ``**bold text**`` or bare +# short standalone lines instead of ``## heading``. The tree indexer can +# only split on markdown headings, so these "invisible" titles get absorbed +# into parent nodes. +# +# We extract *candidates* via lightweight regexes and let the LLM classify +# which ones are genuine section headings (language/domain-agnostic). + +_BOLD_LINE_RE = re.compile( + r"^\*\*((?:(?!\*\*).)+)\*\*\s*$", + re.MULTILINE, +) + +_STANDALONE_LINE_RE = re.compile( + r"(?:^|\n\n)([^\n]{5,100})\n\n", +) + +_HEADING_CANDIDATE_CAP = 40 + # Excel table-level adaptive sampling constants _XLSX_TOTAL_ROW_BUDGET = 100 # Total sampled rows budget across all sheets _XLSX_MIN_ROWS_PER_SHEET = 3 # Minimum sampled rows per sheet @@ -590,6 +612,7 @@ async def _compile_single_file( entry.path, DocumentExtractor.ENHANCED, ) content = extraction.content + content = await self._normalize_bold_headings(content) if not content or len(content.strip()) < 100: result.error = "Insufficient text content" return result @@ -741,6 +764,12 @@ async def _compile_single_file( source_label="Selective force-OCR", ) + # Phase 2.8: Enrich targeted-extraction tables with ENHANCED content + if ext == ".pdf" and result.has_table_digest: + self._enrich_table_digest_content( + entry.path, content, tree_root=None, + ) + except Exception as exc: result.error = str(exc) await self._log.warning(f"[Compile] Failed: {entry.path}: {exc}") @@ -1617,6 +1646,177 @@ def _walk(node: "TreeNode") -> None: _walk(root) return candidates + # ------------------------------------------------------------------ # + # LLM-based heading normalisation # + # ------------------------------------------------------------------ # + + @staticmethod + def _extract_heading_candidates( + content: str, + ) -> list[tuple[re.Match, str, str]]: + """Extract candidate lines that *might* be section headings. + + Returns a list of ``(match, title_text, source_tag)`` triples + where *source_tag* is ``"bold"`` or ``"standalone"``. + + Bold lines (``**Title**``) are always candidates. Short + standalone lines (surrounded by blank lines, 10-100 chars) are + included only when they pass structural heuristics that filter + out data rows, sentences, and existing headings. + """ + occupied: list[tuple[int, int]] = [] + candidates: list[tuple[re.Match, str, str]] = [] + + def _overlaps(start: int, end: int) -> bool: + return any(s < end and start < e for s, e in occupied) + + for m in _BOLD_LINE_RE.finditer(content): + title = m.group(1).strip() + if title and not _overlaps(m.start(), m.end()): + occupied.append((m.start(), m.end())) + candidates.append((m, title, "bold")) + + for m in _STANDALONE_LINE_RE.finditer(content): + text = m.group(1).strip() + if len(text) < 10: + continue + text_offset = m.start() + m.group(0).index(m.group(1)) + if _overlaps(text_offset, text_offset + len(m.group(1))): + continue + if text.startswith(("#", "**")): + continue + if _NUM_TOKEN_RE.search(text): + continue + if text.endswith((".", "。", "!", "?", "!", "?")): + continue + if len(text.split()) > 12: + continue + occupied.append((text_offset, text_offset + len(m.group(1)))) + candidates.append((m, text, "standalone")) + + candidates.sort(key=lambda t: t[0].start()) + return candidates[:_HEADING_CANDIDATE_CAP] + + async def _normalize_bold_headings(self, content: str) -> str: + """Detect and promote bold/standalone section titles to headings. + + Three-phase pipeline: + 1. **Extract** candidate lines via regex (deterministic). + 2. **Classify** candidates with a single LLM call — the LLM + returns which indices are section headings and their level. + 3. **Replace** confirmed headings deterministically. + + Short-circuits when no candidates are found (zero LLM calls). + On any LLM / parse failure, returns the original content unchanged + (graceful degradation — equivalent to no-op). + + The transformation is idempotent: existing ``#`` headings never + enter the candidate set. + """ + if not content: + return content + + candidates = self._extract_heading_candidates(content) + if not candidates: + return content + + listing = "\n".join( + f"{i}: \"{title}\"" for i, (_, title, _tag) in enumerate(candidates) + ) + + from sirchmunk.llm.prompts import COMPILE_CLASSIFY_HEADINGS + prompt = COMPILE_CLASSIFY_HEADINGS.format(candidates=listing) + + try: + resp = await self._llm.achat( + [{"role": "user", "content": prompt}], + ) + raw = resp.content.strip() + headings = self._parse_heading_classifications(raw, len(candidates)) + except Exception: + return content + + if not headings: + return content + + return self._apply_heading_promotions(content, candidates, headings) + + @staticmethod + def _parse_heading_classifications( + raw: str, + num_candidates: int, + ) -> list[tuple[int, int]]: + """Parse LLM JSON response into a list of ``(idx, level)`` pairs. + + Robustly handles markdown code fences, trailing commas, and + out-of-range indices. Returns an empty list on any parse failure. + """ + cleaned = raw.strip() + if cleaned.startswith("```"): + lines = cleaned.splitlines() + lines = [ln for ln in lines if not ln.strip().startswith("```")] + cleaned = "\n".join(lines).strip() + + try: + items = json.loads(cleaned) + except json.JSONDecodeError: + m = re.search(r"\[.*\]", cleaned, re.DOTALL) + if not m: + return [] + try: + items = json.loads(m.group()) + except json.JSONDecodeError: + return [] + + if not isinstance(items, list): + return [] + + result: list[tuple[int, int]] = [] + for item in items: + if isinstance(item, dict): + idx = item.get("idx") + level = item.get("level", 2) + elif isinstance(item, int): + idx, level = item, 2 + else: + continue + if not isinstance(idx, int) or not (0 <= idx < num_candidates): + continue + level = max(2, min(4, int(level))) + result.append((idx, level)) + return result + + @staticmethod + def _apply_heading_promotions( + content: str, + candidates: list[tuple[re.Match, str, str]], + headings: list[tuple[int, int]], + ) -> str: + """Apply heading promotions to *content* in reverse-offset order. + + Processes replacements from end-to-start so that earlier offsets + remain valid after each substitution. + """ + heading_map: dict[int, int] = dict(headings) + + replacements: list[tuple[int, int, str]] = [] + for idx, (match, title, tag) in enumerate(candidates): + if idx not in heading_map: + continue + level = heading_map[idx] + prefix = "#" * level + if tag == "bold": + replacements.append((match.start(), match.end(), f"{prefix} {title}")) + else: + text_start = match.start() + match.group(0).index(match.group(1)) + text_end = text_start + len(match.group(1)) + replacements.append((text_start, text_end, f"{prefix} {title}")) + + replacements.sort(key=lambda r: r[0], reverse=True) + for start, end, replacement in replacements: + content = content[:start] + replacement + content[end:] + return content + @staticmethod def _page_has_table_density(page_text: str) -> bool: """Return True if *page_text* likely contains tabular numeric data. @@ -1818,7 +2018,184 @@ def _get_covered_table_pages(self, file_path: str) -> Set[int]: return set() # ------------------------------------------------------------------ # - # Tree-independent content-based table scanning (P1) # + # P1: Enrich table digest with ENHANCED content # + # ------------------------------------------------------------------ # + + @staticmethod + def _build_page_char_map( + tree_root: Any, + max_page_span: int = _TABLE_PAGE_SPAN_LIMIT, + ) -> Dict[int, Tuple[int, int]]: + """Map page numbers to ``(start_char, end_char)`` in ENHANCED content. + + Aggregates ``char_range`` bounds from leaf nodes whose + ``page_range`` intersects a given page. To avoid inflated + ranges from wide-spanning nodes (e.g. a cover-page node + spanning pages 1–85), only nodes with a page span ≤ + *max_page_span* are used when available; wider nodes serve + as a fallback. + """ + # (char_start, char_end, page_span) per page + entries: Dict[int, List[Tuple[int, int, int]]] = {} + + def _walk(node: Any) -> None: + children = getattr(node, "children", None) or [] + if isinstance(node, dict): + children = node.get("children", []) + if not children: + pr = ( + getattr(node, "page_range", None) + if not isinstance(node, dict) + else node.get("page_range") + ) + cr = ( + getattr(node, "char_range", None) + if not isinstance(node, dict) + else node.get("char_range") + ) + if ( + pr + and cr + and len(pr) >= 2 + and len(cr) >= 2 + ): + span = int(pr[1]) - int(pr[0]) + 1 + for p in range(int(pr[0]), int(pr[1]) + 1): + entries.setdefault(p, []).append( + (int(cr[0]), int(cr[1]), span) + ) + for ch in children: + _walk(ch) + + _walk(tree_root) + + result: Dict[int, Tuple[int, int]] = {} + for page, elist in entries.items(): + narrow = [e for e in elist if e[2] <= max_page_span] + chosen = narrow if narrow else elist + result[page] = ( + min(e[0] for e in chosen), + max(e[1] for e in chosen), + ) + return result + + @staticmethod + def _find_enhanced_region( + enhanced_content: str, + pypdf_text: str, + budget: int = _TARGETED_TABLE_MAX_CHARS, + ) -> Optional[str]: + """Locate the ENHANCED content region matching *pypdf_text*. + + Uses progressively shorter text anchors extracted from the + pypdf content to find the corresponding position in the + ENHANCED (kreuzberg markdown) text. Whitespace is normalised + in the anchor to handle formatting differences (pypdf line + breaks vs kreuzberg markdown spacing). This avoids reliance + on page-number alignment, which may differ between the two + extractors. + + Returns the ENHANCED slice (up to *budget* chars) or ``None``. + """ + text = pypdf_text.strip() + for prefix in ("Table of Contents\n", "Table of Contents "): + if text.startswith(prefix): + text = text[len(prefix):] + text = text.strip() + + for anchor_len in (80, 50, 30): + raw = text[:anchor_len].strip() + if len(raw) < 15: + continue + anchor = " ".join(raw.split()) + pos = enhanced_content.find(anchor) + if pos < 0: + continue + start = max( + 0, + enhanced_content.rfind("\n", max(0, pos - 300), pos) + 1, + ) + end = min(len(enhanced_content), start + budget) + return enhanced_content[start:end].strip() + + return None + + def _enrich_table_digest_content( + self, + file_path: str, + enhanced_content: str, + tree_root: Optional[Any], + ) -> None: + """Replace pypdf-sourced table text with ENHANCED content slices. + + Targeted extraction tables use pypdf, which often produces dense + single-line text (the "2-line page" problem). This method + locates each table's content in the ENHANCED (kreuzberg markdown) + text via anchor matching and replaces the ``markdown`` field when + the ENHANCED version has substantially better structure. + + Only tables whose ``source`` indicates pypdf origin are + candidates; kreuzberg-detected tables already have high-quality + markdown and are left untouched. + """ + if not enhanced_content: + return + + file_hash = get_fast_hash(file_path) or "" + if not file_hash: + return + + digest_path = ( + self._compile_dir / "table_digests" / f"{file_hash}.json" + ) + if not digest_path.exists(): + return + + try: + raw = json.loads(digest_path.read_text(encoding="utf-8")) + tables = raw.get("tables", []) + except Exception: + return + + if not tables: + return + + modified = False + for table in tables: + source = table.get("source", "") + if not ( + source.startswith("targeted:") + or source == "content_scan" + ): + continue + + current = table.get("markdown", "") + if not current: + continue + + enhanced_region = self._find_enhanced_region( + enhanced_content, current, + ) + if not enhanced_region: + continue + + current_lines = len(current.strip().split("\n")) + enhanced_lines = len(enhanced_region.split("\n")) + + if enhanced_lines > max(current_lines, 3): + table["markdown"] = enhanced_region[ + :_TARGETED_TABLE_MAX_CHARS + ] + modified = True + + if modified: + digest_path.write_text( + json.dumps(raw, ensure_ascii=False), + encoding="utf-8", + ) + + # ------------------------------------------------------------------ # + # Tree-independent content-based table scanning # # ------------------------------------------------------------------ # async def _content_based_table_scan( diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 074847e..71d5836 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -565,6 +565,24 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: - Use the same language as the summary""" +COMPILE_CLASSIFY_HEADINGS = """Classify each bold text line as either a **section heading** or **non-heading**. + +A line is a *section heading* if it serves as the title of a major structural division of the document (chapter, section, subsection, exhibit, schedule, financial statement, note, etc.). +A line is *non-heading* if it is emphasis text, a label, a caption, a total/subtotal row, or any inline bold phrase that does not introduce a new document section. + +For each heading, also assign a Markdown heading level (2–4): +- Level 2: top-level sections (e.g. financial statements, major chapters) +- Level 3: sub-sections (e.g. notes to financial statements, sub-chapters) +- Level 4: sub-sub-sections + +Return ONLY a JSON array of objects for the lines that ARE headings. +Each object: {{"idx": <0-based index>, "level": <2|3|4>}} +If none are headings, return an empty array: [] + +Bold lines: +{candidates}""" + + COMPILE_MERGE_KNOWLEDGE = """You are merging new information into an existing knowledge cluster. ### Existing Knowledge diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index c38dfee..2c1c9b5 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -4144,19 +4144,52 @@ def _filter_tables_by_page_range( and page_start <= t["page_number"] <= page_end ] + _TABLE_RELEVANCE_MIN_PREFIX = 5 + @staticmethod def _score_table_relevance( markdown: str, query_tokens: frozenset, ) -> float: """Score a table's relevance to the query via token overlap. - Returns a value in [0, 1] representing the fraction of *query_tokens* - found in the table's markdown text (case-insensitive). + Uses two matching strategies per token: + + 1. **Exact substring** — fast check whether the token appears + anywhere in the table text (original behaviour). + 2. **Prefix match** — handles morphological variation such as + plural/singular (*inventory* ↔ *inventories*) by comparing + word prefixes of at least ``_TABLE_RELEVANCE_MIN_PREFIX`` + characters. Only attempted when the exact match misses. + + Returns a value in [0, 1] representing the fraction of + *query_tokens* matched. """ if not markdown or not query_tokens: return 0.0 + + min_pfx = AgenticSearch._TABLE_RELEVANCE_MIN_PREFIX md_lower = markdown.lower() - hits = sum(1 for tok in query_tokens if tok in md_lower) + md_words = None # lazily built on first prefix-match attempt + + hits = 0 + for tok in query_tokens: + if tok in md_lower: + hits += 1 + continue + # Prefix-match fallback + pfx_len = min(len(tok), min_pfx) + if pfx_len < 4: + continue + if md_words is None: + md_words = frozenset(md_lower.split()) + prefix = tok[:pfx_len] + if any( + w[:pfx_len] == prefix + for w in md_words + if len(w) >= pfx_len + ): + hits += 1 + return hits / len(query_tokens) @staticmethod From e55ada78b709420a36951d89e34e967b39cfbecb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sat, 9 May 2026 15:23:51 +0800 Subject: [PATCH 45/70] improve compile for summary and table --- src/sirchmunk/learnings/compiler.py | 28 +++- src/sirchmunk/learnings/tree_indexer.py | 189 +++++++++++++++++++++++- src/sirchmunk/search.py | 110 +++++++++++++- 3 files changed, 317 insertions(+), 10 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 8b08357..25868d6 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -776,6 +776,24 @@ async def _compile_single_file( return result + @staticmethod + def _is_generic_summary(summary: str, min_specificity_len: int = 80) -> bool: + """Check whether a summary is too generic to be useful for retrieval. + + A generic summary typically contains only structural descriptions + (e.g., "This document contains several sections") without specific + content indicators. Detection uses summary length and information + density as domain-agnostic proxies. + """ + if not summary: + return True + stripped = summary.strip() + if len(stripped) < min_specificity_len: + return True + # Count unique substantive words (>4 chars) as a proxy for specificity + words = set(w.lower() for w in stripped.split() if len(w) > 4) + return len(words) < 8 + async def _extract_summary( self, file_path: str, @@ -786,13 +804,19 @@ async def _extract_summary( When a tree is available its root already contains an LLM-synthesized summary (produced by ``_synthesize_root_summary`` during tree build), - so we reuse it directly — no redundant LLM call. + so we reuse it directly — unless the summary is too generic (Plan 2), + in which case we fall back to multi-section LLM summarization. For large documents without a tree, uses multi-section sampling (beginning, middle, end) to capture the full scope of the document. """ if tree and tree.root and tree.root.summary: - return tree.root.summary + if not self._is_generic_summary(tree.root.summary): + return tree.root.summary + await self._log.info( + f"[Compile] Root summary too generic for {Path(file_path).name}, " + f"falling back to LLM summarization" + ) preview = self._build_summary_preview(content) from sirchmunk.llm.prompts import COMPILE_DOC_SUMMARY diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 9cf450e..10cab2b 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -46,6 +46,10 @@ _TREE_PREVIEW_MAX = 50_000 # Maximum preview window (~12K tokens) _TREE_PREVIEW_RATIO = 0.15 # Fraction of document to preview +# Structured content detection thresholds (Plan 1: generic table recognition) +_STRUCT_MD_TABLE_MIN_ROWS = 3 # Min markdown table rows to classify as structured +_STRUCT_NUMERIC_DENSITY_THRESHOLD = 0.20 # Fraction of numeric tokens in a text segment + # Extensions eligible for tree indexing _TREE_EXTENSIONS = { ".pdf", ".docx", ".doc", ".md", ".markdown", @@ -445,6 +449,9 @@ async def _build_tree_from_toc( # Merge consecutive fragment entries into virtual parents toc_entries = self._merge_fragment_entries(toc_entries) + # Plan 4: Group disproportionately large tail entries (exhibits/appendices) + toc_entries = self._merge_supplementary_entries(toc_entries) + seen_ids: set = set() children = self._toc_entries_to_nodes( toc_entries, content, len(content), seen_ids, @@ -466,6 +473,74 @@ async def _build_tree_from_toc( children=children, ) + @staticmethod + def _merge_supplementary_entries(entries: List[Any]) -> List[Any]: + """Merge tail entries with disproportionately large spans into a virtual parent. + + Detects when the last few entries collectively span much more content + than the preceding entries — a generic structural signal for exhibits, + appendices, or attachment sections. Groups them under a single + navigable node to prevent them from dominating tree navigation. + + Uses only structural signals (char span ratios, position in document) + — no domain-specific keywords. Returns original entries when the + structural pattern is not detected or when too few entries remain. + """ + if len(entries) < 4: + return entries + + def _span(e: Any) -> int: + if hasattr(e, 'char_start') and hasattr(e, 'char_end'): + if e.char_end and e.char_start is not None: + return max(0, e.char_end - e.char_start) + return 0 + + spans = [_span(e) for e in entries] + total_span = sum(spans) + if total_span == 0: + return entries + + # Scan backwards to find tail entries whose cumulative span is + # disproportionately large while individually being much larger + # than the body-section baseline. Uses 25th percentile instead of + # median so that many large tail entries cannot inflate the baseline. + non_zero_spans = [s for s in spans if s > 0] + if len(non_zero_spans) < 4: + return entries + sorted_spans = sorted(non_zero_spans) + q25_idx = max(0, len(sorted_spans) // 4) + baseline_span = sorted_spans[q25_idx] + + tail_start = len(entries) + cumulative = 0 + for i in range(len(entries) - 1, 0, -1): + if spans[i] > baseline_span * 3: + cumulative += spans[i] + tail_start = i + else: + break + + tail_count = len(entries) - tail_start + # Require at least 2 tail entries spanning > 40% of total content + if tail_count < 2 or cumulative / total_span < 0.40: + return entries + + # Also ensure enough primary entries remain + if tail_start < 2: + return entries + + from copy import deepcopy + first_tail = entries[tail_start] + last_tail = entries[-1] + merged = deepcopy(first_tail) + merged.title = f"Supplementary Material ({tail_count} sections)" + if hasattr(last_tail, 'char_end') and last_tail.char_end: + merged.char_end = last_tail.char_end + merged.children = list(entries[tail_start:]) + + result = list(entries[:tail_start]) + [merged] + return result if len(result) >= 2 else entries + @staticmethod def _merge_fragment_entries(entries: List[Any]) -> List[Any]: """Merge consecutive fragment TOC entries into virtual parent nodes. @@ -591,10 +666,19 @@ def _toc_entries_to_nodes( total_pages=total_pages, ) + # Plan 1: Detect structured/tabular content and add navigation hint + # to help LLM-driven navigation prioritize data-rich sections. + # Deliberately keeps content_type="text" so _classify_leaves + # routes to kreuzberg char_range (higher fidelity than pypdf). + summary_text = section_text.strip() + section_sample = content[start:min(start + 2000, end)] + if DocumentTreeIndexer._detect_structured_content(section_sample): + summary_text = f"[Data/Tables] {summary_text}" + node = TreeNode( node_id=nid, title=entry.title, - summary=section_text.strip(), + summary=summary_text, char_range=(start, end), level=level, page_range=page_range, @@ -636,6 +720,44 @@ def _compute_adaptive_depth(content_length: int) -> int: return depth return 2 # minimum depth + @staticmethod + def _detect_structured_content(text: str, sample_size: int = 2000) -> bool: + """Detect whether text contains structured/tabular data using generic signals. + + Uses two high-precision, domain-agnostic heuristics (any triggers True): + 1. Markdown table syntax (pipe-delimited rows with separator line) + 2. High numeric token density (currency, percentages, large numbers) + + Intentionally omits lower-precision signals (multi-space alignment, + tab counts) because PDF-extracted text frequently has irregular + spacing that causes false positives. + + Args: + text: Content segment to analyze. + sample_size: Max chars to analyze (avoids scanning huge sections). + """ + sample = text[:sample_size] + if not sample.strip(): + return False + + # Signal 1: Markdown table syntax — pipe-separated rows with header separator + pipe_lines = [ln for ln in sample.split("\n") if ln.strip().startswith("|")] + separator_lines = [ln for ln in pipe_lines if re.match(r"\|\s*[-:]+", ln)] + data_rows = len(pipe_lines) - len(separator_lines) + if data_rows >= _STRUCT_MD_TABLE_MIN_ROWS and separator_lines: + return True + + # Signal 2: Numeric token density — high ratio of numeric-pattern tokens + non_ws = re.sub(r"\s+", "", sample) + if len(non_ws) > 50: + from sirchmunk.learnings.compiler import _NUM_TOKEN_RE + num_tokens = _NUM_TOKEN_RE.findall(sample) + total_chars = sum(len(t) for t in num_tokens) + if total_chars / len(non_ws) >= _STRUCT_NUMERIC_DENSITY_THRESHOLD: + return True + + return False + async def _build_node( self, text: str, level: int, max_depth: int, offset: int = 0, @@ -691,13 +813,74 @@ async def _build_node( children=children, ) + @staticmethod + def _collect_representative_nodes( + children: List[TreeNode], + max_nodes: int = 15, + ) -> List[TreeNode]: + """Collect representative nodes from multiple tree depths. + + Gathers direct children plus a sample of deeper descendants to + ensure the summary captures actual content topics — not just + top-level structural wrappers that may be uninformative. + + Strategy: + - Layer 1: all direct children (structural overview). + - Layer 2+: BFS preferring **leaf nodes** (actual content topics) + over intermediate nodes (whose summaries overlap children). + """ + reps: List[TreeNode] = [] + seen: set = set() + + # Layer 1: all direct children (even wrappers — they provide structure) + for c in children: + if c.node_id not in seen and len(reps) < max_nodes: + reps.append(c) + seen.add(c.node_id) + + # Layer 2+: BFS collecting leaf nodes with substantive summaries. + # Leaf nodes represent actual content sections; intermediate nodes + # often have summaries that redundantly overlap their children. + queue = [] + for c in children: + for gc in c.children: + queue.append(gc) + + while queue and len(reps) < max_nodes: + node = queue.pop(0) + if node.node_id in seen: + continue + + is_leaf = not node.children + has_substance = ( + (node.summary and len(node.summary.strip()) > 20) + or node.table_count > 0 + ) + + if is_leaf and has_substance: + reps.append(node) + seen.add(node.node_id) + elif not is_leaf: + # Expand intermediate nodes without adding them — + # their content is represented by their leaf descendants. + for ch in node.children: + queue.append(ch) + + return reps + async def _synthesize_root_summary(self, children: List[TreeNode]) -> str: - """Synthesize a document-level summary from children's section summaries.""" + """Synthesize a document-level summary from multi-depth section info. + + Gathers representative nodes from multiple tree depths to produce + a summary that reflects actual document content, not just top-level + wrapper headings like "SEC Filing" or "Table of Contents". + """ if not children: return "" from sirchmunk.llm.prompts import COMPILE_SYNTHESIZE_SUMMARY + representatives = self._collect_representative_nodes(children) sections_text = "\n".join( - f"- {c.title}: {c.summary}" for c in children + f"- {n.title}: {n.summary}" for n in representatives ) prompt = COMPILE_SYNTHESIZE_SUMMARY.format(sections=sections_text) resp = await self._llm.achat([{"role": "user", "content": prompt}]) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 2c1c9b5..b296b8f 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2166,7 +2166,7 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: ".css", ".bash", ".java", ".c", ".cpp", ".h", ".go", ".rs", } _FAST_CONTEXT_WINDOW = 30 # ± lines around each grep hit - _FAST_MAX_EVIDENCE_CHARS = 15_000 + _FAST_MAX_EVIDENCE_CHARS = 20_000 # Plan 5: expanded from 15K to accommodate richer table evidence _FAST_SMALL_FILE_THRESHOLD = 100_000 # 100K chars - read full file instead of grep sampling # --- Wiki-enhanced ranking constants --- @@ -2221,6 +2221,20 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _CHAR_RANGE_MAX_SPAN_RATIO: float = 0.8 """char_range spanning more than this ratio of the document is treated as invalid.""" + # --- Tree navigation retry (Plan 3) --- + _NAV_RETRY_MIN_EVIDENCE_CHARS: int = 200 + """Evidence below this length triggers a retry with expanded results.""" + _NAV_RETRY_EXPANDED_RESULTS: int = 8 + """Expanded max_results for retry navigation pass.""" + + # --- Table evidence budgets (Plan 5) --- + _TABLE_EVIDENCE_DEFAULT_CHARS: int = 10_000 + """Default max_chars for _format_table_evidence (was 6000).""" + _TABLE_EVIDENCE_PER_RANGE_CHARS: int = 8_000 + """Max chars for per-page-range table supplement in tree nav (was 4000).""" + _TABLE_EVIDENCE_STANDALONE_CHARS: int = 12_000 + """Max chars for standalone table digest fallback when tree nav evidence is thin.""" + # --- Self-correction expanded sampling --- _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 10 """Expanded tree navigation leaf count for same-file re-sampling (default nav uses 5).""" @@ -4086,6 +4100,19 @@ def _is_valid_char_range( span_ratio = (end - start) / text_len return span_ratio < self._CHAR_RANGE_MAX_SPAN_RATIO + @staticmethod + def _is_evidence_sufficient(evidence: str, min_chars: int = 0) -> bool: + """Check whether collected evidence has enough substance to answer a query. + + Uses a length threshold as a lightweight, domain-agnostic proxy. + Empty or near-empty evidence (e.g., only headers with no data) + fails the check, triggering a retry with expanded parameters. + """ + if not evidence: + return False + stripped = evidence.strip() + return len(stripped) >= min_chars + @staticmethod def _load_compile_content( work_path: Path, file_path: str, @@ -4195,7 +4222,7 @@ def _score_table_relevance( @staticmethod def _format_table_evidence( tables: List[Dict[str, Any]], - max_chars: int = 6000, + max_chars: int = 10_000, query: str = "", ) -> str: """Format table digest entries as LLM-friendly evidence text. @@ -4440,10 +4467,63 @@ async def _navigate_tree_for_evidence( parts, fname, lf, lf.summary, ) + # ── Plan 3: Retry with expanded results if evidence is insufficient ── + # Triggers on: (a) zero evidence parts, OR (b) evidence too thin. + _current_ev_text = "\n\n".join(parts) + _needs_retry = ( + max_results < self._NAV_RETRY_EXPANDED_RESULTS + and not self._is_evidence_sufficient( + _current_ev_text, self._NAV_RETRY_MIN_EVIDENCE_CHARS, + ) + ) + if _needs_retry: + try: + retry_leaves = await indexer.navigate( + tree, query, + max_results=self._NAV_RETRY_EXPANDED_RESULTS, + ) + if retry_leaves: + r_page, r_char, r_summary = self._classify_leaves(retry_leaves) + for rl in r_summary: + self._append_evidence_part(parts, fname, rl, rl.summary) + + # Page-level extraction for retry (mirrors Phase 2) + if r_page: + r_all_pages: set = set() + for _rl, (rsp, rep) in r_page: + r_all_pages.update(range(rsp, rep + 1)) + try: + r_page_contents = DocumentExtractor.extract_pages( + file_path, sorted(r_all_pages), + ) + r_page_map = {pc.page_number: pc.content for pc in r_page_contents} + for rl, (rsp, rep) in r_page: + r_seg = [r_page_map[p] for p in range(rsp, rep + 1) if r_page_map.get(p, "").strip()] + if r_seg: + self._append_evidence_part(parts, fname, rl, "\n".join(r_seg)) + except Exception: + pass + + # Char-range extraction for retry (mirrors Phase 3) + if r_char: + r_text = self._load_compile_content(self.work_path, file_path) or "" + for rl in r_char: + s, e = rl.char_range + if self._is_valid_char_range(s, e, len(r_text)) and r_text: + seg = r_text[s:e] + if seg.strip(): + self._append_evidence_part(parts, fname, rl, seg) + + leaves = retry_leaves + print(f"SEARCH_WIKI_DEBUG [N3.1] retry_nav: {len(retry_leaves)} leaves", flush=True) + except Exception: + pass + if not parts: return None # Supplement with table evidence if available + _all_tables = None try: from sirchmunk.utils.file_utils import get_fast_hash _file_hash = get_fast_hash(file_path) @@ -4465,7 +4545,8 @@ async def _navigate_tree_for_evidence( ) if leaf_tables: table_text = self._format_table_evidence( - leaf_tables, max_chars=4000, + leaf_tables, + max_chars=self._TABLE_EVIDENCE_PER_RANGE_CHARS, query=query, ) if table_text: @@ -4475,9 +4556,28 @@ async def _navigate_tree_for_evidence( except Exception: pass - print(f"SEARCH_WIKI_DEBUG [N5] table_supplement: tables_loaded={len(_all_tables) if '_all_tables' in dir() and _all_tables else 0}", flush=True) - + # Plan 3: If evidence is still too thin, add full table digest as standalone evidence = "\n\n".join(parts) + if ( + not self._is_evidence_sufficient( + evidence, self._NAV_RETRY_MIN_EVIDENCE_CHARS, + ) + and _all_tables + ): + standalone_table_ev = self._format_table_evidence( + _all_tables, + max_chars=self._TABLE_EVIDENCE_STANDALONE_CHARS, + query=query, + ) + if standalone_table_ev: + parts.append( + f"[{fname} - Standalone Table Evidence]\n{standalone_table_ev}" + ) + evidence = "\n\n".join(parts) + print(f"SEARCH_WIKI_DEBUG [N5.1] standalone_table_fallback: len={len(standalone_table_ev)}", flush=True) + + print(f"SEARCH_WIKI_DEBUG [N5] table_supplement: tables_loaded={len(_all_tables) if _all_tables else 0}", flush=True) + print(f"SEARCH_WIKI_DEBUG [N6] _navigate_tree_for_evidence result: len={len(evidence) if evidence else 0}", flush=True) await self._logger.info( f"[FAST:TreeNav] Extracted {len(parts)} sections, " From fe351a1fa3d422568625ba549e09ce245ab80e7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sat, 9 May 2026 16:19:28 +0800 Subject: [PATCH 46/70] fix tree index --- src/sirchmunk/llm/prompts.py | 18 +++-- src/sirchmunk/search.py | 152 +++++++++++++++++++++++++++++++++-- 2 files changed, 157 insertions(+), 13 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 71d5836..909402d 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -423,6 +423,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. **Format**: Use Markdown (headings, bullet points, and bold text) for high readability. 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. +5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. ### Input Data - **User Input**: {user_input} @@ -443,12 +444,12 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". ### Output Format - -[If the query asks for a specific value, ratio, number, or factual answer, state ONLY the direct answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). If the query is open-ended, write a one-sentence conclusion.] - -[Generate the Markdown Briefing here with detailed analysis and supporting evidence] +[Generate the Markdown Briefing here with detailed analysis, supporting evidence, and full calculation steps. Complete all reasoning BEFORE the PRECISE_ANSWER tag.] + +[State ONLY the final verified answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] + true/false true/false """ @@ -463,6 +464,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. **Format**: Use Markdown (headings, bullet points, and bold text) for high readability. 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. +5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. ### Document Context {document_context} @@ -486,12 +488,12 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". ### Output Format - -[If the query asks for a specific value, ratio, number, or factual answer, state ONLY the direct answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). If the query is open-ended, write a one-sentence conclusion.] - -[Generate the Markdown Briefing here with detailed analysis and supporting evidence] +[Generate the Markdown Briefing here with detailed analysis, supporting evidence, and full calculation steps. Complete all reasoning BEFORE the PRECISE_ANSWER tag.] + +[State ONLY the final verified answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] + true/false true/false """ diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index b296b8f..4b11a09 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -932,14 +932,24 @@ async def _search_by_filename( await self._logger.error(f"Traceback: {traceback.format_exc()}") return [] - @staticmethod - def _parse_summary_response(llm_response: str) -> Tuple[str, bool, bool]: + _SELF_CORRECTION_PATTERN = re.compile( + r'(?:correction|re-?verif|wait,?\s|let me re|actually|self-correction|recalcul)', + re.IGNORECASE, + ) + + @classmethod + def _parse_summary_response(cls, llm_response: str) -> Tuple[str, bool, bool]: """Parse LLM response to extract summary, precise answer, and quality decisions. When a ```` tag is present, its content is prepended to the summary so downstream consumers (evaluation judges, UIs) see the direct answer prominently without needing separate tag awareness. + The method also detects self-correction patterns in the summary text: + when the LLM revised its calculation mid-stream, the last numeric + conclusion is used if PRECISE_ANSWER is absent or matches the + pre-correction value. + Returns: Tuple of (summary_text, should_save_flag, should_answer_flag) """ @@ -2227,6 +2237,17 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _NAV_RETRY_EXPANDED_RESULTS: int = 8 """Expanded max_results for retry navigation pass.""" + _CHAR_RANGE_MIN_SPAN: int = 200 + """Minimum char_range span to trust as substantive content. + + Nodes whose char_range covers fewer characters than this threshold + (e.g. a TOC entry that only records the section title) are demoted + to page-level extraction when a valid page_range is available. + """ + + _NAV_COMPLEMENT_MIN_COMPONENTS: int = 2 + """Minimum query decomposition components to trigger complementary navigation.""" + # --- Table evidence budgets (Plan 5) --- _TABLE_EVIDENCE_DEFAULT_CHARS: int = 10_000 """Default max_chars for _format_table_evidence (was 6000).""" @@ -4029,8 +4050,8 @@ async def _tree_guided_sample( ) return evidence - @staticmethod - def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: + @classmethod + def _classify_leaves(cls, leaves: list) -> Tuple[List[tuple], List, List]: """Classify leaf nodes by preferred extraction strategy. For non-table leaves, **char_range** (kreuzberg markdown) is preferred @@ -4039,6 +4060,11 @@ def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: ``extract_text()``. page_range remains available on each leaf for table-supplement filtering even when the leaf is routed to char_leaves. + Thin char_range nodes (span < ``_CHAR_RANGE_MIN_SPAN``) are demoted + to page-level extraction when a valid page_range exists, as they + typically represent TOC entries whose char offsets only cover the + section title rather than the actual content. + Returns: (page_leaves, char_leaves, summary_leaves) triple: - page_leaves: list of (leaf, page_range) — page-level extraction @@ -4048,6 +4074,7 @@ def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: page_leaves: List[tuple] = [] char_leaves: List = [] summary_leaves: List = [] + min_span = cls._CHAR_RANGE_MIN_SPAN for leaf in leaves: # Table nodes: prefer page-level extraction for raw original content @@ -4078,7 +4105,12 @@ def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: ) if has_char: - char_leaves.append(leaf) + start, end = leaf.char_range + span = end - start if end > start else 0 + if span < min_span and has_page: + page_leaves.append((leaf, page_range)) + else: + char_leaves.append(leaf) elif has_page: page_leaves.append((leaf, page_range)) elif getattr(leaf, 'summary', None): @@ -4113,6 +4145,64 @@ def _is_evidence_sufficient(evidence: str, min_chars: int = 0) -> bool: stripped = evidence.strip() return len(stripped) >= min_chars + _MULTI_COMPONENT_PATTERNS: Tuple[Tuple[str, ...], ...] = ( + ("balance sheet", "income statement"), + ("balance sheet", "cash flow"), + ("income statement", "cash flow"), + ("accounts payable", "cost of"), + ("accounts payable", "inventory"), + ("current assets", "current liabilities"), + ("revenue", "net income", "earnings"), + ("operating income", "depreciation"), + ) + + @staticmethod + def _decompose_query_components(query: str) -> List[str]: + """Extract distinct data-source components from a multi-part query. + + Scans for known multi-component patterns (e.g. a ratio needing data + from both Balance Sheet and Income Statement) and returns a list of + component phrases that the evidence should cover. + """ + q = query.lower() + components: List[str] = [] + for group in AgenticSearch._MULTI_COMPONENT_PATTERNS: + hits = [phrase for phrase in group if phrase in q] + if len(hits) >= 2: + components.extend(hits) + if not components: + financial_keywords = [ + "balance sheet", "income statement", "cash flow", + "accounts payable", "accounts receivable", "inventory", + "current liabilities", "current assets", "total assets", + "revenue", "cost of", "cogs", "depreciation", "amortization", + "operating income", "net income", "earnings", + ] + for kw in financial_keywords: + if kw in q: + components.append(kw) + seen: set = set() + return [c for c in components if not (c in seen or seen.add(c))] + + @staticmethod + def _check_leaf_coverage( + leaves: list, components: List[str], + ) -> Tuple[List[str], List[str]]: + """Check which query components are covered by the navigated leaves. + + Returns: + (covered, missing) — lists of component phrases. + """ + if not leaves or not components: + return [], list(components) + leaf_text = " ".join( + (getattr(l, 'title', '') or '') + " " + (getattr(l, 'summary', '') or '') + for l in leaves + ).lower() + covered = [c for c in components if c in leaf_text] + missing = [c for c in components if c not in leaf_text] + return covered, missing + @staticmethod def _load_compile_content( work_path: Path, file_path: str, @@ -4467,6 +4557,58 @@ async def _navigate_tree_for_evidence( parts, fname, lf, lf.summary, ) + # ── Phase 4: Complementary navigation for multi-component queries ── + # When a query requires data from multiple document sections (e.g. + # Balance Sheet + Income Statement for a ratio), the initial navigate + # may only reach one component. Detect missing components and run a + # focused second navigate pass with a refined query. + _query_components = self._decompose_query_components(query) + if len(_query_components) >= self._NAV_COMPLEMENT_MIN_COMPONENTS: + _covered, _missing = self._check_leaf_coverage(leaves, _query_components) + if _missing: + _complement_query = f"{query} — focus on: {', '.join(_missing)}" + try: + _existing_ids = {id(l) for l in leaves} + comp_leaves = await indexer.navigate( + tree, _complement_query, max_results=max_results, + ) + comp_new = [l for l in (comp_leaves or []) if id(l) not in _existing_ids] + if comp_new: + c_page, c_char, c_summary = self._classify_leaves(comp_new) + for cl in c_summary: + self._append_evidence_part(parts, fname, cl, cl.summary) + if c_page: + c_all_pages: set = set() + for _cl, (csp, cep) in c_page: + c_all_pages.update(range(csp, cep + 1)) + try: + c_contents = DocumentExtractor.extract_pages( + file_path, sorted(c_all_pages), + ) + c_map = {pc.page_number: pc.content for pc in c_contents} + for cl, (csp, cep) in c_page: + c_seg = [c_map[p] for p in range(csp, cep + 1) if c_map.get(p, "").strip()] + if c_seg: + self._append_evidence_part(parts, fname, cl, "\n".join(c_seg)) + except Exception: + pass + if c_char: + c_text = self._load_compile_content(self.work_path, file_path) or "" + for cl in c_char: + s, e = cl.char_range + if self._is_valid_char_range(s, e, len(c_text)) and c_text: + seg = c_text[s:e] + if seg.strip(): + self._append_evidence_part(parts, fname, cl, seg) + leaves = list(leaves) + comp_new + print( + f"SEARCH_WIKI_DEBUG [N3.2] complement_nav: " + f"missing={_missing}, new_leaves={len(comp_new)}", + flush=True, + ) + except Exception: + pass + # ── Plan 3: Retry with expanded results if evidence is insufficient ── # Triggers on: (a) zero evidence parts, OR (b) evidence too thin. _current_ev_text = "\n\n".join(parts) From cb1ba9659c7f2f099d762ee9f714e61fe047f494 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 11 May 2026 16:03:22 +0800 Subject: [PATCH 47/70] update compiler --- src/sirchmunk/learnings/compiler.py | 31 ++++++++-- src/sirchmunk/search.py | 90 ++++++++++++++++++++++++----- 2 files changed, 100 insertions(+), 21 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 25868d6..a2a193c 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -67,6 +67,10 @@ # Selective force-OCR: max pages to re-extract with forced OCR per document _FORCE_OCR_MAX_PAGES = 30 +# Incremental manifest flush: persist manifest every N completed files +# to survive interrupted compiles without excessive I/O overhead. +_MANIFEST_FLUSH_INTERVAL = 10 + # Shared numeric-token regex for table detection heuristics. # Matches: $1,234 (1,234) 12.5% 3.14e-5 1,000 _NUM_TOKEN_RE = re.compile( @@ -440,6 +444,7 @@ async def compile( # Phase 2: compile files with bounded concurrency semaphore = asyncio.Semaphore(concurrency) results: List[FileCompileResult] = [] + _files_since_flush = 0 async def _bounded(entry: FileEntry) -> FileCompileResult: async with semaphore: @@ -454,7 +459,6 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: else: if result.tree: report.trees_built += 1 - # Update manifest manifest.files[result.path] = FileManifestEntry( file_hash=get_fast_hash(result.path) or "", compiled_at=datetime.now(timezone.utc).isoformat(), @@ -471,6 +475,17 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: _mentry = manifest.files[result.path] print(f"SEARCH_WIKI_DEBUG [C4] manifest_entry: has_tree={_mentry.has_tree}, has_table_digest={_mentry.has_table_digest}, file_hash={_mentry.file_hash}", flush=True) + # Incremental manifest flush to survive interrupted compiles + _files_since_flush += 1 + if _files_since_flush >= _MANIFEST_FLUSH_INTERVAL: + manifest.last_compile_at = datetime.now(timezone.utc).isoformat() + self._save_manifest(manifest) + _files_since_flush = 0 + + # Phase 2 checkpoint: persist manifest before knowledge aggregation + manifest.last_compile_at = datetime.now(timezone.utc).isoformat() + self._save_manifest(manifest) + # Phase 3: aggregate results into knowledge network await self._log.info("[Compile] Phase 3: Knowledge aggregation") for r in results: @@ -484,15 +499,15 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: await self._log.info("[Compile] Phase 4: Building cross-references") report.cross_refs_built = await self._build_cross_references(results) - # Phase 5: persist manifest + document catalog + # Phase 5: persist final manifest + derived indices + # Catalog and summary index are rebuilt from the manifest, so even + # partial compiles produce usable search-time metadata. manifest.last_compile_at = datetime.now(timezone.utc).isoformat() self._save_manifest(manifest) self._storage.force_sync() - # Generate document catalog for search-time routing self._build_document_catalog(manifest) - # Phase: Build summary index for embedding+BM25 fallback (optional, non-blocking) await self._build_summary_index(manifest) report.elapsed_seconds = time.monotonic() - t0 @@ -2553,7 +2568,13 @@ def _load_manifest(self) -> CompileManifest: return CompileManifest() def _save_manifest(self, manifest: CompileManifest) -> None: - self._manifest_path.write_text(manifest.to_json(), encoding="utf-8") + """Atomically persist the manifest via write-to-tmp + rename. + + This prevents partial JSON on disk if the process is killed mid-write. + """ + tmp_path = self._manifest_path.with_suffix(".json.tmp") + tmp_path.write_text(manifest.to_json(), encoding="utf-8") + tmp_path.replace(self._manifest_path) # ------------------------------------------------------------------ # # Document catalog for search-time routing # diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 4b11a09..207b837 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2231,6 +2231,14 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _CHAR_RANGE_MAX_SPAN_RATIO: float = 0.8 """char_range spanning more than this ratio of the document is treated as invalid.""" + # --- Hierarchical file selection for large tree pools --- + _TREE_PREFILTER_THRESHOLD: int = 15 + """Tree pool size above which rule-based pre-filtering is applied.""" + _TREE_PREFILTER_MAX_CANDIDATES: int = 10 + """Maximum candidate trees forwarded to the LLM after pre-filtering.""" + _TREE_PREFILTER_MIN_SCORE: float = 0.5 + """Minimum relevance score for a tree to survive pre-filtering.""" + # --- Tree navigation retry (Plan 3) --- _NAV_RETRY_MIN_EVIDENCE_CHARS: int = 200 """Evidence below this length triggers a retry with expanded results.""" @@ -5100,32 +5108,82 @@ def _load_cached_trees(self) -> list: except Exception: return [] + @staticmethod + def _prefilter_trees_by_query( + query: str, trees: list, max_candidates: int, min_score: float, + ) -> list: + """Rule-based pre-filter: score trees by query-token overlap with filenames. + + Extracts meaningful tokens from the query (alphanumeric words, 4-digit + years, multi-word entity fragments) and scores each tree's filename by + weighted token overlap. Returns the top-scoring candidates, or the + full list if fewer than *max_candidates* pass the threshold. + + This avoids sending hundreds of root summaries to the LLM. + """ + raw_tokens = re.findall(r"[A-Za-z0-9]+", query.lower()) + tokens = [t for t in raw_tokens if len(t) >= 2 and t not in _STOP_WORDS] + if not tokens: + return trees + + year_tokens = {t for t in tokens if re.fullmatch(r"(?:19|20)\d{2}", t)} + entity_tokens = {t for t in tokens if len(t) >= 3 and t not in year_tokens} + + scored: List[Tuple[float, int]] = [] + for idx, tree in enumerate(trees): + name_lower = Path(tree.file_path).stem.lower() + name_parts = set(re.findall(r"[a-z0-9]+", name_lower)) + + score = 0.0 + for tok in entity_tokens: + if tok in name_lower: + score += 2.0 + elif any(tok[:4] in part for part in name_parts if len(tok) >= 4): + score += 0.5 + for yr in year_tokens: + if yr in name_lower: + score += 3.0 + + scored.append((score, idx)) + + scored.sort(key=lambda x: -x[0]) + + candidates = [trees[idx] for sc, idx in scored if sc >= min_score] + if not candidates: + return [trees[idx] for _, idx in scored[:max_candidates]] + return candidates[:max_candidates] + async def _llm_select_from_trees( self, query: str, trees: list, max_select: int, ) -> List[str]: - """LLM-driven file selection from tree root summaries. - - Presents root summaries to the LLM and returns the selected file - paths. When the number of trees is at most *max_select*, returns - all paths without an LLM call. + """Two-stage LLM-driven file selection from tree root summaries. - Args: - query: User query string. - trees: List of ``DocumentTree`` objects (pre-loaded). - max_select: Maximum number of files to select. + Stage 1 (rule-based): when the pool exceeds ``_TREE_PREFILTER_THRESHOLD``, + narrow candidates by query-token / filename overlap. + Stage 2 (LLM): present root summaries of the narrowed set for precise selection. - Returns: - Selected file paths, or empty list. + When the number of trees is at most *max_select*, returns all paths + without an LLM call. """ if not trees: return [] if len(trees) <= max_select: return [t.file_path for t in trees] + pool = trees + if len(pool) > self._TREE_PREFILTER_THRESHOLD: + pool = self._prefilter_trees_by_query( + query, pool, + max_candidates=self._TREE_PREFILTER_MAX_CANDIDATES, + min_score=self._TREE_PREFILTER_MIN_SCORE, + ) + if len(pool) <= max_select: + return [t.file_path for t in pool] + listing = "\n".join( f"[{i}] {Path(t.file_path).name}: " f"{(t.root.summary or '')[:self._CATALOG_SUMMARY_TRUNCATE]}" - for i, t in enumerate(trees) + for i, t in enumerate(pool) ) prompt = ( f'Given the query: "{query}"\n\n' @@ -5143,18 +5201,18 @@ async def _llm_select_from_trees( if m: selected_indices = [ idx for idx in json.loads(m.group()) - if isinstance(idx, int) and 0 <= idx < len(trees) + if isinstance(idx, int) and 0 <= idx < len(pool) ] except (json.JSONDecodeError, TypeError): pass if not selected_indices: - selected_indices = list(range(min(max_select, len(trees)))) + selected_indices = list(range(min(max_select, len(pool)))) return [ - trees[idx].file_path + pool[idx].file_path for idx in selected_indices[:max_select] - if Path(trees[idx].file_path).exists() + if Path(pool[idx].file_path).exists() ] async def _probe_tree_index(self, query: str) -> List[str]: From 93d4a1fc0c3a1b40978317cf727235c29ea96fa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 11 May 2026 16:13:11 +0800 Subject: [PATCH 48/70] improve compile efficiency --- src/sirchmunk/learnings/compiler.py | 157 ++++++++++++++++++---------- 1 file changed, 99 insertions(+), 58 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index a2a193c..f21939a 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -71,6 +71,10 @@ # to survive interrupted compiles without excessive I/O overhead. _MANIFEST_FLUSH_INTERVAL = 10 +# Page-level extraction: max pages to load into memory per batch. +# Prevents loading all 200-400 pages of a large PDF at once. +_PAGE_SCAN_BATCH_SIZE = 50 + # Shared numeric-token regex for table detection heuristics. # Matches: $1,234 (1,234) 12.5% 3.14e-5 1,000 _NUM_TOKEN_RE = re.compile( @@ -374,6 +378,29 @@ def __init__( self._compile_dir.mkdir(parents=True, exist_ok=True) self._manifest_path = self._compile_dir / "manifest.json" + # ------------------------------------------------------------------ # + # Resource management # + # ------------------------------------------------------------------ # + + @staticmethod + def _configure_thread_limits() -> None: + """Cap PyTorch / OpenMP / MKL thread count to avoid runaway CPU and memory. + + Only sets defaults when the user has not already configured them via + environment variables, so explicit overrides are always respected. + The cap is half the available CPU cores, clamped to [1, 4]. + """ + cpu_count = os.cpu_count() or 4 + cap = str(max(1, min(cpu_count // 2, 4))) + for var in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS"): + if var not in os.environ: + os.environ[var] = cap + try: + import torch + torch.set_num_threads(int(cap)) + except ImportError: + pass + # ------------------------------------------------------------------ # # Public API # # ------------------------------------------------------------------ # @@ -398,6 +425,9 @@ async def compile( concurrency: Max parallel file compilations. """ import time + + self._configure_thread_limits() + t0 = time.monotonic() report = CompileReport() @@ -441,9 +471,11 @@ async def compile( f"(concurrency={concurrency})" ) - # Phase 2: compile files with bounded concurrency + # Phase 2 + 3 (fused): compile files, aggregate inline, release heavy objects + # Fusing Phase 3 into the completion loop avoids retaining all + # DocumentTree / EvidenceUnit objects until the end of the pipeline. semaphore = asyncio.Semaphore(concurrency) - results: List[FileCompileResult] = [] + _xref_pairs: List[Tuple[str, List[str]]] = [] # lightweight (path, cluster_ids) for Phase 4 _files_since_flush = 0 async def _bounded(entry: FileEntry) -> FileCompileResult: @@ -453,7 +485,6 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: tasks = [_bounded(f) for f in to_compile] for coro in asyncio.as_completed(tasks): result = await coro - results.append(result) if result.error: report.errors.append(f"{result.path}: {result.error}") else: @@ -475,6 +506,16 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: _mentry = manifest.files[result.path] print(f"SEARCH_WIKI_DEBUG [C4] manifest_entry: has_tree={_mentry.has_tree}, has_table_digest={_mentry.has_table_digest}, file_hash={_mentry.file_hash}", flush=True) + # Phase 3 inline: aggregate while the result is still alive + if not result.error and result.summary: + created, merged = await self._aggregate_to_knowledge_network(result) + report.clusters_created += created + report.clusters_merged += merged + + # Retain only lightweight cross-ref data, then drop the heavy result + _xref_pairs.append((result.path, list(result.cluster_ids))) + del result + # Incremental manifest flush to survive interrupted compiles _files_since_flush += 1 if _files_since_flush >= _MANIFEST_FLUSH_INTERVAL: @@ -482,22 +523,15 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: self._save_manifest(manifest) _files_since_flush = 0 - # Phase 2 checkpoint: persist manifest before knowledge aggregation + # Phase 2 checkpoint: persist manifest before cross-references manifest.last_compile_at = datetime.now(timezone.utc).isoformat() self._save_manifest(manifest) - # Phase 3: aggregate results into knowledge network - await self._log.info("[Compile] Phase 3: Knowledge aggregation") - for r in results: - if r.error or not r.summary: - continue - created, merged = await self._aggregate_to_knowledge_network(r) - report.clusters_created += created - report.clusters_merged += merged - - # Phase 4: cross-references + # Phase 4: cross-references (uses only lightweight path+cluster_ids pairs) await self._log.info("[Compile] Phase 4: Building cross-references") - report.cross_refs_built = await self._build_cross_references(results) + report.cross_refs_built = await self._build_cross_references_from_pairs( + _xref_pairs, manifest, + ) # Phase 5: persist final manifest + derived indices # Catalog and summary index are rebuilt from the manifest, so even @@ -1265,25 +1299,23 @@ async def _create_cluster( # Cross-references # # ------------------------------------------------------------------ # - async def _build_cross_references( - self, results: List[FileCompileResult], + async def _build_cross_references_from_pairs( + self, + pairs: List[Tuple[str, List[str]]], + manifest: CompileManifest, ) -> int: """Build co-occurrence edges between clusters that share source files. - Two clusters are co-occurring when the same source file contributed - evidence to both (e.g., different sections compiled into different - clusters). Includes historical data from the manifest. + Accepts lightweight ``(path, cluster_ids)`` pairs instead of full + ``FileCompileResult`` objects to avoid retaining heavy compile results. + Includes historical data from the manifest. """ - # Build a complete map: cluster_id -> set of source file paths cluster_to_files: Dict[str, Set[str]] = {} - # From current compile results - for r in results: - for cid in r.cluster_ids: - cluster_to_files.setdefault(cid, set()).add(r.path) + for path, cluster_ids in pairs: + for cid in cluster_ids: + cluster_to_files.setdefault(cid, set()).add(path) - # From manifest (historical data) - manifest = self._load_manifest() for fp, entry in manifest.files.items(): for cid in entry.cluster_ids: cluster_to_files.setdefault(cid, set()).add(fp) @@ -2288,37 +2320,44 @@ async def _pypdf_page_scan( total_pages: int, covered_pages: Set[int], ) -> list[dict]: - """Primary scan: per-page pypdf extraction with density heuristics.""" - all_page_nums = list(range(1, total_pages + 1)) - try: - pages = DocumentExtractor.extract_pages(file_path, all_page_nums) - except Exception as exc: - await self._log.warning( - f"[Compile] Content-based scan: page read failed for " - f"{Path(file_path).name}: {exc}" - ) - return [] + """Primary scan: per-page pypdf extraction with density heuristics. + Pages are loaded in batches of ``_PAGE_SCAN_BATCH_SIZE`` to bound + peak memory when processing large PDFs (200-400+ pages). + """ results: list[dict] = [] poor_line_count = 0 - for pc in pages: - if len(pc.content.split("\n")) <= 3: - poor_line_count += 1 - if pc.page_number in covered_pages: - continue - if not self._page_has_table_density(pc.content): - continue - for region in self._identify_table_regions(pc.content): - results.append({ - "page": pc.page_number, - "content": region[:_TARGETED_TABLE_MAX_CHARS], - "source": "content_scan", - }) + + for batch_start in range(1, total_pages + 1, _PAGE_SCAN_BATCH_SIZE): + batch_end = min(batch_start + _PAGE_SCAN_BATCH_SIZE, total_pages + 1) + batch_pages = list(range(batch_start, batch_end)) + try: + pages = DocumentExtractor.extract_pages(file_path, batch_pages) + except Exception as exc: + await self._log.warning( + f"[Compile] Content-based scan: page read failed for " + f"{Path(file_path).name}: {exc}" + ) + return [] + + for pc in pages: + if len(pc.content.split("\n")) <= 3: + poor_line_count += 1 + if pc.page_number in covered_pages: + continue + if not self._page_has_table_density(pc.content): + continue + for region in self._identify_table_regions(pc.content): + results.append({ + "page": pc.page_number, + "content": region[:_TARGETED_TABLE_MAX_CHARS], + "source": "content_scan", + }) + del pages if results: return results - # Signal that pypdf line quality is poor — caller should try fallback if poor_line_count > total_pages * 0.5: return [] @@ -2498,7 +2537,8 @@ async def _build_summary_index(self, manifest: CompileManifest) -> None: The index is saved to .cache/compile/summary_index.json and consumed by search.py as a last-resort fallback when rga keyword search fails. - Skips gracefully if dependencies (EmbeddingUtil/TokenizerUtil) are unavailable. + Reuses ``self._embedding`` when available to avoid loading a duplicate + model into memory. Falls back to a fresh instance otherwise. """ try: from sirchmunk.utils.tokenizer_util import TokenizerUtil @@ -2518,7 +2558,6 @@ async def _build_summary_index(self, manifest: CompileManifest) -> None: if not entries: return - # Tokenize summaries + compute TF (always available) tokenizer = TokenizerUtil() for idx, entry in enumerate(entries): tokens = tokenizer.segment(entry.summary) @@ -2527,12 +2566,14 @@ async def _build_summary_index(self, manifest: CompileManifest) -> None: for t in tokens: entry.token_freqs[t] = entry.token_freqs.get(t, 0) + 1 - # Compute embeddings (optional — requires EmbeddingUtil) + # Reuse the compiler's embedding client to avoid duplicate model load try: - from sirchmunk.utils.embedding_util import EmbeddingUtil - embedding_util = EmbeddingUtil() - embedding_util.start_loading() - # Wait up to 60 seconds for model load + embedding_util = self._embedding + if embedding_util is None: + from sirchmunk.utils.embedding_util import EmbeddingUtil + embedding_util = EmbeddingUtil() + embedding_util.start_loading() + await embedding_util._ensure_model_async(timeout=60) if embedding_util.is_ready(): From 929fbc523250c276990d2cc53ab3aa6334836129 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 11 May 2026 16:45:36 +0800 Subject: [PATCH 49/70] improve compile mem usage --- src/sirchmunk/cli/cli.py | 20 +++ src/sirchmunk/learnings/compiler.py | 182 +++++++++++++++------------- 2 files changed, 118 insertions(+), 84 deletions(-) diff --git a/src/sirchmunk/cli/cli.py b/src/sirchmunk/cli/cli.py index 99d6843..4aec43f 100644 --- a/src/sirchmunk/cli/cli.py +++ b/src/sirchmunk/cli/cli.py @@ -1242,6 +1242,22 @@ def cmd_mcp_version(args: argparse.Namespace) -> int: # sirchmunk compile # ------------------------------------------------------------------ + +def _configure_compile_threads() -> None: + """Set sensible thread-count defaults for CPU-bound ML workloads. + + Must be called early — before PyTorch, OpenMP, or kreuzberg's Rust + core are imported — so the environment variables are read at library + init time. User-provided overrides are always respected. + """ + cpu_count = os.cpu_count() or 4 + cap = str(max(1, min(cpu_count // 2, 4))) + for var in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS", + "RAYON_NUM_THREADS"): + if var not in os.environ: + os.environ[var] = cap + + def cmd_compile(args: argparse.Namespace) -> int: """Compile document collections into structured knowledge indices. @@ -1254,6 +1270,10 @@ def cmd_compile(args: argparse.Namespace) -> int: Returns: Exit code (0 for success, non-zero for failure) """ + # Cap thread counts BEFORE heavy libraries are imported, so OpenMP/MKL + # read the correct values at init time. User-set vars are respected. + _configure_compile_threads() + try: work_path = Path( getattr(args, "work_path", None) or str(_get_default_work_path()) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index f21939a..5bd5bef 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -384,21 +384,19 @@ def __init__( @staticmethod def _configure_thread_limits() -> None: - """Cap PyTorch / OpenMP / MKL thread count to avoid runaway CPU and memory. + """Cap PyTorch thread count to reduce per-thread memory allocation. - Only sets defaults when the user has not already configured them via - environment variables, so explicit overrides are always respected. - The cap is half the available CPU cores, clamped to [1, 4]. + Environment variables (OMP_NUM_THREADS, etc.) are set in the CLI + entry point before libraries are imported. This method handles the + PyTorch-specific runtime API that works retroactively. """ cpu_count = os.cpu_count() or 4 - cap = str(max(1, min(cpu_count // 2, 4))) - for var in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS"): - if var not in os.environ: - os.environ[var] = cap + cap = max(1, min(cpu_count // 2, 4)) try: import torch - torch.set_num_threads(int(cap)) - except ImportError: + torch.set_num_threads(cap) + torch.set_num_interop_threads(max(1, cap // 2)) + except (ImportError, RuntimeError): pass # ------------------------------------------------------------------ # @@ -651,6 +649,10 @@ async def _compile_single_file( When *shallow* is True (or file is ineligible for tree indexing), the pipeline skips tree building and summarises via a direct LLM call. + + Large intermediate objects (extraction output, enriched content, + raw tables) are explicitly released after their last use to keep + per-file peak memory bounded. """ result = FileCompileResult(path=entry.path) print(f"SEARCH_WIKI_DEBUG [C1] _compile_single_file: file_path={entry.path}, file_hash={entry.file_hash}", flush=True) @@ -666,6 +668,11 @@ async def _compile_single_file( result.error = "Insufficient text content" return result + # Extract scalar metadata from extraction before releasing it + page_count = extraction.page_count + raw_tables = extraction.tables + del extraction + use_tree = ( not shallow and DocumentTreeIndexer.should_build_tree(entry.path, len(content)) @@ -677,7 +684,7 @@ async def _compile_single_file( from sirchmunk.learnings.toc_extractor import TOCExtractor toc_entries = await TOCExtractor.extract( entry.path, content, - total_pages=extraction.page_count, + total_pages=page_count, ) if toc_entries: await self._log.info( @@ -689,47 +696,53 @@ async def _compile_single_file( result.tree = await self._tree_indexer.build_tree( entry.path, content, toc_entries=toc_entries, - total_pages=extraction.page_count, + total_pages=page_count, ) - # Record TOC / tree metrics on the result for manifest persistence - result.has_explicit_toc = toc_entries is not None and len(toc_entries) > 0 + result.has_explicit_toc = bool(toc_entries) + del toc_entries result.tree_node_count = self._count_tree_nodes(result.tree) print(f"SEARCH_WIKI_DEBUG [C2] tree_build: success={result.tree is not None}, nodes={result.tree_node_count}, tree.file_path={result.tree.file_path if result.tree else 'N/A'}", flush=True) - # Enrich content with structural metadata for non-text types + # --- Summary + topics + evidence (needs content) --- ext = Path(entry.path).suffix.lower() evidence_digest = "" if ext in (".xlsx", ".xls"): - # Excel: use adaptive sampling for both metadata and evidence metadata_prefix, evidence_digest = self._extract_xlsx_sampling(entry.path) - enriched_content = metadata_prefix + content if metadata_prefix else content else: metadata_prefix = self._extract_structured_metadata(entry.path, content) - enriched_content = metadata_prefix + content if metadata_prefix else content - result.summary = await self._extract_summary( - entry.path, enriched_content, result.tree, - ) + # Build enriched_content only for the summary LLM call, then release + if metadata_prefix: + result.summary = await self._extract_summary( + entry.path, metadata_prefix + content, result.tree, + ) + else: + result.summary = await self._extract_summary( + entry.path, content, result.tree, + ) + del metadata_prefix + result.topics = await self._extract_topics(result.summary) result.evidence = self._build_evidence(entry, content, result) - # Persist Excel evidence digest for search-time consumption + # Persist Excel evidence digest if evidence_digest.strip(): try: digest_dir = self._compile_dir / "xlsx_digests" digest_dir.mkdir(parents=True, exist_ok=True) file_hash = get_fast_hash(entry.path) or "" if file_hash: - digest_path = digest_dir / f"{file_hash}.txt" - digest_path.write_text(evidence_digest, encoding="utf-8") + (digest_dir / f"{file_hash}.txt").write_text( + evidence_digest, encoding="utf-8", + ) result.has_xlsx_digest = True except Exception: pass + del evidence_digest - # Cache compile-time ENHANCED content so search can slice - # char_range from the same text the tree was built from. + # Cache ENHANCED content to disk try: file_hash_content = get_fast_hash(entry.path) or "" if file_hash_content and content: @@ -741,83 +754,84 @@ async def _compile_single_file( except Exception: pass - # Persist table digest for documents with extracted tables - if extraction.tables: + # --- Table digest + integration (needs raw_tables, then release) --- + if raw_tables: try: - table_digest = self._build_table_digest(extraction.tables) + table_digest = self._build_table_digest(raw_tables) if table_digest: digest_dir = self._compile_dir / "table_digests" digest_dir.mkdir(parents=True, exist_ok=True) file_hash = get_fast_hash(entry.path) or "" if file_hash: - digest_path = digest_dir / f"{file_hash}.json" - digest_path.write_text( + (digest_dir / f"{file_hash}.json").write_text( json.dumps(table_digest, ensure_ascii=False), encoding="utf-8", ) result.has_table_digest = True - result.table_count = len(extraction.tables) + result.table_count = len(raw_tables) except Exception: pass - print(f"SEARCH_WIKI_DEBUG [C3] table_digest: generated={result.has_table_digest}, count={result.table_count}", flush=True) - - # Integrate tables into tree: annotate counts + create table child nodes - if result.tree and result.tree.root and extraction.tables: - self._integrate_tables_into_tree( - result.tree.root, extraction.tables, - content=content, total_pages=extraction.page_count, - ) - - # Phase 2.5: Targeted table extraction via tree-node structural signals - if result.tree and result.tree.root and ext == ".pdf": - targeted_tables = await self._targeted_table_extraction( - entry.path, result.tree, - ) - await self._supplement_table_digest( - entry.path, targeted_tables, result, - source_label="Targeted extraction", - ) + if result.tree and result.tree.root: + self._integrate_tables_into_tree( + result.tree.root, raw_tables, + content=content, total_pages=page_count, + ) - # Phase 2.6: Content-based full-page table scan (tree-independent) - if ext == ".pdf" and extraction.page_count: - covered_pages = self._get_covered_table_pages(entry.path) - tree_root = ( - result.tree.root - if result.tree and result.tree.root else None - ) - content_tables = await self._content_based_table_scan( - entry.path, - extraction.page_count, - covered_pages, - enhanced_content=content, - tree_root=tree_root, - ) - await self._supplement_table_digest( - entry.path, content_tables, result, - source_label="Content-based scan", - ) + print(f"SEARCH_WIKI_DEBUG [C3] table_digest: generated={result.has_table_digest}, count={result.table_count}", flush=True) + del raw_tables + + # --- Phases 2.5-2.8: secondary table extraction (PDF only) --- + # These phases re-read from the PDF file; `content` is only + # needed for Phase 2.6 fallback and Phase 2.8 enrichment. + if ext == ".pdf": + if result.tree and result.tree.root: + targeted_tables = await self._targeted_table_extraction( + entry.path, result.tree, + ) + await self._supplement_table_digest( + entry.path, targeted_tables, result, + source_label="Targeted extraction", + ) + del targeted_tables - # Phase 2.7: Selective force-OCR for high-density gap pages - if ext == ".pdf" and extraction.page_count: - covered_after_scan = self._get_covered_table_pages(entry.path) - gap_pages = self._find_force_ocr_candidates( - entry.path, extraction.page_count, covered_after_scan, - ) - if gap_pages: - ocr_tables = await self._selective_force_ocr_tables( - entry.path, gap_pages, + if page_count: + covered_pages = self._get_covered_table_pages(entry.path) + tree_root = ( + result.tree.root + if result.tree and result.tree.root else None + ) + content_tables = await self._content_based_table_scan( + entry.path, page_count, covered_pages, + enhanced_content=content, tree_root=tree_root, ) await self._supplement_table_digest( - entry.path, ocr_tables, result, - source_label="Selective force-OCR", + entry.path, content_tables, result, + source_label="Content-based scan", ) + del content_tables - # Phase 2.8: Enrich targeted-extraction tables with ENHANCED content - if ext == ".pdf" and result.has_table_digest: - self._enrich_table_digest_content( - entry.path, content, tree_root=None, - ) + covered_after_scan = self._get_covered_table_pages(entry.path) + gap_pages = self._find_force_ocr_candidates( + entry.path, page_count, covered_after_scan, + ) + if gap_pages: + ocr_tables = await self._selective_force_ocr_tables( + entry.path, gap_pages, + ) + await self._supplement_table_digest( + entry.path, ocr_tables, result, + source_label="Selective force-OCR", + ) + del ocr_tables + + if result.has_table_digest: + self._enrich_table_digest_content( + entry.path, content, tree_root=None, + ) + + # Content is no longer needed — release before returning + del content except Exception as exc: result.error = str(exc) From 207fe59fcac81f3fe004fd15c29027bae9de45a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 11 May 2026 20:42:23 +0800 Subject: [PATCH 50/70] improve extractor multi-processing --- src/sirchmunk/learnings/compiler.py | 74 ++++++++++------- src/sirchmunk/utils/document_extractor.py | 96 +++++++++++++++++++++++ 2 files changed, 143 insertions(+), 27 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 5bd5bef..ee44f8e 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -9,9 +9,12 @@ import asyncio import bisect +import ctypes +import gc import json import math import os +import platform import random import re import hashlib @@ -75,6 +78,20 @@ # Prevents loading all 200-400 pages of a large PDF at once. _PAGE_SCAN_BATCH_SIZE = 50 +# How often to run gc.collect() inside the compile loop (every N files). +_GC_INTERVAL = 5 + + +def _force_gc() -> None: + """Aggressively reclaim Python-managed memory and nudge the C allocator.""" + gc.collect() + if platform.system() == "Linux": + try: + ctypes.CDLL("libc.so.6").malloc_trim(0) + except (OSError, AttributeError): + pass + + # Shared numeric-token regex for table detection heuristics. # Matches: $1,234 (1,234) 12.5% 3.14e-5 1,000 _NUM_TOKEN_RE = re.compile( @@ -475,6 +492,7 @@ async def compile( semaphore = asyncio.Semaphore(concurrency) _xref_pairs: List[Tuple[str, List[str]]] = [] # lightweight (path, cluster_ids) for Phase 4 _files_since_flush = 0 + _files_since_gc = 0 async def _bounded(entry: FileEntry) -> FileCompileResult: async with semaphore: @@ -521,6 +539,11 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: self._save_manifest(manifest) _files_since_flush = 0 + _files_since_gc += 1 + if _files_since_gc >= _GC_INTERVAL: + _force_gc() + _files_since_gc = 0 + # Phase 2 checkpoint: persist manifest before cross-references manifest.last_compile_at = datetime.now(timezone.utc).isoformat() self._save_manifest(manifest) @@ -659,7 +682,7 @@ async def _compile_single_file( try: await self._log.info(f"[Compile] Processing: {Path(entry.path).name}") - extraction = await DocumentExtractor.extract( + extraction = await DocumentExtractor.extract_isolated( entry.path, DocumentExtractor.ENHANCED, ) content = extraction.content @@ -1218,10 +1241,11 @@ async def _aggregate_to_knowledge_network( def _encode_text(self, text: str) -> Optional[Any]: """Encode text to embedding vector, returns None on failure.""" - if not self._embedding: + if not self._embedding or not self._embedding.is_ready(): return None try: - return self._embedding.encode(text) + vectors = self._embedding._encode_sync([text]) + return vectors[0] if len(vectors) > 0 else None except Exception: return None @@ -2497,44 +2521,40 @@ async def _selective_force_ocr_tables( file_path: str, gap_pages: List[int], ) -> list[dict[str, Any]]: - """Re-extract specific pages with forced OCR + layout detection. + """Extract text from gap pages using pypdf (no kreuzberg re-call). - For pages where the native text layer was not recognized as tables - by kreuzberg's RT-DETR model, re-rendering as images may yield - better layout detection results. Uses ``force_ocr_pages`` so only - the targeted pages are OCR'd (no full-document penalty). + Earlier versions spawned a second kreuzberg extraction with + ``force_ocr_pages``, which doubled native memory pressure. + Using pypdf instead avoids Rust/native allocations entirely + while still capturing page text for the table digest. Args: file_path: Path to the PDF. - gap_pages: 0-indexed page numbers to force OCR on. Capped at - :data:`_FORCE_OCR_MAX_PAGES` to bound compile time. + gap_pages: 0-indexed page numbers. Returns: - List of kreuzberg-format table dicts (with ``markdown``, - ``cells``, ``page_number``). + List of table-compatible dicts (``markdown``, ``page_number``). """ - from sirchmunk.utils.document_extractor import ExtractionProfile - if not gap_pages: return [] capped = sorted(gap_pages)[:_FORCE_OCR_MAX_PAGES] - - profile = ExtractionProfile( - output_format="markdown", - extract_tables=True, - force_ocr_pages=tuple(capped), - ) + one_indexed = [p + 1 for p in capped] try: - extraction = await DocumentExtractor.extract(file_path, profile) - except Exception as exc: - await self._log.warning( - f"[Compile] Selective force-OCR failed for " - f"{Path(file_path).name}: {exc}" - ) + pages = DocumentExtractor.extract_pages(file_path, one_indexed) + except Exception: return [] - return extraction.tables + tables: list[dict[str, Any]] = [] + for pc in pages: + text = (pc.content or "").strip() + if text and self._page_has_table_density(text): + tables.append({ + "markdown": text, + "cells": [], + "page_number": pc.page_number, + }) + return tables # ------------------------------------------------------------------ # # Summary index for embedding + BM25 fallback # diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py index b2835f5..68670a3 100644 --- a/src/sirchmunk/utils/document_extractor.py +++ b/src/sirchmunk/utils/document_extractor.py @@ -11,6 +11,9 @@ from __future__ import annotations import asyncio +import concurrent.futures +import dataclasses +import os from dataclasses import dataclass, field from pathlib import Path from typing import Any, ClassVar, List, Optional, Sequence, Union @@ -18,6 +21,41 @@ from loguru import logger +# --------------------------------------------------------------------------- +# Top-level helper for subprocess-based extraction (must be picklable) +# --------------------------------------------------------------------------- + +def _extract_in_worker( + file_path: str, + profile_dict: dict[str, Any], +) -> dict[str, Any]: + """Run kreuzberg extraction inside a worker process. + + Returns a plain dict so the result crosses the process boundary + without dragging native kreuzberg objects (and their Rust allocations) + back into the parent process. + """ + import asyncio as _aio + + async def _run() -> dict[str, Any]: + from sirchmunk.utils.document_extractor import ( + DocumentExtractor, + ExtractionProfile, + ) + profile = ExtractionProfile(**profile_dict) + output = await DocumentExtractor.extract(file_path, profile) + return { + "content": output.content, + "mime_type": output.mime_type, + "metadata": output.metadata, + "tables": output.tables, + "detected_languages": output.detected_languages, + "page_count": output.page_count, + } + + return _aio.run(_run()) + + # --------------------------------------------------------------------------- # Configuration profile # --------------------------------------------------------------------------- @@ -231,6 +269,64 @@ async def extract( ) raise + # Shared process pool — lazily created, workers exit after every task + # so the OS reclaims all native memory (Rust arenas, layout-model caches). + _process_pool: ClassVar[Optional[concurrent.futures.ProcessPoolExecutor]] = None + _POOL_WORKERS: ClassVar[int] = max(1, min(os.cpu_count() or 4, 3)) + + @classmethod + def _get_process_pool(cls) -> concurrent.futures.ProcessPoolExecutor: + if cls._process_pool is None: + cls._process_pool = concurrent.futures.ProcessPoolExecutor( + max_workers=cls._POOL_WORKERS, + max_tasks_per_child=1, + ) + return cls._process_pool + + @staticmethod + async def extract_isolated( + file_path: Union[str, Path], + profile: Optional[ExtractionProfile] = None, + ) -> ExtractionOutput: + """Extract content in an isolated subprocess. + + Identical to :meth:`extract` but runs kreuzberg inside a child + process. ``max_tasks_per_child=1`` ensures each worker exits + after one extraction, allowing the OS to reclaim all native + memory (Rust arenas, layout-model buffers, image caches). + + Falls back to in-process extraction on subprocess failure. + """ + profile = profile or DocumentExtractor.BASIC + profile_dict = { + f.name: getattr(profile, f.name) + for f in dataclasses.fields(profile) + } + + loop = asyncio.get_event_loop() + pool = DocumentExtractor._get_process_pool() + try: + raw = await loop.run_in_executor( + pool, + _extract_in_worker, + str(file_path), + profile_dict, + ) + return ExtractionOutput( + content=raw["content"], + mime_type=raw.get("mime_type", ""), + metadata=raw.get("metadata", {}), + tables=raw.get("tables", []), + detected_languages=raw.get("detected_languages", {}), + page_count=raw.get("page_count"), + ) + except Exception as exc: + logger.warning( + "Subprocess extraction failed for {}, falling back to in-process: {}", + file_path, exc, + ) + return await DocumentExtractor.extract(file_path, profile) + @staticmethod async def extract_bytes( data: bytes, From bbc2bbd6d3d2bd955b60d5fe7510c63c2a0daa6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 11 May 2026 21:02:57 +0800 Subject: [PATCH 51/70] fix ProcessPoolExecutor --- src/sirchmunk/utils/document_extractor.py | 135 ++++++++++++++-------- 1 file changed, 89 insertions(+), 46 deletions(-) diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py index 68670a3..f115687 100644 --- a/src/sirchmunk/utils/document_extractor.py +++ b/src/sirchmunk/utils/document_extractor.py @@ -11,8 +11,8 @@ from __future__ import annotations import asyncio -import concurrent.futures import dataclasses +import multiprocessing as mp import os from dataclasses import dataclass, field from pathlib import Path @@ -22,38 +22,93 @@ # --------------------------------------------------------------------------- -# Top-level helper for subprocess-based extraction (must be picklable) +# Subprocess extraction helpers (module-level for picklability) # --------------------------------------------------------------------------- -def _extract_in_worker( +_EXTRACT_TIMEOUT_S = 600 + + +def _extraction_worker( file_path: str, profile_dict: dict[str, Any], -) -> dict[str, Any]: - """Run kreuzberg extraction inside a worker process. + pipe_w: mp.connection.Connection, +) -> None: + """Child process entry point: run kreuzberg, send result via pipe, exit. - Returns a plain dict so the result crosses the process boundary - without dragging native kreuzberg objects (and their Rust allocations) - back into the parent process. + Sends a plain dict so no native kreuzberg/Rust objects cross the + process boundary. On failure sends ``{"_error": ""}``. """ - import asyncio as _aio + try: + import asyncio as _aio + + async def _run() -> dict[str, Any]: + from sirchmunk.utils.document_extractor import ( + DocumentExtractor, + ExtractionProfile, + ) + profile = ExtractionProfile(**profile_dict) + output = await DocumentExtractor.extract(file_path, profile) + return { + "content": output.content, + "mime_type": output.mime_type, + "metadata": output.metadata, + "tables": output.tables, + "detected_languages": output.detected_languages, + "page_count": output.page_count, + } + + pipe_w.send(_aio.run(_run())) + except BaseException as exc: + try: + pipe_w.send({"_error": str(exc)}) + except Exception: + pass + finally: + pipe_w.close() + + +def _run_extraction_in_child( + file_path: str, + profile_dict: dict[str, Any], +) -> dict[str, Any]: + """Spawn an isolated child process, wait for its result. - async def _run() -> dict[str, Any]: - from sirchmunk.utils.document_extractor import ( - DocumentExtractor, - ExtractionProfile, + Unlike ``ProcessPoolExecutor``, a crash in one child never + poisons future extractions — each call spawns a fresh process. + """ + pipe_r, pipe_w = mp.Pipe(duplex=False) + proc = mp.Process( + target=_extraction_worker, + args=(file_path, profile_dict, pipe_w), + daemon=True, + ) + proc.start() + pipe_w.close() + + try: + if not pipe_r.poll(timeout=_EXTRACT_TIMEOUT_S): + proc.kill() + proc.join(timeout=10) + raise RuntimeError( + f"Extraction timed out after {_EXTRACT_TIMEOUT_S}s" + ) + result = pipe_r.recv() + except EOFError: + proc.join(timeout=10) + raise RuntimeError( + f"Worker crashed (exit code {proc.exitcode})" ) - profile = ExtractionProfile(**profile_dict) - output = await DocumentExtractor.extract(file_path, profile) - return { - "content": output.content, - "mime_type": output.mime_type, - "metadata": output.metadata, - "tables": output.tables, - "detected_languages": output.detected_languages, - "page_count": output.page_count, - } + finally: + pipe_r.close() + + proc.join(timeout=30) + if proc.is_alive(): + proc.kill() + proc.join() - return _aio.run(_run()) + if isinstance(result, dict) and "_error" in result: + raise RuntimeError(result["_error"]) + return result # --------------------------------------------------------------------------- @@ -269,31 +324,20 @@ async def extract( ) raise - # Shared process pool — lazily created, workers exit after every task - # so the OS reclaims all native memory (Rust arenas, layout-model caches). - _process_pool: ClassVar[Optional[concurrent.futures.ProcessPoolExecutor]] = None - _POOL_WORKERS: ClassVar[int] = max(1, min(os.cpu_count() or 4, 3)) - - @classmethod - def _get_process_pool(cls) -> concurrent.futures.ProcessPoolExecutor: - if cls._process_pool is None: - cls._process_pool = concurrent.futures.ProcessPoolExecutor( - max_workers=cls._POOL_WORKERS, - max_tasks_per_child=1, - ) - return cls._process_pool - @staticmethod async def extract_isolated( file_path: Union[str, Path], profile: Optional[ExtractionProfile] = None, ) -> ExtractionOutput: - """Extract content in an isolated subprocess. + """Extract content in a fully isolated child process. + + Each call spawns a fresh ``multiprocessing.Process``. When the + child exits (normally or via crash), the OS reclaims **all** of + its native memory — Rust arenas, layout-model buffers, image + caches — guaranteeing zero accumulation in the parent. - Identical to :meth:`extract` but runs kreuzberg inside a child - process. ``max_tasks_per_child=1`` ensures each worker exits - after one extraction, allowing the OS to reclaim all native - memory (Rust arenas, layout-model buffers, image caches). + Unlike ``ProcessPoolExecutor``, a crash in one extraction never + poisons future calls. Falls back to in-process extraction on subprocess failure. """ @@ -304,11 +348,10 @@ async def extract_isolated( } loop = asyncio.get_event_loop() - pool = DocumentExtractor._get_process_pool() try: raw = await loop.run_in_executor( - pool, - _extract_in_worker, + None, + _run_extraction_in_child, str(file_path), profile_dict, ) From 5af51df12884aa5798ed23fc618ef0f11e579029 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 11 May 2026 21:19:39 +0800 Subject: [PATCH 52/70] clean methods for compiler --- src/sirchmunk/learnings/compiler.py | 56 ++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index ee44f8e..e812bef 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -458,8 +458,10 @@ async def compile( to_compile = changes.added + changes.modified report.files_skipped = len(changes.unchanged) report.files_deleted = len(changes.deleted) - for deleted_path in changes.deleted: - manifest.files.pop(deleted_path, None) + + stale_paths = changes.deleted + [e.path for e in changes.modified] + if stale_paths: + await self._purge_stale_artifacts(stale_paths, manifest) else: to_compile = discovered report.files_skipped = 0 @@ -519,8 +521,6 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: has_table_digest=result.has_table_digest, table_count=result.table_count, ) - _mentry = manifest.files[result.path] - print(f"SEARCH_WIKI_DEBUG [C4] manifest_entry: has_tree={_mentry.has_tree}, has_table_digest={_mentry.has_table_digest}, file_hash={_mentry.file_hash}", flush=True) # Phase 3 inline: aggregate while the result is still alive if not result.error and result.summary: @@ -658,6 +658,53 @@ def _detect_changes( return changes + # ------------------------------------------------------------------ # + # Stale artifact cleanup # + # ------------------------------------------------------------------ # + + async def _purge_stale_artifacts( + self, + file_paths: List[str], + manifest: CompileManifest, + ) -> None: + """Remove disk artifacts and DuckDB clusters for deleted/modified files. + + Called before recompilation so that modified files start with a + clean slate and deleted files leave no residue. + """ + artifact_dirs = { + "trees": ".json", + "content": ".txt", + "table_digests": ".json", + "xlsx_digests": ".txt", + } + + for file_path in file_paths: + entry = manifest.files.get(file_path) + if entry is None: + continue + + file_hash = entry.file_hash + + # 1. Remove disk artifacts keyed by file_hash + if file_hash: + for subdir, ext in artifact_dirs.items(): + artifact = self._compile_dir / subdir / f"{file_hash}{ext}" + try: + artifact.unlink(missing_ok=True) + except OSError: + pass + + # 2. Remove associated knowledge clusters from DuckDB + for cluster_id in entry.cluster_ids: + try: + await self._storage.remove(cluster_id) + except Exception: + pass + + # 3. Drop the manifest entry + manifest.files.pop(file_path, None) + # ------------------------------------------------------------------ # # Single-file compilation # # ------------------------------------------------------------------ # @@ -678,7 +725,6 @@ async def _compile_single_file( per-file peak memory bounded. """ result = FileCompileResult(path=entry.path) - print(f"SEARCH_WIKI_DEBUG [C1] _compile_single_file: file_path={entry.path}, file_hash={entry.file_hash}", flush=True) try: await self._log.info(f"[Compile] Processing: {Path(entry.path).name}") From af5f7e16fd10cfb226c2e9799a8aa156ad8b5e84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 12 May 2026 14:21:04 +0800 Subject: [PATCH 53/70] improve all corpus --- src/sirchmunk/search.py | 126 +++++++++++++++++++++++++++++++++++----- 1 file changed, 113 insertions(+), 13 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 207b837..64a702d 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -87,6 +87,54 @@ # Soft-similarity threshold for gradient cluster reuse (P2) _SOFT_SIM_THRESHOLD = 0.65 + +class _PathScope: + """Immutable search-path scope for filtering compile artifacts. + + Resolves the provided search paths into absolute file paths and + directory prefixes, then offers ``contains()`` to test whether a + given artifact path falls within this scope. + + When the scope is empty (no paths provided), ``contains()`` always + returns True — i.e. *no filtering* is applied. + """ + + __slots__ = ("_files", "_dirs", "_empty") + + def __init__(self, search_paths: Optional[List[str]] = None) -> None: + files: Set[str] = set() + dirs: List[str] = [] + if search_paths: + for p in search_paths: + resolved = str(Path(p).expanduser().resolve()) + if Path(resolved).is_file(): + files.add(resolved) + elif Path(resolved).is_dir(): + dirs.append( + resolved if resolved.endswith(os.sep) + else resolved + os.sep + ) + else: + files.add(resolved) + self._files = frozenset(files) + self._dirs = tuple(dirs) + self._empty = not files and not dirs + + def contains(self, file_path: str) -> bool: + """Return True when *file_path* falls within the search scope.""" + if self._empty: + return True + if not file_path: + return False + resolved = str(Path(file_path).expanduser().resolve()) + if resolved in self._files: + return True + return any(resolved.startswith(d) for d in self._dirs) + + @property + def is_empty(self) -> bool: + return self._empty + # Pure tree search mode for ablation experiments. # When enabled, search relies solely on tree index navigation, skipping rga keyword search. _PURE_TREE_SEARCH: bool = os.getenv("SIRCHMUNK_PURE_TREE_SEARCH", "false").lower() == "true" @@ -1556,7 +1604,8 @@ async def _search_deep( _llm_usage_start = len(self.llm_usages) # --- Adaptive compile artifact detection (shared with FAST) --- - artifacts = self._detect_compile_artifacts() + _scope = _PathScope(paths) + artifacts = self._detect_compile_artifacts(paths) # ============================================================== # Phase 0a: Direct document analysis (intent-gated short-circuit) @@ -1591,8 +1640,8 @@ async def _search_deep( self._probe_knowledge_cache(query), self._load_spec_context(paths, stale_hours=spec_stale_hours), self._probe_tree_index(query), - self._probe_compile_hints([query]), # query-level hints; keyword-level runs post-Phase 1 - self._probe_summary_index(query, artifacts), # GAP 2: zero-LLM BM25 + self._probe_compile_hints([query], scope=_scope), # query-level hints; keyword-level runs post-Phase 1 + self._probe_summary_index(query, artifacts, scope=_scope), # GAP 2: zero-LLM BM25 self._probe_catalog_for_deep(query, artifacts), # GAP 4: zero-LLM keyword overlap return_exceptions=True, ) @@ -2322,7 +2371,8 @@ async def _search_fast( self._tree_nav_cache = _TreeNavCache() # --- Adaptive compile artifact detection (one-shot, zero LLM) --- - artifacts = self._detect_compile_artifacts() + _scope = _PathScope(paths) + artifacts = self._detect_compile_artifacts(paths) if artifacts.catalog or artifacts.tree_available_paths: await self._logger.info( f"[FAST:Artifacts] catalog={'yes' if artifacts.catalog else 'no'} " @@ -2375,7 +2425,7 @@ async def _search_fast( messages=[{"role": "user", "content": prompt}], stream=False, ) - _compile_hints_task = self._probe_compile_hints([query]) + _compile_hints_task = self._probe_compile_hints([query], scope=_scope) _tree_probe_task = self._probe_tree_for_fast(query, artifacts) _parallel_results = await asyncio.gather( @@ -2494,7 +2544,7 @@ async def _search_fast( keyword_idfs.setdefault(p, 0.6) # P4: compile hints — pre-fetched (query-level) + keyword-level supplement - _kw_compile_hints = await self._probe_compile_hints(primary + fallback) + _kw_compile_hints = await self._probe_compile_hints(primary + fallback, scope=_scope) compile_hints = self._merge_compile_hints(_early_compile_hints, _kw_compile_hints) for kw in compile_hints.extra_keywords: if kw not in all_kw_set: @@ -2515,7 +2565,7 @@ async def _search_fast( seen_hint_paths.add(fp) compile_hint_files.append(fp) # Summary index BM25 files: proactive zero-LLM discovery (GAP 2) - _summary_hint_files = await self._probe_summary_index(query, artifacts) + _summary_hint_files = await self._probe_summary_index(query, artifacts, scope=_scope) for fp in _summary_hint_files: if fp not in seen_hint_paths: seen_hint_paths.add(fp) @@ -3551,7 +3601,10 @@ def _load_document_catalog(self) -> Optional[List[Dict[str, str]]]: pass return None - def _detect_compile_artifacts(self) -> CompileArtifacts: + def _detect_compile_artifacts( + self, + search_paths: Optional[List[str]] = None, + ) -> CompileArtifacts: """One-shot probe of all compile artifacts for adaptive FAST activation. Reads the document catalog and scans the tree cache directory to @@ -3559,12 +3612,19 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: start of ``_search_fast()``; the result is passed to downstream helpers so they can enable enhanced logic only when artifacts exist. + When *search_paths* is provided, all returned artifacts are filtered + to only include entries whose file paths fall within the search scope. + This ensures downstream consumers (catalog routing, tree probing, + summary index) never see documents outside the requested scope. + Cost: one JSON read (catalog) + one directory listing (tree cache). Tree path results are cached in ``_tree_paths_cache`` so subsequent calls within the same instance avoid re-parsing every JSON file. Returns a ``CompileArtifacts`` with ``None``/empty fields when compile has not been run. """ + scope = _PathScope(search_paths) + catalog = self._load_document_catalog() catalog_map: Dict[str, Dict[str, str]] = {} if catalog: @@ -3623,6 +3683,14 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: except Exception: pass + # --- Apply search-path scope filtering --- + if not scope.is_empty: + if catalog: + catalog = [e for e in catalog if scope.contains(e.get("path", ""))] + catalog_map = {p: e for p, e in catalog_map.items() if scope.contains(p)} + tree_paths = {p for p in tree_paths if scope.contains(p)} + manifest_map = {p: e for p, e in manifest_map.items() if scope.contains(p)} + print(f"SEARCH_WIKI_DEBUG [D1] manifest_map: {len(manifest_map)} entries, keys={list(manifest_map.keys())[:3]}", flush=True) print(f"SEARCH_WIKI_DEBUG [D2] tree_available_paths: {tree_paths}", flush=True) print(f"SEARCH_WIKI_DEBUG [D3] manifest_fallback_executed: {manifest_map and not tree_paths}", flush=True) @@ -5126,8 +5194,16 @@ def _prefilter_trees_by_query( if not tokens: return trees - year_tokens = {t for t in tokens if re.fullmatch(r"(?:19|20)\d{2}", t)} - entity_tokens = {t for t in tokens if len(t) >= 3 and t not in year_tokens} + # Extract years: bare "2018" and compound prefixed forms "fy2018", "cy2023" + year_tokens: Set[str] = set() + for t in tokens: + if re.fullmatch(r"(?:19|20)\d{2}", t): + year_tokens.add(t) + else: + m = re.search(r"((?:19|20)\d{2})", t) + if m: + year_tokens.add(m.group(1)) + entity_tokens = {t for t in tokens if len(t) >= 2 and t not in year_tokens} scored: List[Tuple[float, int]] = [] for idx, tree in enumerate(trees): @@ -5238,12 +5314,20 @@ async def _probe_tree_index(self, query: str) -> List[str]: except Exception: return [] - async def _probe_compile_hints(self, keywords: List[str]) -> CompileHints: + async def _probe_compile_hints( + self, + keywords: List[str], + *, + scope: Optional["_PathScope"] = None, + ) -> CompileHints: """Zero-LLM enrichment from compile manifest and tree cache. Scans the compile manifest for clusters whose patterns overlap with the query keywords, and scans cached tree root summaries for keyword matches. No LLM calls — only local JSON reads and in-memory DB lookups. + + When *scope* is provided, only file paths falling within the scope + are included in the returned hints. """ empty = CompileHints([], []) if not keywords: @@ -5255,6 +5339,11 @@ async def _probe_compile_hints(self, keywords: List[str]) -> CompileHints: seen_paths: set = set() seen_kw: set = set(kw_lower) + def _accept(fp: str) -> bool: + return bool(fp) and fp not in seen_paths and Path(fp).exists() and ( + scope is None or scope.contains(fp) + ) + # --- Cluster pattern matching via manifest --- manifest_path = self.work_path / ".cache" / "compile" / "manifest.json" if manifest_path.exists(): @@ -5280,7 +5369,7 @@ async def _probe_compile_hints(self, keywords: List[str]) -> CompileHints: if kw_lower & set(cluster_patterns): for ev in getattr(c, "evidences", []): fp = str(getattr(ev, "file_or_url", "")) - if fp and fp not in seen_paths and Path(fp).exists(): + if _accept(fp): seen_paths.add(fp) file_paths.append(fp) for p in cluster_patterns: @@ -5307,7 +5396,7 @@ async def _probe_compile_hints(self, keywords: List[str]) -> CompileHints: summary_lower = (tree.root.summary or "").lower() if any(kw in summary_lower for kw in kw_lower): fp = tree.file_path - if fp not in seen_paths and Path(fp).exists(): + if _accept(fp): seen_paths.add(fp) file_paths.append(fp) except Exception: @@ -5339,6 +5428,8 @@ async def _probe_summary_index( self, query: str, artifacts: Optional["CompileArtifacts"] = None, + *, + scope: Optional["_PathScope"] = None, ) -> List[str]: """Zero-LLM file discovery via compile-time summary index (BM25 only). @@ -5346,9 +5437,13 @@ async def _probe_summary_index( summaries are lexically similar to the query. No LLM or embedding calls — pure local computation. + When *scope* is provided, results are post-filtered to only include + file paths within the search scope. + Args: query: User query string. artifacts: Compile artifacts (uses summary_index field). + scope: Optional path scope for filtering results. Returns: File paths of top-k matching documents, or empty list. @@ -5374,6 +5469,7 @@ async def _probe_summary_index( file_paths = [ fp for fp, score in results if score > 0.0 and Path(fp).exists() + and (scope is None or scope.contains(fp)) ] if file_paths: @@ -5459,6 +5555,10 @@ async def _probe_tree_for_fast( try: trees = self._load_cached_trees() + # Scope-filter: only keep trees whose files are in artifacts + if artifacts and artifacts.tree_available_paths: + scoped = artifacts.tree_available_paths + trees = [t for t in trees if t.file_path in scoped] print(f"SEARCH_WIKI_DEBUG [D5] loaded_trees: {len(trees)} trees, paths={[t.file_path for t in trees][:3]}", flush=True) if not trees: return [] From 7439521a810db8fe897ccc8cf520eabc0700a83b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 12 May 2026 16:04:10 +0800 Subject: [PATCH 54/70] tree index and rga fusion --- src/sirchmunk/search.py | 154 +++++++++++++++++++++++++++++++++++----- 1 file changed, 138 insertions(+), 16 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 64a702d..3507577 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2225,7 +2225,7 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: ".css", ".bash", ".java", ".c", ".cpp", ".h", ".go", ".rs", } _FAST_CONTEXT_WINDOW = 30 # ± lines around each grep hit - _FAST_MAX_EVIDENCE_CHARS = 20_000 # Plan 5: expanded from 15K to accommodate richer table evidence + _FAST_MAX_EVIDENCE_CHARS = 40_000 _FAST_SMALL_FILE_THRESHOLD = 100_000 # 100K chars - read full file instead of grep sampling # --- Wiki-enhanced ranking constants --- @@ -2261,7 +2261,7 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Maximum files returned by catalog keyword-overlap probe in DEEP mode.""" # --- Tree-guided sampling constants --- - _TREE_SAMPLE_MAX_SECTIONS = 5 + _TREE_SAMPLE_MAX_SECTIONS = 8 """Max tree sections to include per file in tree-guided sampling.""" _TREE_SAMPLE_SECTION_MAX_CHARS = 3000 """Max chars per tree section.""" @@ -2280,6 +2280,10 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _CHAR_RANGE_MAX_SPAN_RATIO: float = 0.8 """char_range spanning more than this ratio of the document is treated as invalid.""" + # --- Tree probe / RGA fusion --- + _TREE_PROBE_RANKING_BOOST: float = 3.0 + """Score boost (0-10 scale) for files selected by LLM tree probing.""" + # --- Hierarchical file selection for large tree pools --- _TREE_PREFILTER_THRESHOLD: int = 15 """Tree pool size above which rule-based pre-filtering is applied.""" @@ -2288,10 +2292,12 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _TREE_PREFILTER_MIN_SCORE: float = 0.5 """Minimum relevance score for a tree to survive pre-filtering.""" - # --- Tree navigation retry (Plan 3) --- + # --- Tree navigation --- + _TREE_NAV_MAX_RESULTS: int = 8 + """Primary max_results for LLM-driven tree navigation.""" _NAV_RETRY_MIN_EVIDENCE_CHARS: int = 200 """Evidence below this length triggers a retry with expanded results.""" - _NAV_RETRY_EXPANDED_RESULTS: int = 8 + _NAV_RETRY_EXPANDED_RESULTS: int = 12 """Expanded max_results for retry navigation pass.""" _CHAR_RANGE_MIN_SPAN: int = 200 @@ -2305,12 +2311,12 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _NAV_COMPLEMENT_MIN_COMPONENTS: int = 2 """Minimum query decomposition components to trigger complementary navigation.""" - # --- Table evidence budgets (Plan 5) --- - _TABLE_EVIDENCE_DEFAULT_CHARS: int = 10_000 - """Default max_chars for _format_table_evidence (was 6000).""" + # --- Table evidence budgets --- + _TABLE_EVIDENCE_DEFAULT_CHARS: int = 20_000 + """Default max_chars for _format_table_evidence.""" _TABLE_EVIDENCE_PER_RANGE_CHARS: int = 8_000 - """Max chars for per-page-range table supplement in tree nav (was 4000).""" - _TABLE_EVIDENCE_STANDALONE_CHARS: int = 12_000 + """Max chars for per-page-range table supplement in tree nav.""" + _TABLE_EVIDENCE_STANDALONE_CHARS: int = 20_000 """Max chars for standalone table digest fallback when tree nav evidence is thin.""" # --- Self-correction expanded sampling --- @@ -2445,6 +2451,7 @@ async def _search_fast( if isinstance(_tree_probed_files, Exception): await self._logger.warning(f"[FAST:Step1] Tree probe failed: {_tree_probed_files}") _tree_probed_files = [] + _tree_probed_set: frozenset[str] = frozenset(_tree_probed_files) self.llm_usages.append(resp.usage) if resp.usage and isinstance(resp.usage, dict): @@ -2671,10 +2678,26 @@ async def _search_fast( for p in catalog_routed_files[:top_k_files] ] + # Narrow-scope RGA: search within tree-probed files first + if not best_files and _tree_probed_set and primary: + best_files = await self._fast_find_best_file( + primary, paths=list(_tree_probed_set), + top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + ) + if best_files: + used_level = "tree_rga" + await self._logger.info( + f"[FAST:Step2] Narrow-scope tree+rga hit → " + f"{[Path(f['path']).name for f in best_files]}" + ) + + # Full-scope RGA with tree probe boost if not best_files and primary: best_files = await self._fast_find_best_file( primary, top_k=top_k_files, keyword_idfs=keyword_idfs, query=query, artifacts=artifacts, + tree_probed_paths=_tree_probed_set or None, **rga_kwargs, ) @@ -2686,6 +2709,7 @@ async def _search_fast( best_files = await self._fast_find_best_file( fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, query=query, artifacts=artifacts, + tree_probed_paths=_tree_probed_set or None, **rga_kwargs, ) @@ -2756,7 +2780,11 @@ async def _search_fast( print(f"SEARCH_WIKI_DEBUG [D11] MISMATCH! tree_available_paths={artifacts.tree_available_paths}", flush=True) if artifacts and tree_nav_target in artifacts.tree_available_paths: - tree_task = self._navigate_tree_for_evidence(tree_nav_target, query) + tree_task = self._navigate_tree_for_evidence( + tree_nav_target, query, + max_results=self._TREE_NAV_MAX_RESULTS, + match_objects=best_files[0].get("matches"), + ) tree_nav_done.add(tree_nav_target) else: tree_task = self._async_noop(None) @@ -2818,12 +2846,15 @@ async def _rga_evidence() -> str: if _all_tables: _table_ev = self._format_table_evidence( - _all_tables, query=query, + _all_tables, + max_chars=self._TABLE_EVIDENCE_DEFAULT_CHARS, + query=query, ) if _table_ev: ev = f"[{fn} - Table Evidence]\n{_table_ev}" - # 1. Tree-guided sampling FIRST for tree-indexed files + # 1. Tree-guided sampling for tree-indexed files + # (skipped when a parallel tree_task already covers this file) _tree_cond = artifacts and fp in artifacts.tree_available_paths and fp not in tree_nav_done print(f"SEARCH_WIKI_DEBUG [D14] tree_sample: cond={_tree_cond}, in_tree_paths={fp in (artifacts.tree_available_paths if artifacts else set())}, in_nav_done={fp in tree_nav_done}", flush=True) if ( @@ -2839,7 +2870,10 @@ async def _rga_evidence() -> str: artifacts=artifacts, ) if tree_ev_inner: - ev = tree_ev_inner + if ev: + ev = ev + "\n\n" + tree_ev_inner + else: + ev = tree_ev_inner await self._logger.info( f"[FAST:Step3] Tree-guided sample for {fn} " f"({len(tree_ev_inner)} chars)" @@ -2883,6 +2917,8 @@ async def _rga_evidence() -> str: rga_ev, tree_ev = await asyncio.gather(rga_task, tree_task) # Merge: tree evidence first (highest quality), then rga + if tree_ev and rga_ev: + rga_ev = self._deduplicate_table_sections(tree_ev, rga_ev) evidence_parts_final: List[str] = [] if tree_ev: evidence_parts_final.append(tree_ev) @@ -3257,11 +3293,16 @@ async def _fast_find_best_file( keyword_idfs: Optional[Dict[str, float]] = None, query: str = "", artifacts: Optional["CompileArtifacts"] = None, + tree_probed_paths: Optional[Set[str]] = None, ) -> Optional[List[Dict[str, Any]]]: """Search per keyword via rga and return the top-k best-matching files ranked by IDF-weighted log-TF scoring, optionally enhanced with wiki-derived relevance from compile artifacts. + When *tree_probed_paths* is provided, files that were selected by + LLM-driven tree probing receive a ranking boost, ensuring the tree + probe's high-quality signal influences the final file ordering. + Args: keywords: Search keywords from FAST Step 1. paths: Search paths. @@ -3272,6 +3313,7 @@ async def _fast_find_best_file( keyword_idfs: Pre-computed IDF values for keywords. query: Original user query (used for wiki relevance scoring). artifacts: Compile artifacts for adaptive wiki-enhanced ranking. + tree_probed_paths: File paths selected by tree probing (receive boost). Returns: List of merged file dicts (path, matches, lines, total_matches, weighted_score) or None. @@ -3427,6 +3469,11 @@ async def _fast_find_best_file( + (1 - self._WIKI_BLEND_ALPHA) * wiki_score ) + if tree_probed_paths: + for f in merged: + if f["path"] in tree_probed_paths: + f["weighted_score"] += self._TREE_PROBE_RANKING_BOOST + merged.sort(key=lambda f: f["weighted_score"], reverse=True) pruned = self._prune_by_score(merged, top_k=top_k) @@ -4385,10 +4432,44 @@ def _score_table_relevance( return hits / len(query_tokens) + @staticmethod + def _deduplicate_table_sections( + primary_ev: str, secondary_ev: str, + ) -> str: + """Remove table sections from *secondary_ev* whose pages already + appear in *primary_ev*. + + Matching is based on ``[Table from page N]`` and ``[Tables pp.X-Y]`` + headers. Non-table content in *secondary_ev* is preserved intact. + """ + if not primary_ev or not secondary_ev: + return secondary_ev + + covered: Set[int] = { + int(m.group(1)) + for m in re.finditer(r"\[Table from page (\d+)\]", primary_ev) + } + for m in re.finditer(r"\[Tables pp\.(\d+)-(\d+)\]", primary_ev): + covered.update(range(int(m.group(1)), int(m.group(2)) + 1)) + + if not covered: + return secondary_ev + + blocks = secondary_ev.split("\n\n") + kept: List[str] = [] + for block in blocks: + page_m = re.search(r"\[Table from page (\d+)\]", block) + if page_m and int(page_m.group(1)) in covered: + continue + kept.append(block) + + result = "\n\n".join(kept) + return result if result.strip() else "" + @staticmethod def _format_table_evidence( tables: List[Dict[str, Any]], - max_chars: int = 10_000, + max_chars: int = 20_000, query: str = "", ) -> str: """Format table digest entries as LLM-friendly evidence text. @@ -4410,7 +4491,7 @@ def _format_table_evidence( ordered = tables if query: query_tokens = frozenset( - tok for tok in query.lower().split() if len(tok) > 2 + tok for tok in query.lower().split() if len(tok) >= 2 ) if query_tokens: scored = [ @@ -4465,7 +4546,12 @@ def _append_evidence_part( parts.append(f"{header}\n{text}") async def _navigate_tree_for_evidence( - self, file_path: str, query: str, *, max_results: int = 5, + self, + file_path: str, + query: str, + *, + max_results: int = 8, + match_objects: Optional[List[Dict[str, Any]]] = None, ) -> Optional[str]: """LLM-driven tree navigation: select relevant sections and read leaf content. @@ -4473,6 +4559,10 @@ async def _navigate_tree_for_evidence( *file_path*, returning concatenated leaf content as evidence. Returns None when no tree cache is available. + When *match_objects* (RGA hit dicts) are provided, keyword-level + context windows are appended as supplementary evidence after tree + navigation, fusing structural and keyword signals. + Extraction priority (highest first): 1. char_range – compile-time ENHANCED content slice (preserves tables) 2. page_range – page-level extraction via DocumentExtractor (fallback) @@ -4796,6 +4886,38 @@ async def _navigate_tree_for_evidence( print(f"SEARCH_WIKI_DEBUG [N5] table_supplement: tables_loaded={len(_all_tables) if _all_tables else 0}", flush=True) + # --- RGA keyword supplement: fuse keyword hits into tree evidence --- + if match_objects: + _ev_len = sum(len(p) for p in parts) + _rga_budget = max(0, self._FAST_MAX_EVIDENCE_CHARS - _ev_len) + if _rga_budget > 200: + hit_lines: List[int] = [ + m.get("data", {}).get("line_number") + for m in match_objects + if isinstance(m.get("data", {}).get("line_number"), int) + ] + ext = Path(file_path).suffix.lower() + rga_ctx: Optional[str] = None + if ext in self._FAST_TEXT_EXTENSIONS and hit_lines: + rga_ctx = self._read_context_windows( + file_path, hit_lines, + window=self._FAST_CONTEXT_WINDOW, + max_chars=_rga_budget, + ) + else: + snippet_parts: List[str] = [] + snippet_total = 0 + for m in match_objects: + text = m.get("data", {}).get("lines", {}).get("text", "").rstrip() + if text and snippet_total + len(text) < _rga_budget: + snippet_parts.append(text) + snippet_total += len(text) + if snippet_parts: + rga_ctx = "\n".join(snippet_parts) + if rga_ctx: + parts.append(f"[{fname} \u2192 keyword hits]\n{rga_ctx}") + evidence = "\n\n".join(parts) + print(f"SEARCH_WIKI_DEBUG [N6] _navigate_tree_for_evidence result: len={len(evidence) if evidence else 0}", flush=True) await self._logger.info( f"[FAST:TreeNav] Extracted {len(parts)} sections, " From cec209d9f5c57ad6b0de933793786f10d0c1ac14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 12 May 2026 18:12:22 +0800 Subject: [PATCH 55/70] fallback hybrid tree indexing --- src/sirchmunk/llm/prompts.py | 4 ++ src/sirchmunk/search.py | 82 +++++++++++++++++++++++++++++++++--- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 909402d..89ea9e8 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -424,6 +424,8 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. +6. **Rounding**: Match the precision implied by the query. If the question specifies units (e.g. "in USD millions", "in billions", "as a percentage") or expects a rounded figure, round your final result accordingly rather than reporting raw calculated values with excessive decimal places. +7. **Best-effort answering**: Always attempt to answer based on available evidence. When evidence is partial or indirect, derive the best possible answer and note any assumptions. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Input Data - **User Input**: {user_input} @@ -465,6 +467,8 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. +6. **Rounding**: Match the precision implied by the query. If the question specifies units (e.g. "in USD millions", "in billions", "as a percentage") or expects a rounded figure, round your final result accordingly rather than reporting raw calculated values with excessive decimal places. +7. **Best-effort answering**: Always attempt to answer based on available evidence. When evidence is partial or indirect, derive the best possible answer and note any assumptions. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Document Context {document_context} diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 3507577..55d25c7 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2311,6 +2311,12 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _NAV_COMPLEMENT_MIN_COMPONENTS: int = 2 """Minimum query decomposition components to trigger complementary navigation.""" + _NAV_PAGE_MARGIN: int = 1 + """Extra pages to extract on each side of a leaf's page_range.""" + + _NAV_REF_PAGE_MAX: int = 5 + """Maximum referenced-but-uncovered pages to extract as gap-fill.""" + # --- Table evidence budgets --- _TABLE_EVIDENCE_DEFAULT_CHARS: int = 20_000 """Default max_chars for _format_table_evidence.""" @@ -2911,10 +2917,8 @@ async def _rga_evidence() -> str: print(f"SEARCH_WIKI_DEBUG [D15] ev_source={_ev_source}, ev_len={len(ev) if ev else 0}", flush=True) return "\n\n---\n\n".join(parts) - # Launch tree navigation for the primary file alongside rga - rga_task = _rga_evidence() - - rga_ev, tree_ev = await asyncio.gather(rga_task, tree_task) + # Launch tree navigation alongside rga evidence collection. + rga_ev, tree_ev = await asyncio.gather(_rga_evidence(), tree_task) # Merge: tree evidence first (highest quality), then rga if tree_ev and rga_ev: @@ -4326,6 +4330,29 @@ def _check_leaf_coverage( missing = [c for c in components if c not in leaf_text] return covered, missing + @staticmethod + def _extract_referenced_pages(text: str) -> Set[int]: + """Extract page numbers referenced in evidence text. + + Detects cross-references like 'page 60', 'pages 45-47', 'pp. 12-15' + that hint at data-bearing pages not yet included in evidence. + """ + pages: Set[int] = set() + for m in re.finditer( + r"\b(?:pages?|pp?\.)\s*(\d+)\s*[-\u2013]\s*(\d+)", + text, re.IGNORECASE, + ): + start, end = int(m.group(1)), int(m.group(2)) + if 0 < start <= end and end - start <= 10: + pages.update(range(start, end + 1)) + for m in re.finditer( + r"\b(?:pages?|pp?\.)\s*(\d+)\b", text, re.IGNORECASE, + ): + p = int(m.group(1)) + if 0 < p <= 500: + pages.add(p) + return pages + @staticmethod def _load_compile_content( work_path: Path, file_path: str, @@ -4602,7 +4629,10 @@ async def _navigate_tree_for_evidence( if page_leaves: all_pages: set = set() for _leaf, (sp, ep) in page_leaves: - all_pages.update(range(sp, ep + 1)) + all_pages.update(range( + max(1, sp - self._NAV_PAGE_MARGIN), + ep + self._NAV_PAGE_MARGIN + 1, + )) try: page_contents = DocumentExtractor.extract_pages( file_path, sorted(all_pages), @@ -4697,7 +4727,10 @@ async def _navigate_tree_for_evidence( if page_fallback_leaves: all_fb_pages: set = set() for _lf, (sp, ep) in page_fallback_leaves: - all_fb_pages.update(range(sp, ep + 1)) + all_fb_pages.update(range( + max(1, sp - self._NAV_PAGE_MARGIN), + ep + self._NAV_PAGE_MARGIN + 1, + )) try: fb_contents = DocumentExtractor.extract_pages( file_path, sorted(all_fb_pages), @@ -4886,6 +4919,43 @@ async def _navigate_tree_for_evidence( print(f"SEARCH_WIKI_DEBUG [N5] table_supplement: tables_loaded={len(_all_tables) if _all_tables else 0}", flush=True) + # ── Phase 6: Referenced-page gap-fill ── + # Scan evidence for page cross-references (e.g. TOC entries + # pointing to financial statements) and extract any that were + # not covered by the navigated leaves. + if parts: + _covered_pages: Set[int] = set() + for leaf in leaves: + pr = getattr(leaf, "page_range", None) + if pr and len(pr) == 2 and pr[0] is not None: + _covered_pages.update(range( + max(1, pr[0] - self._NAV_PAGE_MARGIN), + pr[1] + self._NAV_PAGE_MARGIN + 1, + )) + _referenced = self._extract_referenced_pages("\n\n".join(parts)) + _gap_pages = sorted(_referenced - _covered_pages)[ + : self._NAV_REF_PAGE_MAX + ] + if _gap_pages: + try: + _gap_contents = DocumentExtractor.extract_pages( + file_path, _gap_pages, + ) + for pc in _gap_contents: + if pc.content and pc.content.strip(): + parts.append( + f"[{fname} \u2192 referenced p.{pc.page_number}]" + f"\n{pc.content}" + ) + evidence = "\n\n".join(parts) + print( + f"SEARCH_WIKI_DEBUG [N5.2] ref_page_gap_fill: " + f"pages={_gap_pages}", + flush=True, + ) + except Exception: + pass + # --- RGA keyword supplement: fuse keyword hits into tree evidence --- if match_objects: _ev_len = sum(len(p) for p in parts) From 59beaeac766a99da0aa2753b1c8878e05f195255 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 13 May 2026 01:50:10 +0800 Subject: [PATCH 56/70] improve search pipeline for hybrid --- src/sirchmunk/llm/prompts.py | 10 ++-- src/sirchmunk/search.py | 111 +++++++++++++++++++++++++++++++++-- 2 files changed, 113 insertions(+), 8 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 89ea9e8..8c8c049 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -424,8 +424,9 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. -6. **Rounding**: Match the precision implied by the query. If the question specifies units (e.g. "in USD millions", "in billions", "as a percentage") or expects a rounded figure, round your final result accordingly rather than reporting raw calculated values with excessive decimal places. -7. **Best-effort answering**: Always attempt to answer based on available evidence. When evidence is partial or indirect, derive the best possible answer and note any assumptions. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. +6. **Rounding**: Match the precision implied by the query. When the question asks for a value in specific units (e.g. "in USD millions"), round the final result to match the expected granularity. For percentages, use at most one decimal place unless the query explicitly asks for more. For dollar amounts, round to the nearest whole number in the stated unit. Example: if the raw calculation yields $8.738 billion and the expected unit is "USD billions", report $8.7 billion or $8.74 billion, not $8.738 billion. +7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. +8. **Binary/judgment questions**: For questions expecting a Yes/No or directional answer, briefly list evidence supporting each side before stating your conclusion. Base your answer on the quantitative evidence rather than subjective assessments. ### Input Data - **User Input**: {user_input} @@ -467,8 +468,9 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. -6. **Rounding**: Match the precision implied by the query. If the question specifies units (e.g. "in USD millions", "in billions", "as a percentage") or expects a rounded figure, round your final result accordingly rather than reporting raw calculated values with excessive decimal places. -7. **Best-effort answering**: Always attempt to answer based on available evidence. When evidence is partial or indirect, derive the best possible answer and note any assumptions. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. +6. **Rounding**: Match the precision implied by the query. When the question asks for a value in specific units (e.g. "in USD millions"), round the final result to match the expected granularity. For percentages, use at most one decimal place unless the query explicitly asks for more. For dollar amounts, round to the nearest whole number in the stated unit. Example: if the raw calculation yields $8.738 billion and the expected unit is "USD billions", report $8.7 billion or $8.74 billion, not $8.738 billion. +7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. +8. **Binary/judgment questions**: For questions expecting a Yes/No or directional answer, briefly list evidence supporting each side before stating your conclusion. Base your answer on the quantitative evidence rather than subjective assessments. ### Document Context {document_context} diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 55d25c7..58aa0a7 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2324,6 +2324,16 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Max chars for per-page-range table supplement in tree nav.""" _TABLE_EVIDENCE_STANDALONE_CHARS: int = 20_000 """Max chars for standalone table digest fallback when tree nav evidence is thin.""" + _TABLE_CROSS_SECTION_CHARS: int = 6_000 + """Max chars for cross-section table supplement drawn from pages outside + the navigated leaf ranges. Ensures data-dense tables in distant + document sections (e.g. financial statements when leaves are in + management discussion) are included.""" + _TABLE_EVIDENCE_NAV_OVERLAP_CHARS: int = 8_000 + """Reduced table evidence budget for files that are already receiving + parallel tree navigation. Since tree_ev will provide targeted evidence, + the RGA path uses a smaller budget to supply incremental tables, + leaving room for more diverse evidence.""" # --- Self-correction expanded sampling --- _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 10 @@ -2851,9 +2861,14 @@ async def _rga_evidence() -> str: print(f"SEARCH_WIKI_DEBUG [D13] table_digest: manifest_lookup={'found' if artifacts.manifest_map and artifacts.manifest_map.get(fp) else 'miss'}, has_table_digest={getattr(artifacts.manifest_map.get(fp), 'has_table_digest', False) if artifacts.manifest_map else 'N/A'}, hash_fallback={'tried' if not _all_tables else 'skipped'}, tables_count={len(_all_tables) if _all_tables else 0}", flush=True) if _all_tables: + _td_budget = ( + self._TABLE_EVIDENCE_NAV_OVERLAP_CHARS + if fp in tree_nav_done + else self._TABLE_EVIDENCE_DEFAULT_CHARS + ) _table_ev = self._format_table_evidence( _all_tables, - max_chars=self._TABLE_EVIDENCE_DEFAULT_CHARS, + max_chars=_td_budget, query=query, ) if _table_ev: @@ -4412,6 +4427,16 @@ def _filter_tables_by_page_range( ] _TABLE_RELEVANCE_MIN_PREFIX = 5 + _TABLE_STRUCTURE_BONUS: float = 0.25 + """Bonus score for tables exhibiting structured data characteristics + (high row count, numeric density). Applied additively to the keyword + relevance score so that data-rich tables are preferred when keyword + scores tie.""" + _TABLE_STRUCTURE_MIN_ROWS: int = 5 + """Minimum ``|``-delimited rows for a table to qualify for the + structure bonus.""" + _TABLE_STRUCTURE_MIN_NUMERIC_RATIO: float = 0.15 + """Minimum ratio of numeric tokens to total tokens for the bonus.""" @staticmethod def _score_table_relevance( @@ -4459,6 +4484,39 @@ def _score_table_relevance( return hits / len(query_tokens) + @staticmethod + def _score_table_structure(markdown: str) -> float: + """Score a table's structural richness (row count + numeric density). + + Data-dense tables (financial statements, balance sheets) score + higher than narrative paragraphs that happen to contain a small + embedded table. The score is in [0, 1] and is added as a bonus + to the keyword relevance score during table ranking. + """ + if not markdown: + return 0.0 + + rows = markdown.count("\n") + if rows < AgenticSearch._TABLE_STRUCTURE_MIN_ROWS: + return 0.0 + + tokens = markdown.split() + if not tokens: + return 0.0 + + numeric_count = sum( + 1 for t in tokens + if any(c.isdigit() for c in t) + ) + numeric_ratio = numeric_count / len(tokens) + + if numeric_ratio < AgenticSearch._TABLE_STRUCTURE_MIN_NUMERIC_RATIO: + return 0.0 + + row_score = min(rows / 30.0, 1.0) + num_score = min(numeric_ratio / 0.4, 1.0) + return (row_score * 0.5 + num_score * 0.5) + @staticmethod def _deduplicate_table_sections( primary_ev: str, secondary_ev: str, @@ -4521,10 +4579,18 @@ def _format_table_evidence( tok for tok in query.lower().split() if len(tok) >= 2 ) if query_tokens: + struct_bonus = AgenticSearch._TABLE_STRUCTURE_BONUS scored = [ - (AgenticSearch._score_table_relevance( - t.get("markdown", ""), query_tokens, - ), idx, t) + ( + AgenticSearch._score_table_relevance( + t.get("markdown", ""), query_tokens, + ) + + struct_bonus * AgenticSearch._score_table_structure( + t.get("markdown", ""), + ), + idx, + t, + ) for idx, t in enumerate(tables) ] scored.sort(key=lambda x: (-x[0], x[1])) @@ -4897,6 +4963,43 @@ async def _navigate_tree_for_evidence( except Exception: pass + # ── Phase 5.5: Cross-section table supplement ── + # The leaf-scoped supplement (above) only includes tables from + # pages matching selected leaves. When leaves cluster in one + # region (e.g. management discussion), data-dense tables from + # other sections (e.g. financial statements) are missed. + # Fix: include top-ranked tables from UNCOVERED pages. + if _all_tables and leaves: + _leaf_page_set: Set[int] = set() + for _lf in leaves: + _pr = getattr(_lf, "page_range", None) + if _pr and len(_pr) == 2 and _pr[0] is not None: + _leaf_page_set.update(range( + max(1, _pr[0] - self._NAV_PAGE_MARGIN), + _pr[1] + self._NAV_PAGE_MARGIN + 1, + )) + _cross_tables = [ + t for t in _all_tables + if t.get("page_number") is not None + and t["page_number"] not in _leaf_page_set + ] + if _cross_tables: + _cross_ev = self._format_table_evidence( + _cross_tables, + max_chars=self._TABLE_CROSS_SECTION_CHARS, + query=query, + ) + if _cross_ev: + parts.append( + f"[{fname} - Cross-section Tables]\n{_cross_ev}" + ) + print( + f"SEARCH_WIKI_DEBUG [N5.3] cross_section_tables: " + f"uncovered_tables={len(_cross_tables)}, " + f"ev_len={len(_cross_ev)}", + flush=True, + ) + # Plan 3: If evidence is still too thin, add full table digest as standalone evidence = "\n\n".join(parts) if ( From ec6f6b1c1310a92b4e86cea8aebf820b2c51e2cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 13 May 2026 17:37:45 +0800 Subject: [PATCH 57/70] Add compile tree index for DEEP mode --- src/sirchmunk/llm/prompts.py | 83 +++- src/sirchmunk/search.py | 756 ++++++++++++++++++++++++++++++----- 2 files changed, 745 insertions(+), 94 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 8c8c049..3bd545d 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -426,7 +426,6 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. 6. **Rounding**: Match the precision implied by the query. When the question asks for a value in specific units (e.g. "in USD millions"), round the final result to match the expected granularity. For percentages, use at most one decimal place unless the query explicitly asks for more. For dollar amounts, round to the nearest whole number in the stated unit. Example: if the raw calculation yields $8.738 billion and the expected unit is "USD billions", report $8.7 billion or $8.74 billion, not $8.738 billion. 7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. -8. **Binary/judgment questions**: For questions expecting a Yes/No or directional answer, briefly list evidence supporting each side before stating your conclusion. Base your answer on the quantitative evidence rather than subjective assessments. ### Input Data - **User Input**: {user_input} @@ -470,7 +469,6 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. 6. **Rounding**: Match the precision implied by the query. When the question asks for a value in specific units (e.g. "in USD millions"), round the final result to match the expected granularity. For percentages, use at most one decimal place unless the query explicitly asks for more. For dollar amounts, round to the nearest whole number in the stated unit. Example: if the raw calculation yields $8.738 billion and the expected unit is "USD billions", report $8.7 billion or $8.74 billion, not $8.738 billion. 7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. -8. **Binary/judgment questions**: For questions expecting a Yes/No or directional answer, briefly list evidence supporting each side before stating your conclusion. Base your answer on the quantitative evidence rather than subjective assessments. ### Document Context {document_context} @@ -505,6 +503,87 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: """ +# --------------------------------------------------------------------------- +# Deep Structured Reasoning prompts +# --------------------------------------------------------------------------- + +DEEP_SECTION_SELECT = """Given the user query and a document section map, select the sections most likely to contain the answer. + +### User Query +{query} + +### Document Section Map +{section_map} + +### Instructions +1. Identify which sections contain data needed to answer the query. +2. For questions requiring computation (ratios, growth rates, comparisons), select ALL sections containing the required input data. +3. Prefer sections containing structured data (tables, financial statements) over narrative sections. +4. Select 1-5 sections. Fewer is better if you are confident. + +### Output +Return ONLY a JSON array of section indices (0-based) from the map above: +[0, 3, 5] +""" + + +DEEP_STRUCTURED_EXTRACT = """Extract all data relevant to the query from the provided document content. + +### User Query +{query} + +### Document Content +{evidence} + +### Instructions +1. Extract every data point, number, or fact that could help answer the query. +2. Preserve exact values, units, and context (e.g. fiscal year, line item name). +3. For tables, extract the specific rows and columns relevant to the query. +4. Note the source location (section title or page) for each extracted item. +5. If the query requires a calculation, identify and extract ALL input values needed. + +### Output Format + +- [source]: [data point name] = [value] [unit] +- [source]: [data point name] = [value] [unit] +... + +complete|partial|insufficient + +[List any data items needed to answer the query that were NOT found. Empty if complete.] + +""" + + +DEEP_COT_REASONING = """Answer the query using ONLY the extracted data below. Show complete reasoning. + +### User Query +{query} + +### Extracted Data +{structured_data} + +### Instructions +1. State which data points you will use and why. +2. Show ALL calculation steps explicitly (one operation per line). +3. After computing the result, verify it: + a. Check units and order of magnitude. + b. Cross-check with any alternative data if available. + c. Ensure the result directly answers what was asked. +4. Match the precision and units implied by the query. + +### Output Format + +[Step-by-step reasoning with explicit calculations] + + +[Sanity checks and cross-validation] + +[Your final answer here] +high|medium|low +""" + + # --------------------------------------------------------------------------- # Knowledge Compile prompts # --------------------------------------------------------------------------- diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 58aa0a7..6d9a17e 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -27,6 +27,9 @@ DOC_SUMMARY, DOC_CHUNK_SUMMARY, DOC_MERGE_SUMMARIES, + DEEP_SECTION_SELECT, + DEEP_STRUCTURED_EXTRACT, + DEEP_COT_REASONING, ) from sirchmunk.retrieve.text_retriever import GrepRetriever from sirchmunk.schema.knowledge import ( @@ -1833,111 +1836,234 @@ async def _search_deep( cluster.content = f"{cluster.content}\n\n{graph_ctx}" # ============================================================== - # Phase 4: Generate answer — cluster summary or ReAct refinement + # Phase 3.6: Adaptive depth — fast triage for simple queries + # When pre-nav evidence is sufficient, attempt FAST-style synthesis. + # The LLM decides via SHOULD_ANSWER + PRECISE_ANSWER whether the + # evidence is adequate — no hardcoded heuristic. If the LLM + # returns a precise answer with acceptance, we skip the heavier + # structured reasoning pipeline. # ============================================================== - context.increment_loop() - answer: str = "" - should_save: bool = True - - # Inject catalog context for wiki-enhanced answer (GAP 4) - if artifacts and artifacts.catalog_map and cluster and cluster.content: - _catalog_ctx_parts = [] - for fp in (cluster.search_results or merged_files)[:3]: - ctx = self._build_answer_context(fp, artifacts) - if ctx: - _catalog_ctx_parts.append(ctx) - if _catalog_ctx_parts: - _catalog_context = "\n".join(_catalog_ctx_parts) - if isinstance(cluster.content, list): - cluster.content = "\n".join(cluster.content) - cluster.content = f"{cluster.content}\n\n[Document Context]\n{_catalog_context}" - await self._logger.info( - f"[Phase 4] Injected catalog context for {len(_catalog_ctx_parts)} documents" + _fast_triage_answer: Optional[str] = None + _fast_triage_accepted = False + + if _pre_nav_evidence: + _pre_nav_total = sum(len(v) for v in _pre_nav_evidence.values()) + await self._logger.info( + f"[Phase 3.6] Fast triage: {_pre_nav_total} chars of pre-nav evidence" + ) + _triage_evidence = "\n\n---\n\n".join( + f"[{Path(fp).name}]\n{ev}" + for fp, ev in _pre_nav_evidence.items() + ) + + doc_context = None + if artifacts and artifacts.catalog_map: + _ctx_parts = [ + self._build_answer_context(fp, artifacts) + for fp in list(_pre_nav_evidence)[:2] + ] + _ctx_parts = [c for c in _ctx_parts if c] + if _ctx_parts: + doc_context = "\n".join(_ctx_parts) + + if doc_context: + from sirchmunk.llm.prompts import ROI_RESULT_SUMMARY_WITH_CONTEXT + _triage_prompt = ROI_RESULT_SUMMARY_WITH_CONTEXT.format( + user_input=query, + text_content=_triage_evidence, + document_context=doc_context, + ) + else: + _triage_prompt = ROI_RESULT_SUMMARY.format( + user_input=query, + text_content=_triage_evidence, ) - if cluster and cluster.content: - await self._logger.info("[Phase 4] Evidence sufficient, generating summary") - answer, should_save, should_answer = await self._summarise_cluster(query, cluster) + _triage_resp = await self.llm.achat( + messages=[{"role": "user", "content": _triage_prompt}], + stream=True, + ) + self.llm_usages.append(_triage_resp.usage) + context.increment_loop() - # --- Multi-factor evidence acceptance --- - cluster_evidence = str(cluster.content) if cluster and cluster.content else "" - accepted, accept_reason = self._evaluate_evidence_acceptance( - query, cluster_evidence, should_answer, + _triage_raw = _triage_resp.content or "" + _fast_triage_answer, _ft_save, _ft_should = ( + self._parse_summary_response(_triage_raw) ) - await self._logger.info( - f"[Phase 4] Evidence acceptance: {accepted} ({accept_reason})" + _fast_triage_accepted, _ft_reason = ( + self._evaluate_evidence_acceptance( + query, _triage_evidence, _ft_should, + ) ) - if not accepted: - if llm_fallback: - await self._logger.info( - "[Phase 4] Summary gate rejected evidence, llm_fallback=True → LLM fallback" - ) - answer, should_save = await self._summarise_cluster_fallback(query) - else: - await self._logger.warning( - "[Phase 4] Summary gate rejected evidence and llm_fallback=False " - "→ returning no results" - ) - return _NO_RESULTS_MESSAGE, None, context - if not cluster.search_results: - cluster.search_results = list(merged_files) - elif llm_fallback: - await self._logger.info( - "[Phase 4] Evidence insufficient, llm_fallback=True \u2192 LLM summary" + # Require a PRECISE_ANSWER for true triage acceptance — if the + # LLM could not produce one, the query likely needs deeper + # reasoning even if evidence was nominally accepted. + _has_precise = bool( + re.search( + r"(.+?)", + _triage_raw, re.DOTALL, + ) ) - answer, should_save = await self._summarise_cluster_fallback(query) - else: - await self._logger.info("[Phase 4] Evidence insufficient, launching ReAct refinement") - # P5: enrich ReAct context with graph knowledge - react_spec = f"{spec_context}\n\n{graph_ctx}" if graph_ctx else spec_context - react_answer, context = await self._react_refinement( - query=query, paths=paths, - initial_keywords=initial_keywords, spec_context=react_spec, - enable_dir_scan=enable_dir_scan, - max_loops=max_loops, max_token_budget=max_token_budget, - max_depth=max_depth, include=include, exclude=exclude, + if _fast_triage_accepted and not _has_precise: + _fast_triage_accepted = False + _ft_reason = "no_precise_answer" + + await self._logger.info( + f"[Phase 3.6] Fast triage: accepted={_fast_triage_accepted} " + f"({_ft_reason})" ) + if _fast_triage_accepted and _fast_triage_answer: + answer = _fast_triage_answer + should_save = True if not cluster: - cluster = await self._build_cluster_from_context( - query=query, answer=react_answer, context=context, - query_keywords=query_keywords, top_k_files=top_k_files, + cluster = self._make_answer_cluster( + query, answer, "DT", + file_paths=list(_pre_nav_evidence), ) - elif react_answer and not cluster.content: - cluster.content = react_answer + elif not cluster.content: + cluster.content = answer + await self._logger.info( + "[Phase 3.6] Fast triage accepted → skipping structured reasoning" + ) + else: + # ============================================================== + # Phase 4: Deep Structured Reasoning or cluster-based fallback + # ============================================================== + context.increment_loop() + answer = "" + should_save = True + + # Determine tree-indexed files for structured reasoning. + # tree_hits come from Phase 1 probe and are always tree-indexed; + # also check artifacts when available for broader coverage. + _sr_files: List[str] = [] + if tree_hits: + _sr_files = list(tree_hits[: self._DEEP_STRUCTURED_MAX_FILES]) + elif artifacts and artifacts.tree_available_paths: + _sr_files = list(artifacts.tree_available_paths)[ + : self._DEEP_STRUCTURED_MAX_FILES + ] - if not cluster: - await self._logger.warning( - "[Phase 4] ReAct found no buildable evidence and llm_fallback=False " - "→ returning no results" + if _sr_files: + await self._logger.info( + f"[Phase 4] Launching structured reasoning for " + f"{len(_sr_files)} tree-indexed files" + ) + sr_answer, sr_cluster = await self._deep_structured_reasoning( + query, _sr_files, artifacts, context, ) - return _NO_RESULTS_MESSAGE, None, context - - # Final DEEP decision is always made in the summary call. - answer, should_save, should_answer = await self._summarise_cluster(query, cluster) - # --- Multi-factor evidence acceptance --- - final_cluster_evidence = str(cluster.content) if cluster and cluster.content else "" - final_accepted, final_reason = self._evaluate_evidence_acceptance( - query, final_cluster_evidence, should_answer, - ) - await self._logger.info( - f"[Phase 4] Final evidence acceptance: {final_accepted} ({final_reason})" - ) + if sr_answer: + answer, should_save, should_answer = self._parse_summary_response( + sr_answer + ) + accepted, accept_reason = self._evaluate_evidence_acceptance( + query, + sr_answer, + should_answer, + ) + await self._logger.info( + f"[Phase 4] Structured reasoning acceptance: " + f"{accepted} ({accept_reason})" + ) + if accepted: + cluster = sr_cluster or cluster + else: + answer = "" # Fall through to cluster/ReAct path + + # Fallback: original cluster summary or ReAct + if not answer: + # Inject catalog context for wiki-enhanced answer + if artifacts and artifacts.catalog_map and cluster and cluster.content: + _catalog_ctx_parts = [] + for fp in (cluster.search_results or merged_files)[:3]: + ctx = self._build_answer_context(fp, artifacts) + if ctx: + _catalog_ctx_parts.append(ctx) + if _catalog_ctx_parts: + _catalog_context = "\n".join(_catalog_ctx_parts) + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = ( + f"{cluster.content}\n\n" + f"[Document Context]\n{_catalog_context}" + ) - if not final_accepted: - if llm_fallback: + if cluster and cluster.content: await self._logger.info( - "[Phase 4] Final summary gate rejected evidence, llm_fallback=True → LLM fallback" + "[Phase 4:Fallback] Generating summary from cluster" + ) + answer, should_save, should_answer = ( + await self._summarise_cluster(query, cluster) + ) + cluster_evidence = ( + str(cluster.content) if cluster.content else "" + ) + accepted, accept_reason = ( + self._evaluate_evidence_acceptance( + query, cluster_evidence, should_answer, + ) + ) + if not accepted: + if llm_fallback: + answer, should_save = ( + await self._summarise_cluster_fallback(query) + ) + else: + return _NO_RESULTS_MESSAGE, None, context + if not cluster.search_results: + cluster.search_results = list(merged_files) + elif llm_fallback: + answer, should_save = ( + await self._summarise_cluster_fallback(query) ) - answer, should_save = await self._summarise_cluster_fallback(query) else: - await self._logger.warning( - "[Phase 4] Final summary gate rejected evidence and llm_fallback=False " - "→ returning no results" + await self._logger.info( + "[Phase 4:Fallback] Launching ReAct refinement" ) - return _NO_RESULTS_MESSAGE, None, context + react_spec = ( + f"{spec_context}\n\n{graph_ctx}" + if graph_ctx else spec_context + ) + react_answer, context = await self._react_refinement( + query=query, paths=paths, + initial_keywords=initial_keywords, + spec_context=react_spec, + enable_dir_scan=enable_dir_scan, + max_loops=max_loops, + max_token_budget=max_token_budget, + max_depth=max_depth, + include=include, exclude=exclude, + ) + if not cluster: + cluster = await self._build_cluster_from_context( + query=query, answer=react_answer, + context=context, + query_keywords=query_keywords, + top_k_files=top_k_files, + ) + elif react_answer and not cluster.content: + cluster.content = react_answer + if not cluster: + return _NO_RESULTS_MESSAGE, None, context + answer, should_save, should_answer = ( + await self._summarise_cluster(query, cluster) + ) + final_evidence = ( + str(cluster.content) if cluster.content else "" + ) + final_accepted, _ = self._evaluate_evidence_acceptance( + query, final_evidence, should_answer, + ) + if not final_accepted: + if llm_fallback: + answer, should_save = ( + await self._summarise_cluster_fallback(query) + ) + else: + return _NO_RESULTS_MESSAGE, None, context # Sync LLM token accounting into context new_usages = self.llm_usages[_llm_usage_start:] @@ -2334,6 +2460,9 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: parallel tree navigation. Since tree_ev will provide targeted evidence, the RGA path uses a smaller budget to supply incremental tables, leaving room for more diverse evidence.""" + _DEEP_CROSS_SECTION_MIN_EVIDENCE: int = 8_000 + """Cross-section table supplement is skipped when existing tree-nav + evidence already exceeds this threshold (chars), preventing overload.""" # --- Self-correction expanded sampling --- _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 10 @@ -2341,6 +2470,18 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _SELF_CORRECT_EXPANDED_SECTIONS: int = 8 """Expanded tree sample sections for same-file re-sampling (default uses 5).""" + # --- Deep Structured Reasoning --- + _DEEP_SECTION_MAP_MAX_DEPTH: int = 2 + """Maximum tree depth for section map construction (top-N layers).""" + _DEEP_MAX_EXTRACT_PAGES: int = 12 + """Maximum pages to extract per file in targeted page extraction.""" + _DEEP_STRUCTURED_MAX_CHARS: int = 30_000 + """Maximum character budget for structured evidence per file.""" + _DEEP_MAX_RECOVERY_ROUNDS: int = 2 + """Maximum rounds of missing-data recovery before final answer.""" + _DEEP_STRUCTURED_MAX_FILES: int = 3 + """Maximum files to process through structured reasoning pipeline.""" + # --- Evidence acceptance thresholds --- _EVIDENCE_MIN_ACCEPT_LENGTH: int = 800 """Minimum evidence character length for heuristic override.""" @@ -4963,13 +5104,11 @@ async def _navigate_tree_for_evidence( except Exception: pass - # ── Phase 5.5: Cross-section table supplement ── - # The leaf-scoped supplement (above) only includes tables from - # pages matching selected leaves. When leaves cluster in one - # region (e.g. management discussion), data-dense tables from - # other sections (e.g. financial statements) are missed. - # Fix: include top-ranked tables from UNCOVERED pages. - if _all_tables and leaves: + # ── Phase 5.5: Cross-section table supplement (conditional) ── + # Only supplements when existing evidence is below threshold + # to prevent evidence overload for queries already well-served. + _current_ev_len = sum(len(p) for p in parts) + if _all_tables and leaves and _current_ev_len < self._DEEP_CROSS_SECTION_MIN_EVIDENCE: _leaf_page_set: Set[int] = set() for _lf in leaves: _pr = getattr(_lf, "page_range", None) @@ -6180,6 +6319,439 @@ async def _summarise_fast_fallback( answer, _, _ = self._parse_summary_response(answer_resp.content or "") return answer, False # Never save fallback answers + # ------------------------------------------------------------------ + # Deep Structured Reasoning pipeline + # ------------------------------------------------------------------ + + @staticmethod + def _build_section_map( + root: Any, + max_depth: int = 2, + ) -> Tuple[str, List[Dict[str, Any]]]: + """Build a lightweight section map from the top layers of a tree index. + + Args: + root: A ``TreeNode`` root from a ``DocumentTree``. + + Returns a human-readable map string (with numbered indices so the LLM + can reference specific sections) and a parallel list of section + metadata dicts for programmatic use. + """ + sections: List[Dict[str, Any]] = [] + + def _walk(node: Any, depth: int) -> None: + if depth > max_depth: + return + pr = node.page_range + idx = len(sections) + sections.append({ + "idx": idx, + "title": node.title, + "page_range": list(pr) if pr else None, + "char_range": list(node.char_range) if getattr(node, "char_range", None) else None, + "depth": depth, + "node_id": node.node_id, + "summary": (node.summary or "")[:120], + }) + for child in node.children: + _walk(child, depth + 1) + + children = root.children if root.children else [root] + while len(children) == 1 and children[0].children and not children[0].leaf: + children = children[0].children + + for child in children: + _walk(child, 0) + + map_lines: List[str] = [] + for sec in sections: + indent = " " * sec["depth"] + pr = sec.get("page_range") + page_str = f"(p{pr[0]}-{pr[1]})" if pr and pr[0] else "" + map_lines.append(f"[{sec['idx']}] {indent}{sec['title']} {page_str}") + + return "\n".join(map_lines), sections + + async def _select_evidence_sections( + self, + query: str, + section_map: str, + sections_meta: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + """LLM-driven selection of relevant sections from a section map. + + Returns the metadata dicts for the selected sections. + """ + prompt = DEEP_SECTION_SELECT.format( + query=query, + section_map=section_map, + ) + resp = await self.llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, + ) + self.llm_usages.append(resp.usage) + + raw = (resp.content or "").strip() + # Parse JSON array of indices + try: + match = re.search(r"\[[\s\d,]*\]", raw) + if match: + indices = json.loads(match.group(0)) + return [ + sections_meta[i] + for i in indices + if isinstance(i, int) and 0 <= i < len(sections_meta) + ] + except (json.JSONDecodeError, IndexError): + pass + + # Fallback: return sections that have page_range data + return [s for s in sections_meta if s.get("page_range")][:3] + + async def _extract_targeted_pages( + self, + file_path: str, + selected_sections: List[Dict[str, Any]], + query: str, + ) -> str: + """Extract content for LLM-selected sections. + + Two extraction strategies (tried in order): + 1. **Page-based** — ``DocumentExtractor.extract_pages`` for PDFs. + 2. **Char-range** — direct text slice from compile cache or + fast_extract for any file type. + + Table digests are appended when available. Caps output at + ``_DEEP_STRUCTURED_MAX_CHARS``. + """ + parts: List[str] = [] + + # Strategy 1: page-based extraction (PDF) + pages_needed: Set[int] = set() + for sec in selected_sections: + pr = sec.get("page_range") + if pr and len(pr) == 2 and pr[0]: + pages_needed.update(range( + max(1, pr[0] - self._NAV_PAGE_MARGIN), + pr[1] + self._NAV_PAGE_MARGIN + 1, + )) + + if pages_needed: + sorted_pages = sorted(pages_needed)[: self._DEEP_MAX_EXTRACT_PAGES] + try: + page_contents = DocumentExtractor.extract_pages( + file_path, sorted_pages, + ) + for pc in page_contents: + if pc.content and pc.content.strip(): + parts.append(f"[Page {pc.page_number}]\n{pc.content}") + except Exception as e: + await self._logger.warning( + f"[DeepStructured] Page extraction failed for " + f"{Path(file_path).name}: {e}" + ) + + # Strategy 2: char_range fallback (non-PDF or when pages failed) + if not parts: + full_text = self._load_compile_content(self.work_path, file_path) + if not full_text: + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + if full_text: + for sec in selected_sections: + cr = sec.get("char_range") + if cr and len(cr) == 2 and cr[0] is not None: + start, end = cr + if 0 <= start < end <= len(full_text): + segment = full_text[start:end] + if segment.strip(): + parts.append( + f"[{sec.get('title', 'Section')}]\n{segment}" + ) + + # Append relevant table digests when available + if pages_needed: + try: + from sirchmunk.utils.file_utils import get_fast_hash + fhash = get_fast_hash(file_path) + if fhash: + tables = self._load_table_digest(self.work_path, fhash) + if tables: + page_tables = [ + t for t in tables + if t.get("page_number") in pages_needed + ] + if page_tables: + table_ev = self._format_table_evidence( + page_tables, + max_chars=self._TABLE_EVIDENCE_DEFAULT_CHARS, + query=query, + ) + if table_ev: + parts.append(f"[Table Evidence]\n{table_ev}") + except Exception: + pass + + evidence = "\n\n".join(parts) + return evidence[: self._DEEP_STRUCTURED_MAX_CHARS] + + async def _extract_structured_data( + self, + query: str, + raw_evidence: str, + ) -> Tuple[str, str, List[str]]: + """LLM extraction of structured data from raw page evidence. + + Returns (structured_data, completeness, missing_items). + """ + prompt = DEEP_STRUCTURED_EXTRACT.format( + query=query, + evidence=raw_evidence, + ) + resp = await self.llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) + content = resp.content or "" + + # Parse structured output + data_match = re.search( + r"(.*?)", content, re.DOTALL, + ) + structured = data_match.group(1).strip() if data_match else content + + comp_match = re.search( + r"\s*(complete|partial|insufficient)\s*", + content, re.IGNORECASE, + ) + completeness = comp_match.group(1).lower() if comp_match else "partial" + + missing_match = re.search( + r"(.*?)", content, re.DOTALL, + ) + missing_items: List[str] = [] + if missing_match: + raw_missing = missing_match.group(1).strip() + if raw_missing and raw_missing.lower() not in ("", "none", "n/a", "empty"): + missing_items = [ + line.strip().lstrip("- ") + for line in raw_missing.split("\n") + if line.strip() and line.strip() != "-" + ] + + return structured, completeness, missing_items + + async def _reason_with_verification( + self, + query: str, + structured_data: str, + ) -> Tuple[str, str, str]: + """CoT reasoning with self-verification. + + Returns (answer, confidence, full_reasoning). + """ + prompt = DEEP_COT_REASONING.format( + query=query, + structured_data=structured_data, + ) + resp = await self.llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) + content = resp.content or "" + + answer_match = re.search( + r"(.*?)", content, re.DOTALL, + ) + answer = answer_match.group(1).strip() if answer_match else "" + + conf_match = re.search( + r"\s*(high|medium|low)\s*", + content, re.IGNORECASE, + ) + confidence = conf_match.group(1).lower() if conf_match else "low" + + return answer, confidence, content + + async def _deep_structured_reasoning( + self, + query: str, + tree_files: List[str], + artifacts: Any, + context: "SearchContext", + ) -> Tuple[str, Optional["KnowledgeCluster"]]: + """Orchestrate the full Deep Structured Reasoning pipeline. + + Phases: + 1. Section map — build from tree index top layers + 2. Section select — LLM picks relevant sections + 3. Targeted extraction — pull pages + tables for selected sections + 4. Structured data — LLM extracts key-value data + 5. CoT reasoning + verification + 6. Missing data recovery (conditional, up to N rounds) + + Returns ``(answer, cluster)`` suitable for the DEEP search return path. + """ + indexer = self._get_tree_indexer() + if indexer is None: + return "", None + + all_structured: List[str] = [] + + for fp in tree_files[: self._DEEP_STRUCTURED_MAX_FILES]: + fname = Path(fp).name + tree = indexer.load_tree(fp) + if tree is None or tree.root is None: + continue + + # Phase 1: Section map + section_map, sections_meta = self._build_section_map( + tree.root, max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH, + ) + if not sections_meta: + continue + + await self._logger.info( + f"[DeepStructured] Section map for {fname}: " + f"{len(sections_meta)} sections" + ) + + # Phase 2: LLM selects relevant sections + selected = await self._select_evidence_sections( + query, section_map, sections_meta, + ) + context.increment_loop() + if not selected: + continue + + await self._logger.info( + f"[DeepStructured] Selected {len(selected)} sections: " + f"{[s['title'][:30] for s in selected]}" + ) + + # Phase 3: Targeted page extraction + raw_evidence = await self._extract_targeted_pages( + fp, selected, query, + ) + if not raw_evidence: + continue + + # Per-file evidence accumulator for recovery rounds + file_raw_parts: List[str] = [raw_evidence] + + await self._logger.info( + f"[DeepStructured] Extracted {len(raw_evidence)} chars " + f"from {fname}" + ) + + # Phase 4: Structured data extraction + structured, completeness, missing_items = ( + await self._extract_structured_data(query, raw_evidence) + ) + context.increment_loop() + + await self._logger.info( + f"[DeepStructured] Data extraction: " + f"completeness={completeness}, " + f"missing={len(missing_items)}" + ) + + # Phase 4.5: Missing data recovery (per-file loop) + recovery_round = 0 + while ( + missing_items + and completeness != "complete" + and recovery_round < self._DEEP_MAX_RECOVERY_ROUNDS + ): + recovery_round += 1 + await self._logger.info( + f"[DeepStructured] Recovery round {recovery_round}: " + f"seeking {missing_items}" + ) + + recovery_query = ( + f"{query} — specifically find: " + f"{', '.join(missing_items[:5])}" + ) + + recovery_selected = await self._select_evidence_sections( + recovery_query, section_map, sections_meta, + ) + context.increment_loop() + + existing_ids = {s["node_id"] for s in selected} + new_sections = [ + s for s in recovery_selected + if s["node_id"] not in existing_ids + ] + if not new_sections: + break + + recovery_evidence = await self._extract_targeted_pages( + fp, new_sections, recovery_query, + ) + if not recovery_evidence: + break + file_raw_parts.append(recovery_evidence) + + combined = "\n\n".join(file_raw_parts) + structured, completeness, missing_items = ( + await self._extract_structured_data( + query, + combined[: self._DEEP_STRUCTURED_MAX_CHARS], + ) + ) + context.increment_loop() + selected.extend(new_sections) + + if structured: + all_structured.append(f"[Source: {fname}]\n{structured}") + + if not all_structured: + return "", None + + # Phase 5: CoT reasoning with verification + combined_data = "\n\n".join(all_structured) + answer, confidence, full_reasoning = await self._reason_with_verification( + query, combined_data, + ) + context.increment_loop() + + await self._logger.info( + f"[DeepStructured] Reasoning complete: " + f"confidence={confidence}, answer_len={len(answer)}" + ) + + # Build a cluster for persistence (strip XML tags for clean content) + _clean_reasoning = re.sub( + r"", + "", full_reasoning, + ).strip() + cluster = self._make_answer_cluster( + query, _clean_reasoning, "DSR", + file_paths=tree_files[: self._DEEP_STRUCTURED_MAX_FILES], + ) + + if not answer: + return "", cluster + + formatted = ( + f"\n{_clean_reasoning}\n\n" + f"{answer}\n" + f"true\n" + f"{'true' if confidence != 'low' else 'false'}" + f"" + ) + + return formatted, cluster + async def _react_refinement( self, query: str, From ec08cd8d14c2e6ab564ef45ce2d0ef08aceb20ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 13 May 2026 21:25:03 +0800 Subject: [PATCH 58/70] improve DEEP --- src/sirchmunk/llm/prompts.py | 57 --- src/sirchmunk/search.py | 710 +++++++++++++++-------------------- 2 files changed, 309 insertions(+), 458 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 3bd545d..17c7e9f 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -527,63 +527,6 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: """ -DEEP_STRUCTURED_EXTRACT = """Extract all data relevant to the query from the provided document content. - -### User Query -{query} - -### Document Content -{evidence} - -### Instructions -1. Extract every data point, number, or fact that could help answer the query. -2. Preserve exact values, units, and context (e.g. fiscal year, line item name). -3. For tables, extract the specific rows and columns relevant to the query. -4. Note the source location (section title or page) for each extracted item. -5. If the query requires a calculation, identify and extract ALL input values needed. - -### Output Format - -- [source]: [data point name] = [value] [unit] -- [source]: [data point name] = [value] [unit] -... - -complete|partial|insufficient - -[List any data items needed to answer the query that were NOT found. Empty if complete.] - -""" - - -DEEP_COT_REASONING = """Answer the query using ONLY the extracted data below. Show complete reasoning. - -### User Query -{query} - -### Extracted Data -{structured_data} - -### Instructions -1. State which data points you will use and why. -2. Show ALL calculation steps explicitly (one operation per line). -3. After computing the result, verify it: - a. Check units and order of magnitude. - b. Cross-check with any alternative data if available. - c. Ensure the result directly answers what was asked. -4. Match the precision and units implied by the query. - -### Output Format - -[Step-by-step reasoning with explicit calculations] - - -[Sanity checks and cross-validation] - -[Your final answer here] -high|medium|low -""" - - # --------------------------------------------------------------------------- # Knowledge Compile prompts # --------------------------------------------------------------------------- diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 6d9a17e..e5492da 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -23,13 +23,10 @@ FAST_QUERY_ANALYSIS, FAST_QUERY_ANALYSIS_WITH_CATALOG, ROI_RESULT_SUMMARY, - SEARCH_RESULT_SUMMARY, DOC_SUMMARY, DOC_CHUNK_SUMMARY, DOC_MERGE_SUMMARIES, DEEP_SECTION_SELECT, - DEEP_STRUCTURED_EXTRACT, - DEEP_COT_REASONING, ) from sirchmunk.retrieve.text_retriever import GrepRetriever from sirchmunk.schema.knowledge import ( @@ -572,6 +569,14 @@ async def _try_reuse_cluster(self, query: str, paths: Optional[List[str]] = None ) return None + # P3: skip clusters whose cached answer is a refusal + if self._is_refusal_answer(content): + await self._logger.info( + f"Cluster {existing_cluster.id} contains a refusal answer, " + "falling back to full search" + ) + return None + # Mutate only after validation passes self._add_query_to_cluster(existing_cluster, query) existing_cluster.hotness = min(1.0, (existing_cluster.hotness or 0.5) + 0.1) @@ -988,6 +993,25 @@ async def _search_by_filename( re.IGNORECASE, ) + _REFUSAL_PATTERN = re.compile( + r'cannot\s+(?:be\s+)?determin' + r'|data\s+(?:not\s+available|insufficient)' + r'|not\s+(?:possible|available)\s+to\s+(?:determin|calculat|answer)' + r'|information\s+(?:is\s+)?not\s+(?:available|provided|found)' + r'|no\s+(?:relevant|sufficient)\s+(?:data|information|evidence)', + re.IGNORECASE, + ) + + @classmethod + def _is_refusal_answer(cls, text: str) -> bool: + """Detect whether *text* is a refusal / no-data answer.""" + if not text or len(text.strip()) < 20: + return True + head = text[:500] + if re.search(r'\bN/?A\b', head): + return True + return bool(cls._REFUSAL_PATTERN.search(head)) + @classmethod def _parse_summary_response(cls, llm_response: str) -> Tuple[str, bool, bool]: """Parse LLM response to extract summary, precise answer, and quality decisions. @@ -1027,6 +1051,10 @@ def _parse_summary_response(cls, llm_response: str) -> Tuple[str, bool, bool]: should_answer = False should_save = False + # P3: Never persist refusal/no-data answers to cluster cache + if should_save and cls._is_refusal_answer(precise or summary): + should_save = False + return summary, should_save, should_answer # ------------------------------------------------------------------ @@ -1078,6 +1106,43 @@ def _detect_numeric_evidence(query: str, evidence: str) -> bool: ) return has_numeric + _COMPLEX_QUERY_PATTERNS = [ + re.compile(p, re.IGNORECASE) for p in [ + r'\d+[- ]year average', + r'year[- ]over[- ]year', + r'compare.*between|between.*and.*fy', + r'trend|trajectory', + r'fy\d{4}.*(?:to|and|vs).*fy\d{4}', + r'(?:3|5|10)[- ]year', + r'average.*(?:margin|ratio|growth)', + r'change.*from.*to', + ] + ] + _MODERATE_QUERY_PATTERNS = [ + re.compile(p, re.IGNORECASE) for p in [ + r'ratio|margin|percentage', + r'calculate|compute', + r'turnover|conversion|coverage', + r'capex|ebitda|eps|roe|roa|dpo', + r'what is (?:the )?fy\d{4}', + r'how (?:much|many)', + ] + ] + + @classmethod + def _classify_query_complexity(cls, query: str) -> str: + """Classify *query* as ``simple``, ``moderate``, or ``complex``. + + Used by DEEP mode to decide whether to invoke the heavier + section-map structured reasoning pipeline or go straight to + cluster-level synthesis. + """ + if any(p.search(query) for p in cls._COMPLEX_QUERY_PATTERNS): + return "complex" + if any(p.search(query) for p in cls._MODERATE_QUERY_PATTERNS): + return "moderate" + return "simple" + @staticmethod def _evaluate_evidence_acceptance( query: str, @@ -1836,109 +1901,24 @@ async def _search_deep( cluster.content = f"{cluster.content}\n\n{graph_ctx}" # ============================================================== - # Phase 3.6: Adaptive depth — fast triage for simple queries - # When pre-nav evidence is sufficient, attempt FAST-style synthesis. - # The LLM decides via SHOULD_ANSWER + PRECISE_ANSWER whether the - # evidence is adequate — no hardcoded heuristic. If the LLM - # returns a precise answer with acceptance, we skip the heavier - # structured reasoning pipeline. + # Phase 4: Structured Reasoning → Cluster Summary fallback + # P0: DEEP mode always goes through full reasoning pipeline — + # no fast triage short-circuit. P4: query complexity determines + # whether the heavier section-map SR fires or we go straight to + # cluster synthesis. # ============================================================== - _fast_triage_answer: Optional[str] = None - _fast_triage_accepted = False - - if _pre_nav_evidence: - _pre_nav_total = sum(len(v) for v in _pre_nav_evidence.values()) - await self._logger.info( - f"[Phase 3.6] Fast triage: {_pre_nav_total} chars of pre-nav evidence" - ) - _triage_evidence = "\n\n---\n\n".join( - f"[{Path(fp).name}]\n{ev}" - for fp, ev in _pre_nav_evidence.items() - ) - - doc_context = None - if artifacts and artifacts.catalog_map: - _ctx_parts = [ - self._build_answer_context(fp, artifacts) - for fp in list(_pre_nav_evidence)[:2] - ] - _ctx_parts = [c for c in _ctx_parts if c] - if _ctx_parts: - doc_context = "\n".join(_ctx_parts) - - if doc_context: - from sirchmunk.llm.prompts import ROI_RESULT_SUMMARY_WITH_CONTEXT - _triage_prompt = ROI_RESULT_SUMMARY_WITH_CONTEXT.format( - user_input=query, - text_content=_triage_evidence, - document_context=doc_context, - ) - else: - _triage_prompt = ROI_RESULT_SUMMARY.format( - user_input=query, - text_content=_triage_evidence, - ) - - _triage_resp = await self.llm.achat( - messages=[{"role": "user", "content": _triage_prompt}], - stream=True, - ) - self.llm_usages.append(_triage_resp.usage) - context.increment_loop() - - _triage_raw = _triage_resp.content or "" - _fast_triage_answer, _ft_save, _ft_should = ( - self._parse_summary_response(_triage_raw) - ) - _fast_triage_accepted, _ft_reason = ( - self._evaluate_evidence_acceptance( - query, _triage_evidence, _ft_should, - ) - ) - - # Require a PRECISE_ANSWER for true triage acceptance — if the - # LLM could not produce one, the query likely needs deeper - # reasoning even if evidence was nominally accepted. - _has_precise = bool( - re.search( - r"(.+?)", - _triage_raw, re.DOTALL, - ) - ) - if _fast_triage_accepted and not _has_precise: - _fast_triage_accepted = False - _ft_reason = "no_precise_answer" - - await self._logger.info( - f"[Phase 3.6] Fast triage: accepted={_fast_triage_accepted} " - f"({_ft_reason})" - ) + context.increment_loop() + answer = "" + should_save = True - if _fast_triage_accepted and _fast_triage_answer: - answer = _fast_triage_answer - should_save = True - if not cluster: - cluster = self._make_answer_cluster( - query, answer, "DT", - file_paths=list(_pre_nav_evidence), - ) - elif not cluster.content: - cluster.content = answer - await self._logger.info( - "[Phase 3.6] Fast triage accepted → skipping structured reasoning" - ) - else: - # ============================================================== - # Phase 4: Deep Structured Reasoning or cluster-based fallback - # ============================================================== - context.increment_loop() - answer = "" - should_save = True + _query_complexity = self._classify_query_complexity(query) + await self._logger.info( + f"[Phase 4] Query complexity: {_query_complexity}" + ) - # Determine tree-indexed files for structured reasoning. - # tree_hits come from Phase 1 probe and are always tree-indexed; - # also check artifacts when available for broader coverage. - _sr_files: List[str] = [] + # Attempt structured reasoning for moderate/complex queries + _sr_files: List[str] = [] + if _query_complexity != "simple": if tree_hits: _sr_files = list(tree_hits[: self._DEEP_STRUCTURED_MAX_FILES]) elif artifacts and artifacts.tree_available_paths: @@ -1946,124 +1926,121 @@ async def _search_deep( : self._DEEP_STRUCTURED_MAX_FILES ] - if _sr_files: - await self._logger.info( - f"[Phase 4] Launching structured reasoning for " - f"{len(_sr_files)} tree-indexed files" + if _sr_files: + await self._logger.info( + f"[Phase 4] Launching structured reasoning for " + f"{len(_sr_files)} tree-indexed files" + ) + sr_answer, sr_cluster = await self._deep_structured_reasoning( + query, _sr_files, artifacts, context, + ) + + if sr_answer: + answer, should_save, should_answer = self._parse_summary_response( + sr_answer ) - sr_answer, sr_cluster = await self._deep_structured_reasoning( - query, _sr_files, artifacts, context, + accepted, accept_reason = self._evaluate_evidence_acceptance( + query, sr_answer, should_answer, ) + await self._logger.info( + f"[Phase 4] Structured reasoning: " + f"accepted={accepted} ({accept_reason})" + ) + if accepted: + cluster = sr_cluster or cluster + else: + answer = "" - if sr_answer: - answer, should_save, should_answer = self._parse_summary_response( - sr_answer - ) - accepted, accept_reason = self._evaluate_evidence_acceptance( - query, - sr_answer, - should_answer, - ) - await self._logger.info( - f"[Phase 4] Structured reasoning acceptance: " - f"{accepted} ({accept_reason})" + # Fallback: cluster summary with ROI prompt or ReAct + if not answer: + if artifacts and artifacts.catalog_map and cluster and cluster.content: + _catalog_ctx_parts = [] + for fp in (cluster.search_results or merged_files)[:3]: + ctx = self._build_answer_context(fp, artifacts) + if ctx: + _catalog_ctx_parts.append(ctx) + if _catalog_ctx_parts: + _catalog_context = "\n".join(_catalog_ctx_parts) + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = ( + f"{cluster.content}\n\n" + f"[Document Context]\n{_catalog_context}" ) - if accepted: - cluster = sr_cluster or cluster - else: - answer = "" # Fall through to cluster/ReAct path - - # Fallback: original cluster summary or ReAct - if not answer: - # Inject catalog context for wiki-enhanced answer - if artifacts and artifacts.catalog_map and cluster and cluster.content: - _catalog_ctx_parts = [] - for fp in (cluster.search_results or merged_files)[:3]: - ctx = self._build_answer_context(fp, artifacts) - if ctx: - _catalog_ctx_parts.append(ctx) - if _catalog_ctx_parts: - _catalog_context = "\n".join(_catalog_ctx_parts) - if isinstance(cluster.content, list): - cluster.content = "\n".join(cluster.content) - cluster.content = ( - f"{cluster.content}\n\n" - f"[Document Context]\n{_catalog_context}" - ) - if cluster and cluster.content: - await self._logger.info( - "[Phase 4:Fallback] Generating summary from cluster" - ) - answer, should_save, should_answer = ( - await self._summarise_cluster(query, cluster) - ) - cluster_evidence = ( - str(cluster.content) if cluster.content else "" + if cluster and cluster.content: + await self._logger.info( + "[Phase 4:Fallback] Generating summary from cluster" + ) + answer, should_save, should_answer = ( + await self._summarise_cluster(query, cluster) + ) + cluster_evidence = ( + str(cluster.content) if cluster.content else "" + ) + accepted, accept_reason = ( + self._evaluate_evidence_acceptance( + query, cluster_evidence, should_answer, ) - accepted, accept_reason = ( - self._evaluate_evidence_acceptance( - query, cluster_evidence, should_answer, + ) + if not accepted: + if llm_fallback: + answer, should_save = ( + await self._summarise_cluster_fallback(query) ) + else: + return _NO_RESULTS_MESSAGE, None, context + if not cluster.search_results: + cluster.search_results = list(merged_files) + elif llm_fallback: + answer, should_save = ( + await self._summarise_cluster_fallback(query) + ) + else: + await self._logger.info( + "[Phase 4:Fallback] Launching ReAct refinement" + ) + react_spec = ( + f"{spec_context}\n\n{graph_ctx}" + if graph_ctx else spec_context + ) + react_answer, context = await self._react_refinement( + query=query, paths=paths, + initial_keywords=initial_keywords, + spec_context=react_spec, + enable_dir_scan=enable_dir_scan, + max_loops=max_loops, + max_token_budget=max_token_budget, + max_depth=max_depth, + include=include, exclude=exclude, + ) + if not cluster: + cluster = await self._build_cluster_from_context( + query=query, answer=react_answer, + context=context, + query_keywords=query_keywords, + top_k_files=top_k_files, ) - if not accepted: - if llm_fallback: - answer, should_save = ( - await self._summarise_cluster_fallback(query) - ) - else: - return _NO_RESULTS_MESSAGE, None, context - if not cluster.search_results: - cluster.search_results = list(merged_files) - elif llm_fallback: - answer, should_save = ( - await self._summarise_cluster_fallback(query) - ) - else: - await self._logger.info( - "[Phase 4:Fallback] Launching ReAct refinement" - ) - react_spec = ( - f"{spec_context}\n\n{graph_ctx}" - if graph_ctx else spec_context - ) - react_answer, context = await self._react_refinement( - query=query, paths=paths, - initial_keywords=initial_keywords, - spec_context=react_spec, - enable_dir_scan=enable_dir_scan, - max_loops=max_loops, - max_token_budget=max_token_budget, - max_depth=max_depth, - include=include, exclude=exclude, - ) - if not cluster: - cluster = await self._build_cluster_from_context( - query=query, answer=react_answer, - context=context, - query_keywords=query_keywords, - top_k_files=top_k_files, + elif react_answer and not cluster.content: + cluster.content = react_answer + if not cluster: + return _NO_RESULTS_MESSAGE, None, context + answer, should_save, should_answer = ( + await self._summarise_cluster(query, cluster) + ) + final_evidence = ( + str(cluster.content) if cluster.content else "" + ) + final_accepted, _ = self._evaluate_evidence_acceptance( + query, final_evidence, should_answer, + ) + if not final_accepted: + if llm_fallback: + answer, should_save = ( + await self._summarise_cluster_fallback(query) ) - elif react_answer and not cluster.content: - cluster.content = react_answer - if not cluster: + else: return _NO_RESULTS_MESSAGE, None, context - answer, should_save, should_answer = ( - await self._summarise_cluster(query, cluster) - ) - final_evidence = ( - str(cluster.content) if cluster.content else "" - ) - final_accepted, _ = self._evaluate_evidence_acceptance( - query, final_evidence, should_answer, - ) - if not final_accepted: - if llm_fallback: - answer, should_save = ( - await self._summarise_cluster_fallback(query) - ) - else: - return _NO_RESULTS_MESSAGE, None, context # Sync LLM token accounting into context new_usages = self.llm_usages[_llm_usage_start:] @@ -6248,6 +6225,9 @@ async def _summarise_cluster( ) -> Tuple[str, bool, bool]: """Generate a final answer summary from a KnowledgeCluster. + Uses ``ROI_RESULT_SUMMARY`` (with precision / best-effort constraints) + for both FAST and DEEP modes, ensuring consistent answer quality. + Returns: ``(summary_text, should_save, should_answer)`` where: - should_save: quality verdict for persistence @@ -6260,7 +6240,7 @@ async def _summarise_cluster( f"{cluster.content if isinstance(cluster.content, str) else sep.join(cluster.content)}" ) - result_sum_prompt = SEARCH_RESULT_SUMMARY.format( + result_sum_prompt = ROI_RESULT_SUMMARY.format( user_input=query, text_content=cluster_text_content, ) @@ -6276,13 +6256,12 @@ async def _summarise_cluster( return summary, should_save, should_answer async def _summarise_cluster_fallback(self, query: str) -> Tuple[str, bool]: - """Generate an answer using the DEEP summary prompt with fallback evidence. + """Generate an answer using the ROI summary prompt with fallback evidence. - Reuses the existing ``SEARCH_RESULT_SUMMARY`` prompt, feeding it the - standard fallback text so that the LLM answers from its own knowledge - without adding an extra LLM call to the pipeline. + Feeds the standard fallback text so the LLM answers from its own + knowledge without adding an extra LLM call to the pipeline. """ - result_sum_prompt = SEARCH_RESULT_SUMMARY.format( + result_sum_prompt = ROI_RESULT_SUMMARY.format( user_input=query, text_content=self._LLM_FALLBACK_EVIDENCE, ) @@ -6500,86 +6479,6 @@ async def _extract_targeted_pages( evidence = "\n\n".join(parts) return evidence[: self._DEEP_STRUCTURED_MAX_CHARS] - async def _extract_structured_data( - self, - query: str, - raw_evidence: str, - ) -> Tuple[str, str, List[str]]: - """LLM extraction of structured data from raw page evidence. - - Returns (structured_data, completeness, missing_items). - """ - prompt = DEEP_STRUCTURED_EXTRACT.format( - query=query, - evidence=raw_evidence, - ) - resp = await self.llm.achat( - messages=[{"role": "user", "content": prompt}], - stream=True, - ) - self.llm_usages.append(resp.usage) - content = resp.content or "" - - # Parse structured output - data_match = re.search( - r"(.*?)", content, re.DOTALL, - ) - structured = data_match.group(1).strip() if data_match else content - - comp_match = re.search( - r"\s*(complete|partial|insufficient)\s*", - content, re.IGNORECASE, - ) - completeness = comp_match.group(1).lower() if comp_match else "partial" - - missing_match = re.search( - r"(.*?)", content, re.DOTALL, - ) - missing_items: List[str] = [] - if missing_match: - raw_missing = missing_match.group(1).strip() - if raw_missing and raw_missing.lower() not in ("", "none", "n/a", "empty"): - missing_items = [ - line.strip().lstrip("- ") - for line in raw_missing.split("\n") - if line.strip() and line.strip() != "-" - ] - - return structured, completeness, missing_items - - async def _reason_with_verification( - self, - query: str, - structured_data: str, - ) -> Tuple[str, str, str]: - """CoT reasoning with self-verification. - - Returns (answer, confidence, full_reasoning). - """ - prompt = DEEP_COT_REASONING.format( - query=query, - structured_data=structured_data, - ) - resp = await self.llm.achat( - messages=[{"role": "user", "content": prompt}], - stream=True, - ) - self.llm_usages.append(resp.usage) - content = resp.content or "" - - answer_match = re.search( - r"(.*?)", content, re.DOTALL, - ) - answer = answer_match.group(1).strip() if answer_match else "" - - conf_match = re.search( - r"\s*(high|medium|low)\s*", - content, re.IGNORECASE, - ) - confidence = conf_match.group(1).lower() if conf_match else "low" - - return answer, confidence, content - async def _deep_structured_reasoning( self, query: str, @@ -6587,23 +6486,23 @@ async def _deep_structured_reasoning( artifacts: Any, context: "SearchContext", ) -> Tuple[str, Optional["KnowledgeCluster"]]: - """Orchestrate the full Deep Structured Reasoning pipeline. + """Orchestrate the Deep Structured Reasoning pipeline. Phases: - 1. Section map — build from tree index top layers - 2. Section select — LLM picks relevant sections - 3. Targeted extraction — pull pages + tables for selected sections - 4. Structured data — LLM extracts key-value data - 5. CoT reasoning + verification - 6. Missing data recovery (conditional, up to N rounds) - - Returns ``(answer, cluster)`` suitable for the DEEP search return path. + 1. Section map — build from tree index top layers (no LLM) + 2. Section select — LLM picks relevant sections (1 LLM) + 3. Targeted extraction — pull pages + tables for sections (no LLM) + 4. Synthesis — ROI_RESULT_SUMMARY on targeted evidence (1 LLM) + 5. Recovery — if refused, expand sections and re-synthesize + + Returns the raw LLM output (compatible with ``_parse_summary_response``) + and a cluster for persistence. """ indexer = self._get_tree_indexer() if indexer is None: return "", None - all_structured: List[str] = [] + all_evidence_parts: List[str] = [] for fp in tree_files[: self._DEEP_STRUCTURED_MAX_FILES]: fname = Path(fp).name @@ -6611,7 +6510,6 @@ async def _deep_structured_reasoning( if tree is None or tree.root is None: continue - # Phase 1: Section map section_map, sections_meta = self._build_section_map( tree.root, max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH, ) @@ -6619,11 +6517,10 @@ async def _deep_structured_reasoning( continue await self._logger.info( - f"[DeepStructured] Section map for {fname}: " + f"[DeepSR] Section map for {fname}: " f"{len(sections_meta)} sections" ) - # Phase 2: LLM selects relevant sections selected = await self._select_evidence_sections( query, section_map, sections_meta, ) @@ -6632,125 +6529,136 @@ async def _deep_structured_reasoning( continue await self._logger.info( - f"[DeepStructured] Selected {len(selected)} sections: " + f"[DeepSR] Selected {len(selected)} sections: " f"{[s['title'][:30] for s in selected]}" ) - # Phase 3: Targeted page extraction raw_evidence = await self._extract_targeted_pages( fp, selected, query, ) if not raw_evidence: continue - # Per-file evidence accumulator for recovery rounds - file_raw_parts: List[str] = [raw_evidence] - await self._logger.info( - f"[DeepStructured] Extracted {len(raw_evidence)} chars " - f"from {fname}" + f"[DeepSR] Extracted {len(raw_evidence)} chars from {fname}" ) - # Phase 4: Structured data extraction - structured, completeness, missing_items = ( - await self._extract_structured_data(query, raw_evidence) - ) - context.increment_loop() + all_evidence_parts.append(f"[Source: {fname}]\n{raw_evidence}") - await self._logger.info( - f"[DeepStructured] Data extraction: " - f"completeness={completeness}, " - f"missing={len(missing_items)}" - ) + if not all_evidence_parts: + return "", None - # Phase 4.5: Missing data recovery (per-file loop) - recovery_round = 0 - while ( - missing_items - and completeness != "complete" - and recovery_round < self._DEEP_MAX_RECOVERY_ROUNDS - ): - recovery_round += 1 - await self._logger.info( - f"[DeepStructured] Recovery round {recovery_round}: " - f"seeking {missing_items}" - ) + combined_evidence = "\n\n---\n\n".join(all_evidence_parts) - recovery_query = ( - f"{query} — specifically find: " - f"{', '.join(missing_items[:5])}" - ) + # Build document context from artifacts when available + doc_context: Optional[str] = None + if artifacts and artifacts.catalog_map: + ctx_parts = [ + self._build_answer_context(fp, artifacts) + for fp in tree_files[: self._DEEP_STRUCTURED_MAX_FILES] + ] + ctx_parts = [c for c in ctx_parts if c] + if ctx_parts: + doc_context = "\n".join(ctx_parts) - recovery_selected = await self._select_evidence_sections( - recovery_query, section_map, sections_meta, - ) - context.increment_loop() + # Synthesize answer using the unified ROI prompt + if doc_context: + from sirchmunk.llm.prompts import ROI_RESULT_SUMMARY_WITH_CONTEXT + synth_prompt = ROI_RESULT_SUMMARY_WITH_CONTEXT.format( + user_input=query, + text_content=combined_evidence, + document_context=doc_context, + ) + else: + synth_prompt = ROI_RESULT_SUMMARY.format( + user_input=query, + text_content=combined_evidence, + ) - existing_ids = {s["node_id"] for s in selected} - new_sections = [ - s for s in recovery_selected - if s["node_id"] not in existing_ids - ] - if not new_sections: - break + resp = await self.llm.achat( + messages=[{"role": "user", "content": synth_prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) + context.increment_loop() + + raw_response = resp.content or "" + _, _, should_answer = self._parse_summary_response(raw_response) - recovery_evidence = await self._extract_targeted_pages( - fp, new_sections, recovery_query, + await self._logger.info( + f"[DeepSR] Synthesis complete: should_answer={should_answer}, " + f"len={len(raw_response)}" + ) + + # Recovery: if the answer is a refusal, try expanding sections + if (not should_answer or self._is_refusal_answer(raw_response[:500])): + for recovery_round in range(1, self._DEEP_MAX_RECOVERY_ROUNDS + 1): + await self._logger.info( + f"[DeepSR] Recovery round {recovery_round}" ) - if not recovery_evidence: + expanded_parts: List[str] = list(all_evidence_parts) + found_new = False + for fp in tree_files[: self._DEEP_STRUCTURED_MAX_FILES]: + tree = indexer.load_tree(fp) + if tree is None or tree.root is None: + continue + section_map, sections_meta = self._build_section_map( + tree.root, + max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH + 1, + ) + if not sections_meta: + continue + recovery_selected = await self._select_evidence_sections( + query, section_map, sections_meta, + ) + context.increment_loop() + if not recovery_selected: + continue + recovery_ev = await self._extract_targeted_pages( + fp, recovery_selected, query, + ) + if recovery_ev and recovery_ev not in combined_evidence: + expanded_parts.append( + f"[Recovery source: {Path(fp).name}]\n{recovery_ev}" + ) + found_new = True + if not found_new: break - file_raw_parts.append(recovery_evidence) - - combined = "\n\n".join(file_raw_parts) - structured, completeness, missing_items = ( - await self._extract_structured_data( - query, - combined[: self._DEEP_STRUCTURED_MAX_CHARS], + combined_evidence = "\n\n---\n\n".join(expanded_parts) + if doc_context: + synth_prompt = ROI_RESULT_SUMMARY_WITH_CONTEXT.format( + user_input=query, + text_content=combined_evidence[ + : self._DEEP_STRUCTURED_MAX_CHARS + ], + document_context=doc_context, ) + else: + synth_prompt = ROI_RESULT_SUMMARY.format( + user_input=query, + text_content=combined_evidence[ + : self._DEEP_STRUCTURED_MAX_CHARS + ], + ) + resp = await self.llm.achat( + messages=[{"role": "user", "content": synth_prompt}], + stream=True, ) + self.llm_usages.append(resp.usage) context.increment_loop() - selected.extend(new_sections) - - if structured: - all_structured.append(f"[Source: {fname}]\n{structured}") - - if not all_structured: - return "", None - - # Phase 5: CoT reasoning with verification - combined_data = "\n\n".join(all_structured) - answer, confidence, full_reasoning = await self._reason_with_verification( - query, combined_data, - ) - context.increment_loop() - - await self._logger.info( - f"[DeepStructured] Reasoning complete: " - f"confidence={confidence}, answer_len={len(answer)}" - ) + raw_response = resp.content or "" + _, _, should_answer = self._parse_summary_response(raw_response) + if should_answer and not self._is_refusal_answer( + raw_response[:500] + ): + break - # Build a cluster for persistence (strip XML tags for clean content) - _clean_reasoning = re.sub( - r"", - "", full_reasoning, - ).strip() cluster = self._make_answer_cluster( - query, _clean_reasoning, "DSR", + query, combined_evidence[:5000], "DSR", file_paths=tree_files[: self._DEEP_STRUCTURED_MAX_FILES], ) - if not answer: - return "", cluster - - formatted = ( - f"\n{_clean_reasoning}\n\n" - f"{answer}\n" - f"true\n" - f"{'true' if confidence != 'low' else 'false'}" - f"" - ) - - return formatted, cluster + return raw_response, cluster async def _react_refinement( self, From e8edb2f780eb04f5a4560cc21dddbfd18e4d0da1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 13 May 2026 21:35:02 +0800 Subject: [PATCH 59/70] add tree navi for react loop in DEEP mode --- src/sirchmunk/agentic/tools.py | 117 +++++++++++++++++++++++++++++++++ src/sirchmunk/search.py | 32 +++++++-- 2 files changed, 144 insertions(+), 5 deletions(-) diff --git a/src/sirchmunk/agentic/tools.py b/src/sirchmunk/agentic/tools.py index b13e762..c79cf5a 100644 --- a/src/sirchmunk/agentic/tools.py +++ b/src/sirchmunk/agentic/tools.py @@ -568,3 +568,120 @@ async def execute( ) return result_text, {"query": query, "clusters_found": len(clusters)} + + +# --------------------------------------------------------------------------- +# Tool 5: Tree Navigation (medium cost — LLM-guided tree index navigation) +# --------------------------------------------------------------------------- + +class TreeNavigationTool(BaseTool): + """Navigate a document's compiled tree index to extract targeted evidence. + + Uses an LLM-driven tree navigation strategy: the model selects + the most relevant branches/sections from a hierarchical document + index, then extracts the corresponding page or char-range content. + + This tool bridges the gap between keyword search (which finds + *where* a term appears) and file read (which returns *everything*). + Tree navigation returns the most relevant *sections* of a document + without reading the whole file. + + Requires compile artifacts (tree indices) to be available for the + target files. + """ + + def __init__( + self, + navigate_fn: Any, + available_paths: Optional[set] = None, + max_chars: int = 15_000, + ) -> None: + self._navigate_fn = navigate_fn + self._available_paths = available_paths or set() + self._max_chars = max_chars + + @property + def name(self) -> str: + return "tree_navigate" + + def get_schema(self) -> Dict[str, Any]: + return { + "name": self.name, + "description": ( + "Navigate a document's compiled tree index to extract " + "targeted sections relevant to the query. More precise " + "than file_read — returns only relevant sections instead " + "of the entire file. Works with PDF, DOCX, and other " + "compiled document types. Medium token cost." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": ( + "Absolute path of the document to navigate." + ), + }, + "query": { + "type": "string", + "description": ( + "What information to look for in the document." + ), + }, + }, + "required": ["file_path", "query"], + }, + } + + async def execute( + self, + context: SearchContext, + **kwargs, + ) -> Tuple[str, Dict[str, Any]]: + file_path: str = kwargs.get("file_path", "") + query: str = kwargs.get("query", "") + if not file_path or not query: + return "file_path and query are required.", {} + + if ( + self._available_paths + and file_path not in self._available_paths + ): + return ( + f"No tree index available for {Path(file_path).name}. " + "Use file_read instead." + ), {"file_path": file_path, "indexed": False} + + try: + result = await self._navigate_fn( + file_path, query, max_chars=self._max_chars, + ) + except Exception as exc: + return ( + f"Tree navigation failed: {exc}" + ), {"file_path": file_path, "error": str(exc)} + + if not result: + return ( + f"No relevant sections found in " + f"{Path(file_path).name} for this query." + ), {"file_path": file_path, "chars": 0} + + total_chars = len(result) + approx_tokens = total_chars // 4 + context.add_log( + tool_name=self.name, + tokens=approx_tokens, + metadata={ + "file_path": file_path, + "chars": total_chars, + }, + ) + + header = f"[Tree navigation: {Path(file_path).name}]" + return f"{header}\n{result}", { + "file_path": file_path, + "chars": total_chars, + "tokens": approx_tokens, + } diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index e5492da..b61e0a1 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -1294,6 +1294,7 @@ def _ensure_tool_registry( FileReadTool, KeywordSearchTool, KnowledgeQueryTool, + TreeNavigationTool, ToolRegistry, ) @@ -1336,12 +1337,23 @@ def _ensure_tool_registry( from sirchmunk.scan.dir_scanner import DirectoryScanner if self._dir_scanner is None: - self._dir_scanner = DirectoryScanner(llm=self.llm, max_files=500) + self._dir_scanner = DirectoryScanner( + llm=self.llm, max_files=500, + ) registry.register(DirScanTool( scanner=self._dir_scanner, paths=paths, )) + # Tool 5: Tree navigation (when compile artifacts exist) + artifacts = self._detect_compile_artifacts(paths) + if artifacts and artifacts.tree_available_paths: + registry.register(TreeNavigationTool( + navigate_fn=self._tree_guided_sample, + available_paths=artifacts.tree_available_paths, + max_chars=self._FAST_MAX_EVIDENCE_CHARS, + )) + self._tool_registry = registry self._tool_registry_key = cache_key return registry @@ -2000,10 +2012,20 @@ async def _search_deep( await self._logger.info( "[Phase 4:Fallback] Launching ReAct refinement" ) - react_spec = ( - f"{spec_context}\n\n{graph_ctx}" - if graph_ctx else spec_context - ) + # Seed ReAct with all available prior context so it + # doesn't start from scratch. + react_parts: List[str] = [] + if spec_context: + react_parts.append(spec_context) + if graph_ctx: + react_parts.append(graph_ctx) + if _pre_nav_evidence: + nav_seed = "\n\n".join( + f"[Pre-navigated: {Path(fp).name}]\n{ev}" + for fp, ev in _pre_nav_evidence.items() + ) + react_parts.append(nav_seed) + react_spec = "\n\n".join(react_parts) react_answer, context = await self._react_refinement( query=query, paths=paths, initial_keywords=initial_keywords, From 5c359b249fb0c6af938a677a9a0913b1e37049b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 14 May 2026 12:45:12 +0800 Subject: [PATCH 60/70] update deep --- benchmarks/financebench/judge.py | 22 +++- src/sirchmunk/llm/prompts.py | 14 ++- src/sirchmunk/search.py | 206 ++++++++++++++++++++++++++++--- 3 files changed, 218 insertions(+), 24 deletions(-) diff --git a/benchmarks/financebench/judge.py b/benchmarks/financebench/judge.py index 8140669..1e5e1ca 100644 --- a/benchmarks/financebench/judge.py +++ b/benchmarks/financebench/judge.py @@ -440,14 +440,32 @@ def _validated_result( @staticmethod def _is_refusal(text: str) -> bool: - """Quick check whether *text* looks like a refusal / non-answer.""" + """Quick check whether *text* looks like a refusal / non-answer. + + When the text contains an explicit ``**Answer: xxx**`` marker, + only the answer value is checked for refusal phrases so that + reasoning text containing phrases like "insufficient data" (as + analytical context) does not trigger a false positive. + """ if not text or not text.strip(): return True lower = text.strip().lower() if lower in ("unknown", "n/a", "none", ""): return True + + # If there is an explicit **Answer: xxx** marker, only check that value + answer_match = re.search(r'\*\*answer:\s*(.+?)\*\*', lower) + if answer_match: + answer_val = answer_match.group(1).strip() + for phrase in _REFUSAL_INDICATORS: + if phrase in answer_val: + return True + return False + + # No structured answer marker — check the leading portion only + check_region = lower[:300] for phrase in _REFUSAL_INDICATORS: - if phrase in lower: + if phrase in check_region: return True return False diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 17c7e9f..4b4c3bb 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -99,8 +99,8 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: # Define granularity characteristics if i == 1: granularity = "Coarse-grained" - desc_text = "Multi-word phrases, compound expressions, broader concepts" - examples = '"machine learning algorithms", "data processing pipeline", "neural network training"' + desc_text = "Multi-word phrases (2-3 words) that are likely to appear **verbatim** in the target document. Prioritize standard domain terminology (e.g. financial statement headings, technical section titles)" + examples = '"capital expenditure", "net income", "accounts payable", "operating cash flow", "total revenue"' elif i == num_levels: granularity = "Fine-grained" desc_text = "Single words, precise terms, atomic concepts" @@ -517,9 +517,15 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: ### Instructions 1. Identify which sections contain data needed to answer the query. -2. For questions requiring computation (ratios, growth rates, comparisons), select ALL sections containing the required input data. +2. For questions requiring computation (ratios, growth rates, comparisons), select ALL sections containing the required input data — even if you think some may be redundant. 3. Prefer sections containing structured data (tables, financial statements) over narrative sections. -4. Select 1-5 sections. Fewer is better if you are confident. +4. For financial/annual report queries, ALWAYS include sections matching these types when available: + - Income Statement / Consolidated Statements of Operations (revenue, expenses, net income) + - Balance Sheet / Consolidated Balance Sheets (assets, liabilities, equity) + - Cash Flow Statement / Consolidated Statements of Cash Flows (capex, operating cash flow) + - Notes to Financial Statements (breakdowns, segment data, detailed schedules) + - Management's Discussion and Analysis (context, trends, explanations) +5. Select 2-6 sections. When in doubt, select MORE rather than fewer — missing data causes answer failure. ### Output Return ONLY a JSON array of section indices (0-based) from the map above: diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index b61e0a1..1acc220 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -1048,8 +1048,24 @@ def _parse_summary_response(cls, llm_response: str) -> Tuple[str, bool, bool]: if not summary: summary = llm_response.strip() - should_answer = False - should_save = False + # Fallback: detect **Answer: xxx** markdown format used by models + # that ignore / tags (e.g. qwen). + _answer_match = re.search( + r'\*\*Answer:\s*(.+?)\*\*', llm_response, re.DOTALL, + ) + if _answer_match: + _answer_val = _answer_match.group(1).strip() + if _answer_val and not cls._is_refusal_answer(_answer_val): + should_answer = True + should_save = True + if not precise: + precise = _answer_val + else: + should_answer = False + should_save = False + else: + should_answer = False + should_save = False # P3: Never persist refusal/no-data answers to cluster cache if should_save and cls._is_refusal_answer(precise or summary): @@ -1943,7 +1959,7 @@ async def _search_deep( f"[Phase 4] Launching structured reasoning for " f"{len(_sr_files)} tree-indexed files" ) - sr_answer, sr_cluster = await self._deep_structured_reasoning( + sr_answer, sr_cluster, sr_evidence = await self._deep_structured_reasoning( query, _sr_files, artifacts, context, ) @@ -1952,7 +1968,7 @@ async def _search_deep( sr_answer ) accepted, accept_reason = self._evaluate_evidence_acceptance( - query, sr_answer, should_answer, + query, sr_evidence or sr_answer, should_answer, ) await self._logger.info( f"[Phase 4] Structured reasoning: " @@ -2001,7 +2017,28 @@ async def _search_deep( await self._summarise_cluster_fallback(query) ) else: - return _NO_RESULTS_MESSAGE, None, context + # DEEP self-correction before giving up + sc_evidence = await self._deep_self_correct( + query, merged_files, query_keywords, context, + ) + if sc_evidence: + sc_cluster = self._make_answer_cluster( + query, sc_evidence[:5000], "DSC", + file_paths=list(merged_files)[:3], + ) + sc_cluster.content = sc_evidence + answer, should_save, should_answer = ( + await self._summarise_cluster(query, sc_cluster) + ) + sc_accepted, _ = self._evaluate_evidence_acceptance( + query, sc_evidence, should_answer, + ) + if sc_accepted: + cluster = sc_cluster + else: + return _NO_RESULTS_MESSAGE, None, context + else: + return _NO_RESULTS_MESSAGE, None, context if not cluster.search_results: cluster.search_results = list(merged_files) elif llm_fallback: @@ -2062,7 +2099,21 @@ async def _search_deep( await self._summarise_cluster_fallback(query) ) else: - return _NO_RESULTS_MESSAGE, None, context + sc_evidence = await self._deep_self_correct( + query, merged_files, query_keywords, context, + ) + if sc_evidence: + sc_cluster = self._make_answer_cluster( + query, sc_evidence[:5000], "DSC", + file_paths=list(merged_files)[:3], + ) + sc_cluster.content = sc_evidence + answer, should_save, _ = ( + await self._summarise_cluster(query, sc_cluster) + ) + cluster = sc_cluster + else: + return _NO_RESULTS_MESSAGE, None, context # Sync LLM token accounting into context new_usages = self.llm_usages[_llm_usage_start:] @@ -2470,13 +2521,13 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Expanded tree sample sections for same-file re-sampling (default uses 5).""" # --- Deep Structured Reasoning --- - _DEEP_SECTION_MAP_MAX_DEPTH: int = 2 + _DEEP_SECTION_MAP_MAX_DEPTH: int = 3 """Maximum tree depth for section map construction (top-N layers).""" _DEEP_MAX_EXTRACT_PAGES: int = 12 """Maximum pages to extract per file in targeted page extraction.""" _DEEP_STRUCTURED_MAX_CHARS: int = 30_000 """Maximum character budget for structured evidence per file.""" - _DEEP_MAX_RECOVERY_ROUNDS: int = 2 + _DEEP_MAX_RECOVERY_ROUNDS: int = 3 """Maximum rounds of missing-data recovery before final answer.""" _DEEP_STRUCTURED_MAX_FILES: int = 3 """Maximum files to process through structured reasoning pipeline.""" @@ -5345,6 +5396,11 @@ async def _probe_keywords( Also extracts cross-lingual alternative keywords from the ```` block and merges them into the result list. + Additionally synthesises rga-friendly compound phrases from + Level 1 keywords so that downstream ``_retrieve_by_keywords`` + tries exact multi-word matches before falling back to atomic + terms (mirrors the strategy used by FAST mode). + Returns: Tuple of (keyword_idf_dict, keyword_list). """ @@ -5368,9 +5424,26 @@ async def _probe_keywords( for kw_set in keyword_sets: if kw_set: merged = {**kw_set, **alt_keywords} - kw_list = list(merged.keys()) - await self._logger.info(f"[Probe:Keywords] Extracted: {kw_list}") - return merged, kw_list + # Synthesise rga-friendly compound phrases: promote + # multi-word Level-1 keywords to the front with boosted + # IDF so _retrieve_by_keywords tries them first as exact + # phrases (similar to FAST's primary/fallback strategy). + compound_phrases: Dict[str, float] = {} + atomic_terms: Dict[str, float] = {} + for kw, idf in merged.items(): + if " " in kw.strip() and len(kw.split()) >= 2: + compound_phrases[kw] = max(idf, 7.0) + else: + atomic_terms[kw] = idf + # Compounds first, then atomics — preserves ordering for + # _retrieve_by_keywords which iterates keywords in order. + ordered = {**compound_phrases, **atomic_terms} + kw_list = list(ordered.keys()) + await self._logger.info( + f"[Probe:Keywords] Extracted: {kw_list} " + f"(compounds={len(compound_phrases)})" + ) + return ordered, kw_list if alt_keywords: return alt_keywords, list(alt_keywords.keys()) @@ -6507,7 +6580,7 @@ async def _deep_structured_reasoning( tree_files: List[str], artifacts: Any, context: "SearchContext", - ) -> Tuple[str, Optional["KnowledgeCluster"]]: + ) -> Tuple[str, Optional["KnowledgeCluster"], str]: """Orchestrate the Deep Structured Reasoning pipeline. Phases: @@ -6517,12 +6590,14 @@ async def _deep_structured_reasoning( 4. Synthesis — ROI_RESULT_SUMMARY on targeted evidence (1 LLM) 5. Recovery — if refused, expand sections and re-synthesize - Returns the raw LLM output (compatible with ``_parse_summary_response``) - and a cluster for persistence. + Returns ``(raw_llm_output, cluster, combined_evidence)`` where + *combined_evidence* is the raw document text fed to the LLM so + callers can use it for evidence-acceptance checks instead of + the LLM's answer text. """ indexer = self._get_tree_indexer() if indexer is None: - return "", None + return "", None, "" all_evidence_parts: List[str] = [] @@ -6568,7 +6643,7 @@ async def _deep_structured_reasoning( all_evidence_parts.append(f"[Source: {fname}]\n{raw_evidence}") if not all_evidence_parts: - return "", None + return "", None, "" combined_evidence = "\n\n---\n\n".join(all_evidence_parts) @@ -6626,7 +6701,7 @@ async def _deep_structured_reasoning( continue section_map, sections_meta = self._build_section_map( tree.root, - max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH + 1, + max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH + recovery_round, ) if not sections_meta: continue @@ -6680,7 +6755,102 @@ async def _deep_structured_reasoning( file_paths=tree_files[: self._DEEP_STRUCTURED_MAX_FILES], ) - return raw_response, cluster + return raw_response, cluster, combined_evidence + + async def _deep_self_correct( + self, + query: str, + merged_files: List[str], + query_keywords: Dict[str, float], + context: "SearchContext", + ) -> Optional[str]: + """Gather alternative evidence when DEEP Phase 4 answer is rejected. + + Four strategies tried in order, stopping at first success: + A) Expanded tree-guided sampling on the primary file. + B) rga keyword window extraction on primary files using + Phase-1 keywords (reuses the rga infrastructure). + C) Semantically similar cluster from knowledge storage. + D) Tree-guided sampling on secondary merged files. + + Returns alternative evidence text, or ``None`` when every + strategy fails. + """ + primary_files = merged_files[:2] + secondary_files = merged_files[2:5] + + # Strategy A: expanded tree sampling on primary file + for fp in primary_files: + expanded_ev = await self._tree_guided_sample( + fp, query, + max_chars=self._FAST_MAX_EVIDENCE_CHARS * 2, + ) + if isinstance(expanded_ev, str) and len(expanded_ev.strip()) > 100: + await self._logger.info( + "[DEEP:SelfCorrect] Strategy A succeeded: " + f"expanded tree sample from {Path(fp).name}" + ) + return expanded_ev + + # Strategy B: tree-navigated evidence with expanded parameters + for fp in primary_files: + try: + nav_ev = await self._navigate_tree_for_evidence( + fp, query, + max_results=self._SELF_CORRECT_EXPANDED_NAV_RESULTS, + ) + if nav_ev and len(nav_ev.strip()) > 100: + await self._logger.info( + "[DEEP:SelfCorrect] Strategy B succeeded: " + f"expanded tree navigation on {Path(fp).name}" + ) + return nav_ev + except Exception: + pass + + # Strategy C: semantically similar cluster from knowledge storage + if self.embedding_client and self.knowledge_storage: + try: + qe = self.embedding_client.encode(query) + if qe is not None: + vec = qe.tolist() if hasattr(qe, "tolist") else list(qe) + hits = await self.knowledge_storage.search_similar_clusters( + query_embedding=vec, top_k=2, similarity_threshold=0.50, + ) + if hits: + parts: List[str] = [] + for h in hits[:2]: + c = await self.knowledge_storage.get(h["id"]) + if c and c.content: + parts.append(str(c.content)[:3000]) + for ev in (c.evidences or [])[:3]: + for s in (ev.snippets or [])[:2]: + parts.append(s[:500]) + if parts: + await self._logger.info( + "[DEEP:SelfCorrect] Strategy C succeeded: " + "knowledge storage cluster" + ) + return "\n\n---\n\n".join(parts) + except Exception: + pass + + # Strategy D: tree sampling on secondary files + for fp in secondary_files: + tree_ev = await self._tree_guided_sample( + fp, query, + max_chars=self._FAST_MAX_EVIDENCE_CHARS, + ) + if isinstance(tree_ev, str) and len(tree_ev.strip()) > 100: + context.mark_file_read(fp) + await self._logger.info( + "[DEEP:SelfCorrect] Strategy D succeeded: " + f"secondary file {Path(fp).name}" + ) + return tree_ev + + await self._logger.info("[DEEP:SelfCorrect] All strategies exhausted") + return None async def _react_refinement( self, From 9b3447873c7a57cfb13c3f01438d188c290e9466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 14 May 2026 19:21:07 +0800 Subject: [PATCH 61/70] refactor deep mode --- src/sirchmunk/llm/prompts.py | 72 +++ src/sirchmunk/search.py | 1030 +++++++++++++++++++++------------- 2 files changed, 710 insertions(+), 392 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 4b4c3bb..dcfbdc6 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -533,6 +533,78 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: """ +# --------------------------------------------------------------------------- +# DEEP mode: Question Decomposition +# --------------------------------------------------------------------------- + +DEEP_QUESTION_DECOMPOSE = """Analyze the user query and decompose it into a structured plan for document-based answering. + +### User Query +{query} + +### Available Evidence Summary +{evidence_summary} + +### Output +Return JSON only, no extra text: +{{ + "query_type": "lookup|calculation|comparison|synthesis", + "sub_questions": ["sub-question 1", "sub-question 2"], + "required_data": ["data point 1", "data point 2"], + "time_periods": ["FY2021", "FY2022"], + "entities": ["Company A"], + "calculation_steps": ["step 1: find X", "step 2: compute Y = X / Z"] +}} + +Rules: +- **query_type**: "lookup" for direct fact retrieval; "calculation" for queries needing arithmetic (ratios, growth rates, differences); "comparison" for year-over-year or entity-vs-entity; "synthesis" for multi-fact integration. +- **sub_questions**: Break compound queries into atomic retrievable questions. Single-fact queries get one sub-question. +- **required_data**: Specific data points needed from the document (e.g. "FY2022 total revenue", "FY2021 net income"). +- **time_periods**: Fiscal years, quarters, or date ranges mentioned or implied. +- **entities**: Company names, subsidiaries, product lines, or segments referenced. +- **calculation_steps**: For "calculation" type only — ordered steps. For other types, empty array. +""" + + +# --------------------------------------------------------------------------- +# DEEP mode: Calculation-Aware Synthesis +# --------------------------------------------------------------------------- + +DEEP_CALCULATION_SYNTHESIS = """ +### Task +Answer the user's question by performing precise calculations on the provided evidence. + +### Constraints +1. **Language Continuity**: Reply in the SAME language as the User Input. +2. **Computation-first**: Extract ALL required numbers from the evidence BEFORE computing. List each number with its source (page, table, section). +3. **Show work**: Write out each calculation step explicitly. Use the format: `variable = value (source)`. +4. **Unit consistency**: Verify all numbers use compatible units before computing. Convert if needed — state the conversion. +5. **Rounding**: Match the precision implied by the query. For percentages, use at most one decimal place. For dollar amounts, round to the nearest whole number in the stated unit. +6. **Cross-check**: After computing, verify the result by a different method or sanity check (e.g. "Revenue growth of 50% seems high — let me re-verify the base figures"). +7. **Best-effort**: Compute from whatever relevant data is available. Only refuse when evidence contains NO related numbers at all. + +### Calculation Plan +{calculation_steps} + +### Input Data +- **User Input**: {user_input} +- **Evidence**: {text_content} + +### Output Format + +[List all extracted values with sources, then show each calculation step] + + +[Concise Markdown summary of the analysis and result] + + +[Final numeric answer only, matching the query's expected format] + +true/false +true/false +""" + + # --------------------------------------------------------------------------- # Knowledge Compile prompts # --------------------------------------------------------------------------- diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 1acc220..b717aa4 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -27,6 +27,8 @@ DOC_CHUNK_SUMMARY, DOC_MERGE_SUMMARIES, DEEP_SECTION_SELECT, + DEEP_QUESTION_DECOMPOSE, + DEEP_CALCULATION_SYNTHESIS, ) from sirchmunk.retrieve.text_retriever import GrepRetriever from sirchmunk.schema.knowledge import ( @@ -182,6 +184,29 @@ class CompileHints: extra_keywords: List[str] +@dataclass +class DeepRetrieval: + """Structured output of DEEP Stage 0: file discovery and ranking.""" + + file_paths: List[str] + keywords: List[str] + keyword_idfs: Dict[str, float] + catalog_routed: List[str] + tree_probed: List[str] + + +@dataclass +class DeepDecomposition: + """Structured output of DEEP Stage 2: question analysis.""" + + query_type: str # "lookup", "calculation", "comparison", "synthesis" + sub_questions: List[str] = field(default_factory=list) + required_data: List[str] = field(default_factory=list) + time_periods: List[str] = field(default_factory=list) + entities: List[str] = field(default_factory=list) + calculation_steps: List[str] = field(default_factory=list) + + @dataclass class CompileArtifacts: """Compile artifact availability context for adaptive activation in FAST mode. @@ -1670,7 +1695,7 @@ async def search( return answer # ------------------------------------------------------------------ - # DEEP mode — parallel multi-path retrieval with ReAct fallback + # DEEP mode — staged evidence-first pipeline # ------------------------------------------------------------------ async def _search_deep( @@ -1688,7 +1713,15 @@ async def _search_deep( spec_stale_hours: float = 72.0, llm_fallback: bool = False, ) -> Tuple[str, Optional[KnowledgeCluster], SearchContext]: - """Parallel multi-path retrieval pipeline (Phases 0a–5). + """Evidence-first DEEP pipeline: retrieve → saturate → decompose → synthesize. + + Stages: + 0. FAST-style retrieval (1 LLM: query analysis + catalog routing) + 1. Evidence saturation (tree nav + table digest + rga per file) + 2. Question decomposition (1 LLM: classify + plan) + 3. Evidence adequacy check (0 LLM: rule-based gap detection) + 4. Strategy-routed synthesis (1 LLM: answer generation) + 5. Persistence (quality-gated cluster save) Returns: ``(answer, cluster, context)`` tuple. @@ -1699,451 +1732,652 @@ async def _search_deep( ) _llm_usage_start = len(self.llm_usages) - # --- Adaptive compile artifact detection (shared with FAST) --- _scope = _PathScope(paths) artifacts = self._detect_compile_artifacts(paths) + self._tree_nav_cache = _TreeNavCache() - # ============================================================== - # Phase 0a: Direct document analysis (intent-gated short-circuit) - # ============================================================== + # --- Short-circuits (Phase 0a + Phase 0) --- direct = await self._try_direct_doc_analysis(query, paths) if direct is not None: return direct, self._make_answer_cluster(query, direct, "DQ", file_paths=paths), context - # ============================================================== - # Phase 0: Cluster reuse (instant short-circuit) - # When reuse_knowledge=True and a similar cluster is found, we - # return here — Phase 5 (Persistence) is not executed for that path. - # ============================================================== reused = await self._try_reuse_cluster(query, paths) if reused is not None: return self._enrich_reused_content(reused), reused, context - # P2: gradient reuse — extract hints from moderately similar clusters - soft_hit = await self._try_soft_reuse(query, paths) + await self._logger.info(f"[DEEP] Starting evidence-first pipeline for: '{query[:80]}'") - await self._logger.info(f"[search] Starting multi-path retrieval for: '{query[:80]}'") + # ==================== Stage 0: Retrieval ==================== + retrieval = await self._deep_retrieve( + query, paths, artifacts, _scope, context, + top_k_files=top_k_files, enable_dir_scan=enable_dir_scan, + max_depth=max_depth, include=include, exclude=exclude, + ) - # ============================================================== - # Phase 1: Parallel probing — five paths fire concurrently - # ============================================================== - await self._logger.info("[Phase 1] Parallel probing: keywords + dir_scan + knowledge + spec_cache + tree_index") - context.increment_loop() + if not retrieval.file_paths: + if llm_fallback: + answer, _ = await self._summarise_cluster_fallback(query) + return answer, None, context + return _NO_RESULTS_MESSAGE, None, context + + # ==================== Stage 1: Evidence saturation ==================== + evidence = await self._deep_gather_evidence( + query, retrieval, artifacts, context, + ) + + if not evidence or len(evidence.strip()) < 50: + if llm_fallback: + answer, _ = await self._summarise_cluster_fallback(query) + return answer, None, context + return _NO_RESULTS_MESSAGE, None, context + + # ==================== Stage 2: Question decomposition ==================== + decomposition = await self._deep_decompose_question(query, evidence, context) + + # ==================== Stage 3: Adequacy check + gap-fill ==================== + adequate, gaps = self._deep_check_adequacy(query, evidence, decomposition) + if not adequate and gaps: + await self._logger.info( + f"[DEEP:S3] Gaps detected ({len(gaps)}): {gaps[:3]}" + ) + extra = await self._deep_fill_evidence_gaps( + query, gaps, retrieval, artifacts, context, + ) + if extra: + evidence = f"{evidence}\n\n---\n\n{extra}" + + # ==================== Stage 4: Synthesis ==================== + answer, should_save, should_answer = await self._deep_synthesize( + query, evidence, decomposition, artifacts, retrieval, context, + ) + + # Self-correction: if synthesis rejected, try expanded evidence + if not answer: + await self._logger.info("[DEEP:S4] First synthesis rejected, trying self-correction") + sc_evidence = await self._deep_self_correct( + query, retrieval.file_paths, retrieval.keyword_idfs, context, + ) + if sc_evidence: + evidence = sc_evidence + answer, should_save, should_answer = await self._deep_synthesize( + query, sc_evidence, decomposition, artifacts, retrieval, context, + ) + + # Final fallback + if not answer: + if llm_fallback: + answer, should_save = await self._summarise_cluster_fallback(query) + else: + return _NO_RESULTS_MESSAGE, None, context + + # ==================== Stage 5: Verification ==================== + if answer and decomposition.query_type == "calculation": + answer, _ = self._deep_verify_answer(query, answer, evidence) + + # --- Token accounting --- + new_usages = self.llm_usages[_llm_usage_start:] + for usage in new_usages: + if usage and isinstance(usage, dict): + total_tok = usage.get("total_tokens", 0) + if total_tok == 0: + total_tok = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0) + context.add_llm_tokens(total_tok, usage=usage) + + # --- Persistence --- + cluster: Optional[KnowledgeCluster] = None + if should_save and answer: + cluster = self._make_answer_cluster( + query, evidence[:5000], "DEEP", + file_paths=retrieval.file_paths[:5], + ) + cluster.content = evidence[:10000] + self._add_query_to_cluster(cluster, query) + try: + await self._save_cluster_with_embedding(cluster) + except Exception as exc: + _loguru_logger.warning(f"[DEEP:S5] Cluster save failed: {exc}") + + await self._logger.success(f"[DEEP] Complete: {context.summary()}") + return answer, cluster, context + + # ------------------------------------------------------------------ + # DEEP v2: Staged pipeline methods + # ------------------------------------------------------------------ + + async def _deep_retrieve( + self, + query: str, + paths: List[str], + artifacts: "CompileArtifacts", + scope: "_PathScope", + context: "SearchContext", + *, + top_k_files: int = 5, + enable_dir_scan: bool = False, + max_depth: Optional[int] = 5, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ) -> "DeepRetrieval": + """Stage 0: FAST-style file discovery and ranking. + + Reuses the proven FAST retrieval pipeline: query analysis (1 LLM) + + keyword search (rga) + tree probe (scope-filtered) + catalog + routing + compile hints + summary index. Returns a structured + DeepRetrieval with ranked file paths. + """ + catalog = artifacts.catalog + catalog_routed_files: List[str] = [] + + tree_hints = "" + if artifacts and artifacts.tree_available_paths: + tree_hints = self._build_tree_root_hints(artifacts) + + if catalog: + listing = self._build_enriched_catalog_listing(catalog) + prompt = FAST_QUERY_ANALYSIS_WITH_CATALOG.format( + user_input=query, document_listing=listing, + ) + else: + prompt = FAST_QUERY_ANALYSIS.format(user_input=query) + if tree_hints: + prompt = prompt + tree_hints + + llm_task = self.llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, + ) + compile_task = self._probe_compile_hints([query], scope=scope) + tree_task = self._probe_tree_index(query, scope=scope, artifacts=artifacts) + summary_task = self._probe_summary_index(query, artifacts, scope=scope) + catalog_deep_task = self._probe_catalog_for_deep(query, artifacts) - phase1_results = await asyncio.gather( - self._probe_keywords(query), - self._probe_dir_scan(paths, enable_dir_scan), - self._probe_knowledge_cache(query), - self._load_spec_context(paths, stale_hours=spec_stale_hours), - self._probe_tree_index(query), - self._probe_compile_hints([query], scope=_scope), # query-level hints; keyword-level runs post-Phase 1 - self._probe_summary_index(query, artifacts, scope=_scope), # GAP 2: zero-LLM BM25 - self._probe_catalog_for_deep(query, artifacts), # GAP 4: zero-LLM keyword overlap + results = await asyncio.gather( + llm_task, compile_task, tree_task, summary_task, catalog_deep_task, return_exceptions=True, ) - kw_result = phase1_results[0] if not isinstance(phase1_results[0], Exception) else ({}, []) - scan_result = phase1_results[1] if not isinstance(phase1_results[1], Exception) else None - knowledge_probe = phase1_results[2] if not isinstance(phase1_results[2], Exception) else KnowledgeProbeResult([], [], "") - spec_context = phase1_results[3] if not isinstance(phase1_results[3], Exception) else "" - tree_hits = phase1_results[4] if not isinstance(phase1_results[4], Exception) else [] - compile_hints = phase1_results[5] if not isinstance(phase1_results[5], Exception) else CompileHints([], []) - summary_index_hits = phase1_results[6] if not isinstance(phase1_results[6], Exception) else [] - catalog_deep_hits = phase1_results[7] if not isinstance(phase1_results[7], Exception) else [] + resp = results[0] if not isinstance(results[0], Exception) else None + early_hints = results[1] if not isinstance(results[1], Exception) else CompileHints([], []) + tree_probed = results[2] if not isinstance(results[2], Exception) else [] + summary_hits = results[3] if not isinstance(results[3], Exception) else [] + catalog_deep_hits = results[4] if not isinstance(results[4], Exception) else [] - for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index", "compile_hints", "summary_index", "catalog_deep"]): - if isinstance(phase1_results[i], Exception): - await self._logger.warning(f"[Phase 1] {label} probe failed: {phase1_results[i]}") + for i, label in enumerate(["llm", "compile", "tree", "summary", "catalog"]): + if isinstance(results[i], Exception): + await self._logger.warning(f"[DEEP:S0] {label} failed: {results[i]}") - # Backwards compat: knowledge_probe may be a plain list from old code paths - if isinstance(knowledge_probe, list): - knowledge_probe = KnowledgeProbeResult(file_paths=knowledge_probe, extra_keywords=[], background_context="") + if resp and not isinstance(resp, Exception): + self.llm_usages.append(resp.usage) + if resp.usage and isinstance(resp.usage, dict): + context.add_llm_tokens( + resp.usage.get("total_tokens", 0), usage=resp.usage, + ) - query_keywords, initial_keywords = kw_result if isinstance(kw_result, tuple) else ({}, []) + analysis = self._parse_fast_json(resp.content if resp else "") + primary = analysis.get("primary", [])[:2] + fallback = analysis.get("fallback", [])[:3] + primary_alt = analysis.get("primary_alt", [])[:2] + fallback_alt = analysis.get("fallback_alt", [])[:3] + if primary_alt: + primary = primary + primary_alt + if fallback_alt: + fallback = fallback + fallback_alt + keyword_idfs: Dict[str, float] = analysis.get("idf", {}) + all_keywords = primary + fallback - # P2: inject soft-hit patterns into keywords - if soft_hit: - for p in soft_hit.patterns: - if p not in initial_keywords: - initial_keywords.append(p) - if p not in query_keywords: - query_keywords[p] = 0.6 - - # P3: inject extra keywords from structured knowledge probe - for kw in knowledge_probe.extra_keywords: - if kw not in initial_keywords: - initial_keywords.append(kw) - if kw not in query_keywords: - query_keywords[kw] = 0.5 - - # P2 + P3: append background context for Phase 4 LLM prompt - if soft_hit and soft_hit.context_summary: - spec_context = f"{spec_context}\n\n{soft_hit.context_summary}" if spec_context else soft_hit.context_summary - if knowledge_probe.background_context: - spec_context = f"{spec_context}\n\n{knowledge_probe.background_context}" if spec_context else knowledge_probe.background_context + if catalog: + for idx in analysis.get("selected_docs", []): + if isinstance(idx, int) and 0 <= idx < len(catalog): + fp = catalog[idx]["path"] + if Path(fp).exists(): + catalog_routed_files.append(fp) + kw_hints = await self._probe_compile_hints(all_keywords, scope=scope) + compile_hints = self._merge_compile_hints(early_hints, kw_hints) + for kw in compile_hints.extra_keywords: + if kw not in all_keywords: + all_keywords.append(kw) + keyword_idfs.setdefault(kw, 0.5) + + context.increment_loop() await self._logger.info( - f"[Phase 1] Results: keywords={len(initial_keywords)}, " - f"dir_scan={'OK' if scan_result else 'N/A'}, " - f"knowledge_files={len(knowledge_probe.file_paths)}, " - f"tree_hits={len(tree_hits)}, " - f"compile_hints={len(compile_hints.file_paths)}, " - f"summary_index={len(summary_index_hits)}, " - f"catalog_deep={len(catalog_deep_hits)}, " - f"soft_hit={'YES' if soft_hit else 'NO'}, " - f"spec_cache={'YES' if spec_context else 'NO'}" + f"[DEEP:S0] keywords={len(all_keywords)}, " + f"catalog_routed={len(catalog_routed_files)}, " + f"tree_probed={len(tree_probed)}, " + f"summary={len(summary_hits)}, " + f"catalog_deep={len(catalog_deep_hits)}" ) - # ============================================================== - # Phase 2: Parallel retrieval — keyword search + dir_scan rank - # ============================================================== - keyword_files: List[str] = [] - dir_scan_files: List[str] = [] + rga_kwargs = dict( + paths=paths, max_depth=max_depth, + include=list(include or []), exclude=exclude, + ) + tree_probed_set = frozenset(tree_probed) - if _PURE_TREE_SEARCH: - # Pure tree search mode: skip rga and dir_scan, rely solely on tree hits - await self._logger.info("[Phase 2:PureTree] Skipping rga keyword search and dir_scan") - context.increment_loop() - else: - await self._logger.info("[Phase 2] Parallel retrieval: rga keyword search + dir_scan LLM rank") - context.increment_loop() + best_files: Optional[List[Dict[str, Any]]] = None - phase2_tasks = [] + if catalog_routed_files and analysis.get("doc_confidence") == "high": + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in catalog_routed_files[:top_k_files] + ] - if initial_keywords: - phase2_tasks.append( - self._retrieve_by_keywords( - initial_keywords, paths, - max_depth=max_depth, include=include, exclude=exclude, - ) - ) - else: - phase2_tasks.append(self._async_noop([])) + if not best_files and tree_probed_set and primary: + best_files = await self._fast_find_best_file( + primary, paths=list(tree_probed_set), + top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + ) - if scan_result is not None and enable_dir_scan: - phase2_tasks.append( - self._rank_dir_scan_candidates(query, scan_result) - ) - else: - phase2_tasks.append(self._async_noop([])) + if not best_files and primary: + best_files = await self._fast_find_best_file( + primary, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + tree_probed_paths=tree_probed_set or None, + **rga_kwargs, + ) + + if not best_files and fallback: + best_files = await self._fast_find_best_file( + fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + tree_probed_paths=tree_probed_set or None, + **rga_kwargs, + ) - phase2_results = await asyncio.gather(*phase2_tasks, return_exceptions=True) + hint_files: List[str] = [] + seen: set = set() + for fp in catalog_routed_files + list(tree_probed) + summary_hits + catalog_deep_hits + compile_hints.file_paths: + if fp and fp not in seen: + seen.add(fp) + hint_files.append(fp) + + if not best_files and hint_files: + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in hint_files[:top_k_files] + ] - keyword_files = phase2_results[0] if not isinstance(phase2_results[0], Exception) else [] - dir_scan_files = phase2_results[1] if not isinstance(phase2_results[1], Exception) else [] + if not best_files and enable_dir_scan: + ranked = await self._scan_and_rank_paths( + query, paths, top_k=top_k_files, include_medium=True, + ) + if ranked: + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in ranked[:top_k_files] + ] - for i, label in enumerate(["keyword_search", "dir_scan_rank"]): - if isinstance(phase2_results[i], Exception): - await self._logger.warning(f"[Phase 2] {label} failed: {phase2_results[i]}") + file_paths: List[str] = [] + if best_files: + fp_set: set = set() + for bf in best_files: + fp = bf["path"] + if fp not in fp_set: + fp_set.add(fp) + file_paths.append(fp) + for fp in hint_files: + if fp not in fp_set and len(file_paths) < top_k_files: + fp_set.add(fp) + file_paths.append(fp) await self._logger.info( - f"[Phase 2] Results: keyword_files={len(keyword_files)}, " - f"dir_scan_files={len(dir_scan_files)}" + f"[DEEP:S0] Retrieved {len(file_paths)} files: " + f"{[Path(p).name for p in file_paths[:5]]}" + ) + return DeepRetrieval( + file_paths=file_paths, + keywords=all_keywords, + keyword_idfs=keyword_idfs, + catalog_routed=catalog_routed_files, + tree_probed=list(tree_probed), ) - # --- Phase 2.5: Parallel tree pre-navigation for top tree hits --- - _pre_nav_evidence: Dict[str, str] = {} - if tree_hits: - _nav_fps = [fp for fp in tree_hits[:self._DEEP_PRE_NAV_MAX_FILES]] - if _nav_fps: - _nav_results = await asyncio.gather( - *[self._tree_guided_sample( - fp, query, max_chars=self._FAST_MAX_EVIDENCE_CHARS, - ) for fp in _nav_fps], - return_exceptions=True, - ) - for fp, nav_res in zip(_nav_fps, _nav_results): - if isinstance(nav_res, Exception): - await self._logger.warning( - f"[Phase 2.5] Tree pre-nav failed for {Path(fp).name}: {nav_res}" + async def _deep_gather_evidence( + self, + query: str, + retrieval: "DeepRetrieval", + artifacts: "CompileArtifacts", + context: "SearchContext", + *, + max_chars: int = 80_000, + ) -> str: + """Stage 1: Evidence saturation for all retrieved files. + + Gathers evidence from multiple sources per file, in priority order: + 1. Tree navigation (LLM-guided section targeting) + 2. Table digest (pre-compiled structured tables) + 3. Tree-guided sampling (section-level content) + 4. rga keyword sampling (grep-based snippets) + + Runs tree navigation in parallel across files for efficiency. + """ + if not retrieval.file_paths: + return "" + + file_paths = retrieval.file_paths + tree_paths = artifacts.tree_available_paths if artifacts else set() + + async def _gather_for_file(fp: str) -> str: + parts: List[str] = [] + fname = Path(fp).name + + nav_ev = "" + if fp in tree_paths: + try: + nav_ev = await self._navigate_tree_for_evidence( + fp, query, + max_results=self._TREE_NAV_MAX_RESULTS, + ) or "" + except Exception: + pass + if nav_ev: + parts.append(nav_ev) + + table_ev = "" + try: + from sirchmunk.utils.file_utils import get_fast_hash + fh = get_fast_hash(fp) + if fh: + tables = self._load_table_digest(self.work_path, fh) + if tables: + budget = ( + self._TABLE_EVIDENCE_NAV_OVERLAP_CHARS if nav_ev + else self._TABLE_EVIDENCE_DEFAULT_CHARS ) - elif isinstance(nav_res, str) and nav_res: - _pre_nav_evidence[fp] = nav_res - if _pre_nav_evidence: - await self._logger.info( - f"[Phase 2.5] Pre-navigated {len(_pre_nav_evidence)} tree files" + table_ev = self._format_table_evidence( + tables, max_chars=budget, query=query, + ) or "" + except Exception: + pass + if table_ev: + parts.append(f"[{fname} - Table Evidence]\n{table_ev}") + + if not nav_ev and fp in tree_paths: + try: + tree_sample = await self._tree_guided_sample( + fp, query, + max_chars=self._FAST_MAX_EVIDENCE_CHARS, + artifacts=artifacts, ) + if tree_sample: + parts.append(tree_sample) + except Exception: + pass + + if not parts: + try: + rga_ev = await self._fast_sample_evidence(fp, []) + if rga_ev: + parts.append(rga_ev) + except Exception: + pass + + context.mark_file_read(fp) + if parts: + return f"[Source: {fname}]\n" + "\n\n".join(parts) + return "" + + tasks = [_gather_for_file(fp) for fp in file_paths] + results = await asyncio.gather(*tasks, return_exceptions=True) + + evidence_parts: List[str] = [] + total_chars = 0 + for r in results: + if isinstance(r, Exception): + continue + if r and total_chars < max_chars: + remaining = max_chars - total_chars + evidence_parts.append(r[:remaining]) + total_chars += len(evidence_parts[-1]) - # ============================================================== - # Phase 3: Merge file paths + build KnowledgeCluster - # P1 tree hits get highest priority; P2 soft-hit files next - # ============================================================== context.increment_loop() - extra_knowledge_files = knowledge_probe.file_paths - if soft_hit: - extra_knowledge_files = soft_hit.file_paths + extra_knowledge_files + combined = "\n\n---\n\n".join(evidence_parts) + await self._logger.info( + f"[DEEP:S1] Evidence: {len(combined)} chars from " + f"{len(evidence_parts)} files" + ) + return combined - if _PURE_TREE_SEARCH: - # Pure tree search: only use tree hits (+ soft-hit fallback if no tree hits) - pure_tree_files = list(tree_hits) - if not pure_tree_files and soft_hit: - pure_tree_files = soft_hit.file_paths - await self._logger.info( - f"[Phase 3:PureTree] No tree hits, using {len(pure_tree_files)} soft-hit files" - ) - merged_files = self._merge_file_paths( - keyword_files=pure_tree_files, - dir_scan_files=[], - knowledge_hits=[], + async def _deep_decompose_question( + self, + query: str, + evidence: str, + context: "SearchContext", + ) -> "DeepDecomposition": + """Stage 2: Decompose the question into structured plan (1 LLM call). + + Classifies query type (lookup/calculation/comparison/synthesis), + extracts required data points, time periods, entities, and + calculation steps. + """ + evidence_summary = evidence[:3000] if evidence else "(no evidence yet)" + prompt = DEEP_QUESTION_DECOMPOSE.format( + query=query, + evidence_summary=evidence_summary, + ) + try: + resp = await self.llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, ) - await self._logger.info( - f"[Phase 3:PureTree] Merged {len(merged_files)} tree-only candidate files" + self.llm_usages.append(resp.usage) + if resp.usage and isinstance(resp.usage, dict): + context.add_llm_tokens( + resp.usage.get("total_tokens", 0), usage=resp.usage, + ) + context.increment_loop() + + raw = (resp.content or "").strip() + parsed = self._parse_fast_json(raw) + + return DeepDecomposition( + query_type=parsed.get("query_type", "lookup"), + sub_questions=parsed.get("sub_questions", [query]), + required_data=parsed.get("required_data", []), + time_periods=parsed.get("time_periods", []), + entities=parsed.get("entities", []), + calculation_steps=parsed.get("calculation_steps", []), ) - else: - merged_files = self._merge_file_paths( - keyword_files=list(tree_hits) + catalog_deep_hits + compile_hints.file_paths + summary_index_hits + keyword_files, - dir_scan_files=dir_scan_files, - knowledge_hits=extra_knowledge_files, + except Exception as exc: + await self._logger.warning(f"[DEEP:S2] Decomposition failed: {exc}") + return DeepDecomposition(query_type="lookup", sub_questions=[query]) + + @staticmethod + def _deep_check_adequacy( + query: str, + evidence: str, + decomposition: "DeepDecomposition", + ) -> Tuple[bool, List[str]]: + """Stage 3: Rule-based evidence adequacy check (0 LLM calls). + + Checks whether the evidence contains the required data points + from the decomposition. Returns (adequate, gap_descriptions). + """ + if not evidence or len(evidence.strip()) < 200: + return False, ["evidence too short"] + + evidence_lower = evidence.lower() + gaps: List[str] = [] + + for data_point in decomposition.required_data: + tokens = [t.lower() for t in re.findall(r"[A-Za-z0-9]+", data_point) if len(t) >= 3] + if tokens and not any(t in evidence_lower for t in tokens): + gaps.append(data_point) + + for period in decomposition.time_periods: + year_match = re.search(r"(\d{4})", period) + if year_match and year_match.group(1) not in evidence: + gaps.append(f"time period: {period}") + + if decomposition.query_type == "calculation": + numbers = re.findall(r'[\$€£]?\d[\d,]*\.?\d*', evidence) + if len(numbers) < 2: + gaps.append("insufficient numeric data for calculation") + + adequate = len(gaps) <= len(decomposition.required_data) * 0.3 + return adequate, gaps + + async def _deep_fill_evidence_gaps( + self, + query: str, + gaps: List[str], + retrieval: "DeepRetrieval", + artifacts: "CompileArtifacts", + context: "SearchContext", + ) -> str: + """Fill identified evidence gaps with targeted retrieval. + + Uses tree section selection for files with tree indices, or + keyword-based rga search for specific gap terms. + """ + extra_parts: List[str] = [] + tree_paths = artifacts.tree_available_paths if artifacts else set() + indexer = self._get_tree_indexer() + + for fp in retrieval.file_paths[:3]: + if fp not in tree_paths or indexer is None: + continue + + tree = indexer.load_tree(fp) + if tree is None or tree.root is None: + continue + + section_map, sections_meta = self._build_section_map( + tree.root, max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH + 1, ) - await self._logger.info(f"[Phase 3] Merged {len(merged_files)} unique candidate files") + if not sections_meta: + continue - cluster: Optional[KnowledgeCluster] = None - if merged_files: - cluster = await self._build_cluster( - query=query, file_paths=merged_files, - query_keywords=query_keywords, top_k_files=top_k_files, + gap_query = f"{query} — specifically looking for: {'; '.join(gaps[:5])}" + selected = await self._select_evidence_sections( + gap_query, section_map, sections_meta, ) + context.increment_loop() - # ============================================================== - # Phase 3.5: Graph context enrichment (P5) - # Append related knowledge from graph neighbours to cluster content - # so the answer-generation LLM has richer context. - # ============================================================== - graph_ctx = "" - if cluster: - # Merge pre-navigated tree evidence into cluster content - if _pre_nav_evidence and cluster.content: - pre_nav_parts = [] - for fp, ev in _pre_nav_evidence.items(): - pre_nav_parts.append(f"[Tree evidence: {Path(fp).name}]\n{ev}") - if pre_nav_parts: - pre_nav_ctx = "\n\n".join(pre_nav_parts) - if isinstance(cluster.content, list): - cluster.content = "\n".join(cluster.content) - cluster.content = f"{cluster.content}\n\n{pre_nav_ctx}" - - graph_ctx = await self._gather_graph_context(cluster) - if graph_ctx and cluster.content: - if isinstance(cluster.content, list): - cluster.content = "\n".join(cluster.content) - cluster.content = f"{cluster.content}\n\n{graph_ctx}" + if selected: + ev = await self._extract_targeted_pages(fp, selected, query) + if ev and len(ev.strip()) > 100: + extra_parts.append(f"[Gap-fill: {Path(fp).name}]\n{ev}") + context.mark_file_read(fp) - # ============================================================== - # Phase 4: Structured Reasoning → Cluster Summary fallback - # P0: DEEP mode always goes through full reasoning pipeline — - # no fast triage short-circuit. P4: query complexity determines - # whether the heavier section-map SR fires or we go straight to - # cluster synthesis. - # ============================================================== - context.increment_loop() - answer = "" - should_save = True + if extra_parts: + await self._logger.info( + f"[DEEP:S3] Gap-fill: {len(extra_parts)} additional evidence sources" + ) + return "\n\n---\n\n".join(extra_parts) - _query_complexity = self._classify_query_complexity(query) - await self._logger.info( - f"[Phase 4] Query complexity: {_query_complexity}" - ) + async def _deep_synthesize( + self, + query: str, + evidence: str, + decomposition: "DeepDecomposition", + artifacts: "CompileArtifacts", + retrieval: "DeepRetrieval", + context: "SearchContext", + ) -> Tuple[str, bool, bool]: + """Stage 4: Strategy-routed answer synthesis. - # Attempt structured reasoning for moderate/complex queries - _sr_files: List[str] = [] - if _query_complexity != "simple": - if tree_hits: - _sr_files = list(tree_hits[: self._DEEP_STRUCTURED_MAX_FILES]) - elif artifacts and artifacts.tree_available_paths: - _sr_files = list(artifacts.tree_available_paths)[ - : self._DEEP_STRUCTURED_MAX_FILES - ] + Routes to specialized prompts based on query_type: + - calculation: DEEP_CALCULATION_SYNTHESIS with explicit computation steps + - lookup/comparison/synthesis: ROI_RESULT_SUMMARY with document context + """ + doc_context: Optional[str] = None + if artifacts and artifacts.catalog_map: + ctx_parts = [ + self._build_answer_context(fp, artifacts) + for fp in retrieval.file_paths[:3] + ] + ctx_parts = [c for c in ctx_parts if c] + if ctx_parts: + doc_context = "\n".join(ctx_parts) - if _sr_files: - await self._logger.info( - f"[Phase 4] Launching structured reasoning for " - f"{len(_sr_files)} tree-indexed files" + if decomposition.query_type == "calculation" and decomposition.calculation_steps: + steps_text = "\n".join( + f"{i+1}. {s}" for i, s in enumerate(decomposition.calculation_steps) ) - sr_answer, sr_cluster, sr_evidence = await self._deep_structured_reasoning( - query, _sr_files, artifacts, context, + synth_prompt = DEEP_CALCULATION_SYNTHESIS.format( + calculation_steps=steps_text, + user_input=query, + text_content=evidence[:self._DEEP_STRUCTURED_MAX_CHARS * 2], + ) + elif doc_context: + from sirchmunk.llm.prompts import ROI_RESULT_SUMMARY_WITH_CONTEXT + synth_prompt = ROI_RESULT_SUMMARY_WITH_CONTEXT.format( + user_input=query, + text_content=evidence[:self._DEEP_STRUCTURED_MAX_CHARS * 2], + document_context=doc_context, + ) + else: + synth_prompt = ROI_RESULT_SUMMARY.format( + user_input=query, + text_content=evidence[:self._DEEP_STRUCTURED_MAX_CHARS * 2], ) - if sr_answer: - answer, should_save, should_answer = self._parse_summary_response( - sr_answer - ) - accepted, accept_reason = self._evaluate_evidence_acceptance( - query, sr_evidence or sr_answer, should_answer, - ) - await self._logger.info( - f"[Phase 4] Structured reasoning: " - f"accepted={accepted} ({accept_reason})" - ) - if accepted: - cluster = sr_cluster or cluster - else: - answer = "" + resp = await self.llm.achat( + messages=[{"role": "user", "content": synth_prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) + if resp.usage and isinstance(resp.usage, dict): + context.add_llm_tokens( + resp.usage.get("total_tokens", 0), usage=resp.usage, + ) + context.increment_loop() - # Fallback: cluster summary with ROI prompt or ReAct - if not answer: - if artifacts and artifacts.catalog_map and cluster and cluster.content: - _catalog_ctx_parts = [] - for fp in (cluster.search_results or merged_files)[:3]: - ctx = self._build_answer_context(fp, artifacts) - if ctx: - _catalog_ctx_parts.append(ctx) - if _catalog_ctx_parts: - _catalog_context = "\n".join(_catalog_ctx_parts) - if isinstance(cluster.content, list): - cluster.content = "\n".join(cluster.content) - cluster.content = ( - f"{cluster.content}\n\n" - f"[Document Context]\n{_catalog_context}" - ) + answer, should_save, should_answer = self._parse_summary_response( + resp.content or "" + ) - if cluster and cluster.content: - await self._logger.info( - "[Phase 4:Fallback] Generating summary from cluster" - ) - answer, should_save, should_answer = ( - await self._summarise_cluster(query, cluster) - ) - cluster_evidence = ( - str(cluster.content) if cluster.content else "" - ) - accepted, accept_reason = ( - self._evaluate_evidence_acceptance( - query, cluster_evidence, should_answer, - ) - ) - if not accepted: - if llm_fallback: - answer, should_save = ( - await self._summarise_cluster_fallback(query) - ) - else: - # DEEP self-correction before giving up - sc_evidence = await self._deep_self_correct( - query, merged_files, query_keywords, context, - ) - if sc_evidence: - sc_cluster = self._make_answer_cluster( - query, sc_evidence[:5000], "DSC", - file_paths=list(merged_files)[:3], - ) - sc_cluster.content = sc_evidence - answer, should_save, should_answer = ( - await self._summarise_cluster(query, sc_cluster) - ) - sc_accepted, _ = self._evaluate_evidence_acceptance( - query, sc_evidence, should_answer, - ) - if sc_accepted: - cluster = sc_cluster - else: - return _NO_RESULTS_MESSAGE, None, context - else: - return _NO_RESULTS_MESSAGE, None, context - if not cluster.search_results: - cluster.search_results = list(merged_files) - elif llm_fallback: - answer, should_save = ( - await self._summarise_cluster_fallback(query) - ) - else: - await self._logger.info( - "[Phase 4:Fallback] Launching ReAct refinement" - ) - # Seed ReAct with all available prior context so it - # doesn't start from scratch. - react_parts: List[str] = [] - if spec_context: - react_parts.append(spec_context) - if graph_ctx: - react_parts.append(graph_ctx) - if _pre_nav_evidence: - nav_seed = "\n\n".join( - f"[Pre-navigated: {Path(fp).name}]\n{ev}" - for fp, ev in _pre_nav_evidence.items() - ) - react_parts.append(nav_seed) - react_spec = "\n\n".join(react_parts) - react_answer, context = await self._react_refinement( - query=query, paths=paths, - initial_keywords=initial_keywords, - spec_context=react_spec, - enable_dir_scan=enable_dir_scan, - max_loops=max_loops, - max_token_budget=max_token_budget, - max_depth=max_depth, - include=include, exclude=exclude, - ) - if not cluster: - cluster = await self._build_cluster_from_context( - query=query, answer=react_answer, - context=context, - query_keywords=query_keywords, - top_k_files=top_k_files, - ) - elif react_answer and not cluster.content: - cluster.content = react_answer - if not cluster: - return _NO_RESULTS_MESSAGE, None, context - answer, should_save, should_answer = ( - await self._summarise_cluster(query, cluster) - ) - final_evidence = ( - str(cluster.content) if cluster.content else "" - ) - final_accepted, _ = self._evaluate_evidence_acceptance( - query, final_evidence, should_answer, - ) - if not final_accepted: - if llm_fallback: - answer, should_save = ( - await self._summarise_cluster_fallback(query) - ) - else: - sc_evidence = await self._deep_self_correct( - query, merged_files, query_keywords, context, - ) - if sc_evidence: - sc_cluster = self._make_answer_cluster( - query, sc_evidence[:5000], "DSC", - file_paths=list(merged_files)[:3], - ) - sc_cluster.content = sc_evidence - answer, should_save, _ = ( - await self._summarise_cluster(query, sc_cluster) - ) - cluster = sc_cluster - else: - return _NO_RESULTS_MESSAGE, None, context + accepted, accept_reason = self._evaluate_evidence_acceptance( + query, evidence, should_answer, + ) + await self._logger.info( + f"[DEEP:S4] Synthesis: accepted={accepted} ({accept_reason}), " + f"type={decomposition.query_type}" + ) - # Sync LLM token accounting into context - new_usages = self.llm_usages[_llm_usage_start:] - for usage in new_usages: - if usage and isinstance(usage, dict): - total_tok = usage.get("total_tokens", 0) - if total_tok == 0: - total_tok = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0) - context.add_llm_tokens(total_tok, usage=usage) + if not accepted: + return "", False, False + return answer, should_save, should_answer - # ============================================================== - # Phase 5: Persistence (quality-gated) - # Skipped when Phase 4 quality check says the answer is low-quality - # or when Phase 0 reused a cluster (early-returned above). - # ============================================================== - phase5_tasks = [] - if cluster and should_save: - self._add_query_to_cluster(cluster, query) - phase5_tasks.append(self._save_cluster_with_embedding(cluster)) - elif not should_save: - await self._logger.info("[Phase 5] Quality gate: low-quality answer, skipping cluster save") - cluster = None - phase5_tasks.append(self._save_spec_context(paths, context, scan_result=scan_result)) - results = await asyncio.gather(*phase5_tasks, return_exceptions=True) - for r in results: - if isinstance(r, Exception): - _loguru_logger.warning(f"[Phase 5] Persistence task failed: {r}") + @staticmethod + def _deep_verify_answer( + query: str, + answer: str, + evidence: str, + ) -> Tuple[str, bool]: + """Stage 5: Verify calculation answers with Python eval. - await self._logger.success(f"[search] Complete: {context.summary()}") - return answer, cluster, context + Extracts numeric expressions from the answer, attempts to + evaluate them, and flags discrepancies. Returns + (potentially_corrected_answer, verified). + """ + num_pattern = re.compile( + r'[\$€£]?\s*([\d,]+\.?\d*)\s*(?:million|billion|%|percent)?', + re.IGNORECASE, + ) + answer_numbers = num_pattern.findall(answer[:500]) + if len(answer_numbers) < 1: + return answer, True + + calc_patterns = [ + re.compile(r'(\d[\d,]*\.?\d*)\s*[/÷]\s*(\d[\d,]*\.?\d*)', re.IGNORECASE), + re.compile(r'(\d[\d,]*\.?\d*)\s*[-−]\s*(\d[\d,]*\.?\d*)', re.IGNORECASE), + re.compile(r'\((\d[\d,]*\.?\d*)\s*[-−]\s*(\d[\d,]*\.?\d*)\)\s*[/÷]\s*(\d[\d,]*\.?\d*)', re.IGNORECASE), + ] + + for pat in calc_patterns: + matches = pat.findall(evidence + " " + answer) + for m in matches: + try: + nums = [float(n.replace(",", "")) for n in m if n] + if len(nums) >= 2 and nums[-1] != 0: + _ = nums[0] / nums[-1] + except (ValueError, ZeroDivisionError): + pass + + return answer, True # ------------------------------------------------------------------ # Phase 0a: Direct document analysis (intent-gated) @@ -5797,15 +6031,27 @@ async def _llm_select_from_trees( if Path(pool[idx].file_path).exists() ] - async def _probe_tree_index(self, query: str) -> List[str]: + async def _probe_tree_index( + self, + query: str, + *, + scope: Optional["_PathScope"] = None, + artifacts: Optional["CompileArtifacts"] = None, + ) -> List[str]: """LLM-driven file discovery via compiled tree root summaries (PageIndex). - Loads all cached document trees, presents their root summaries to the - LLM, and asks it to select the most relevant documents. Returns file - paths of the most relevant documents. + Loads cached document trees, filters them by *scope* and/or + *artifacts.tree_available_paths*, presents root summaries to the + LLM, and asks it to select the most relevant documents. """ try: trees = self._load_cached_trees() + if not trees: + return [] + if artifacts and artifacts.tree_available_paths: + trees = [t for t in trees if t.file_path in artifacts.tree_available_paths] + if scope and not scope.is_empty: + trees = [t for t in trees if scope.contains(t.file_path)] if not trees: return [] result = await self._llm_select_from_trees( From f81b24fd9e89f5f0cbf23da33ec52025c940815b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 14 May 2026 23:20:19 +0800 Subject: [PATCH 62/70] enhance search deep --- src/sirchmunk/search.py | 377 +++++++++++++++++++++++++++++++++------- 1 file changed, 313 insertions(+), 64 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index b717aa4..4976b43 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -1776,12 +1776,17 @@ async def _search_deep( # ==================== Stage 3: Adequacy check + gap-fill ==================== adequate, gaps = self._deep_check_adequacy(query, evidence, decomposition) - if not adequate and gaps: + needs_gap_fill = (not adequate and gaps) or ( + decomposition.query_type in ("calculation", "comparison") + ) + if needs_gap_fill: + fill_gaps = gaps if gaps else ["numeric data for calculation"] await self._logger.info( - f"[DEEP:S3] Gaps detected ({len(gaps)}): {gaps[:3]}" + f"[DEEP:S3] Gaps detected ({len(fill_gaps)}): {fill_gaps[:3]}" ) extra = await self._deep_fill_evidence_gaps( - query, gaps, retrieval, artifacts, context, + query, fill_gaps, retrieval, artifacts, context, + decomposition=decomposition, ) if extra: evidence = f"{evidence}\n\n---\n\n{extra}" @@ -1810,8 +1815,11 @@ async def _search_deep( else: return _NO_RESULTS_MESSAGE, None, context - # ==================== Stage 5: Verification ==================== - if answer and decomposition.query_type == "calculation": + # ==================== Stage 5: Self-consistency + Verification ==================== + if answer and decomposition.query_type in ("calculation", "comparison"): + answer = await self._deep_self_consistency( + query, answer, evidence, decomposition, artifacts, retrieval, context, + ) answer, _ = self._deep_verify_answer(query, answer, evidence) # --- Token accounting --- @@ -2039,24 +2047,24 @@ async def _deep_gather_evidence( retrieval: "DeepRetrieval", artifacts: "CompileArtifacts", context: "SearchContext", - *, - max_chars: int = 80_000, ) -> str: """Stage 1: Evidence saturation for all retrieved files. - Gathers evidence from multiple sources per file, in priority order: + Gathers evidence from multiple sources per file: 1. Tree navigation (LLM-guided section targeting) - 2. Table digest (pre-compiled structured tables) - 3. Tree-guided sampling (section-level content) - 4. rga keyword sampling (grep-based snippets) + 2. Table digest (pre-compiled structured tables, expanded budget) + 3. Financial statement harvesting (rule-based section title match) + 4. Tree-guided sampling / rga fallback - Runs tree navigation in parallel across files for efficiency. + Uses DEEP-specific budgets (_DEEP_EVIDENCE_TOTAL_CHARS) that are + larger than FAST's to maximize evidence quality. """ if not retrieval.file_paths: return "" file_paths = retrieval.file_paths tree_paths = artifacts.tree_available_paths if artifacts else set() + max_chars = self._DEEP_EVIDENCE_TOTAL_CHARS async def _gather_for_file(fp: str) -> str: parts: List[str] = [] @@ -2082,8 +2090,8 @@ async def _gather_for_file(fp: str) -> str: tables = self._load_table_digest(self.work_path, fh) if tables: budget = ( - self._TABLE_EVIDENCE_NAV_OVERLAP_CHARS if nav_ev - else self._TABLE_EVIDENCE_DEFAULT_CHARS + self._DEEP_TABLE_BUDGET_WITH_NAV if nav_ev + else self._DEEP_TABLE_BUDGET ) table_ev = self._format_table_evidence( tables, max_chars=budget, query=query, @@ -2093,6 +2101,10 @@ async def _gather_for_file(fp: str) -> str: if table_ev: parts.append(f"[{fname} - Table Evidence]\n{table_ev}") + stmt_ev = await self._harvest_financial_statements(fp, query, artifacts) + if stmt_ev: + parts.append(f"[{fname} - Financial Statements]\n{stmt_ev}") + if not nav_ev and fp in tree_paths: try: tree_sample = await self._tree_guided_sample( @@ -2139,6 +2151,52 @@ async def _gather_for_file(fp: str) -> str: ) return combined + async def _harvest_financial_statements( + self, + file_path: str, + query: str, + artifacts: "CompileArtifacts", + ) -> str: + """Proactively extract financial statement sections via tree index. + + Scans tree section titles for income/balance/cashflow patterns and + extracts those pages. Complements tree navigation which may focus + on narrative sections instead of data-dense statements. + """ + tree_paths = artifacts.tree_available_paths if artifacts else set() + if file_path not in tree_paths: + return "" + + indexer = self._get_tree_indexer() + if indexer is None: + return "" + + tree = indexer.load_tree(file_path) + if tree is None or tree.root is None: + return "" + + statement_sections: List[Dict[str, Any]] = [] + section_map, sections_meta = self._build_section_map( + tree.root, max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH, + ) + + for sec in sections_meta: + title = (sec.get("title") or "").lower() + if any(pat.search(title) for pat in self._DEEP_STATEMENT_PATTERNS): + if sec.get("page_range"): + statement_sections.append(sec) + + if not statement_sections: + return "" + + try: + ev = await self._extract_targeted_pages( + file_path, statement_sections[:6], query, + ) + return ev or "" + except Exception: + return "" + async def _deep_decompose_question( self, query: str, @@ -2225,41 +2283,48 @@ async def _deep_fill_evidence_gaps( retrieval: "DeepRetrieval", artifacts: "CompileArtifacts", context: "SearchContext", + decomposition: Optional["DeepDecomposition"] = None, ) -> str: """Fill identified evidence gaps with targeted retrieval. - Uses tree section selection for files with tree indices, or - keyword-based rga search for specific gap terms. + Strategy per file: + 1. Tree section selection with expanded depth for gap-specific terms + 2. Table digest supplement for numeric gaps + 3. Keyword rga fallback for non-tree files """ extra_parts: List[str] = [] tree_paths = artifacts.tree_available_paths if artifacts else set() indexer = self._get_tree_indexer() + is_calc = decomposition and decomposition.query_type in ("calculation", "comparison") for fp in retrieval.file_paths[:3]: - if fp not in tree_paths or indexer is None: - continue - - tree = indexer.load_tree(fp) - if tree is None or tree.root is None: - continue - - section_map, sections_meta = self._build_section_map( - tree.root, max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH + 1, - ) - if not sections_meta: - continue - gap_query = f"{query} — specifically looking for: {'; '.join(gaps[:5])}" - selected = await self._select_evidence_sections( - gap_query, section_map, sections_meta, - ) - context.increment_loop() - if selected: - ev = await self._extract_targeted_pages(fp, selected, query) - if ev and len(ev.strip()) > 100: - extra_parts.append(f"[Gap-fill: {Path(fp).name}]\n{ev}") - context.mark_file_read(fp) + if fp in tree_paths and indexer is not None: + tree = indexer.load_tree(fp) + if tree is not None and tree.root is not None: + section_map, sections_meta = self._build_section_map( + tree.root, max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH + 2, + ) + if sections_meta: + selected = await self._select_evidence_sections( + gap_query, section_map, sections_meta, + ) + context.increment_loop() + if selected: + ev = await self._extract_targeted_pages(fp, selected, query) + if ev and len(ev.strip()) > 100: + extra_parts.append(f"[Gap-fill: {Path(fp).name}]\n{ev}") + context.mark_file_read(fp) + + if is_calc and any("numeric" in g or "data" in g for g in gaps): + table_index = (artifacts.table_index or {}).get(fp, []) + if table_index: + table_ev = self._format_table_evidence( + table_index, max_chars=self._DEEP_TABLE_BUDGET_WITH_NAV, query=query, + ) + if table_ev and len(table_ev.strip()) > 100: + extra_parts.append(f"[Table-supplement: {Path(fp).name}]\n{table_ev}") if extra_parts: await self._logger.info( @@ -2281,7 +2346,14 @@ async def _deep_synthesize( Routes to specialized prompts based on query_type: - calculation: DEEP_CALCULATION_SYNTHESIS with explicit computation steps - lookup/comparison/synthesis: ROI_RESULT_SUMMARY with document context + + Pre-processing: prunes evidence by entities/time_periods, and + appends format constraints derived from query semantics. """ + evidence = self._deep_prune_evidence(evidence, decomposition) + + format_hint = self._deep_format_constraint(query) + doc_context: Optional[str] = None if artifacts and artifacts.catalog_map: ctx_parts = [ @@ -2299,21 +2371,24 @@ async def _deep_synthesize( synth_prompt = DEEP_CALCULATION_SYNTHESIS.format( calculation_steps=steps_text, user_input=query, - text_content=evidence[:self._DEEP_STRUCTURED_MAX_CHARS * 2], + text_content=evidence[:self._DEEP_SYNTHESIS_MAX_CHARS], ) elif doc_context: from sirchmunk.llm.prompts import ROI_RESULT_SUMMARY_WITH_CONTEXT synth_prompt = ROI_RESULT_SUMMARY_WITH_CONTEXT.format( user_input=query, - text_content=evidence[:self._DEEP_STRUCTURED_MAX_CHARS * 2], + text_content=evidence[:self._DEEP_SYNTHESIS_MAX_CHARS], document_context=doc_context, ) else: synth_prompt = ROI_RESULT_SUMMARY.format( user_input=query, - text_content=evidence[:self._DEEP_STRUCTURED_MAX_CHARS * 2], + text_content=evidence[:self._DEEP_SYNTHESIS_MAX_CHARS], ) + if format_hint: + synth_prompt = f"{synth_prompt}\n\n### Format Constraint\n{format_hint}" + resp = await self.llm.achat( messages=[{"role": "user", "content": synth_prompt}], stream=True, @@ -2332,6 +2407,14 @@ async def _deep_synthesize( accepted, accept_reason = self._evaluate_evidence_acceptance( query, evidence, should_answer, ) + + if not accepted and decomposition.query_type in ("calculation", "comparison"): + accepted = self._deep_relaxed_acceptance(query, evidence) + if accepted: + accept_reason = "deep_calc_relaxed" + should_answer = True + should_save = True + await self._logger.info( f"[DEEP:S4] Synthesis: accepted={accepted} ({accept_reason}), " f"type={decomposition.query_type}" @@ -2341,6 +2424,121 @@ async def _deep_synthesize( return "", False, False return answer, should_save, should_answer + @staticmethod + def _deep_relaxed_acceptance(query: str, evidence: str) -> bool: + """Relaxed acceptance for calculation/comparison queries. + + Accepts when evidence contains >=2 distinct numbers AND at least + one query-relevant keyword appears. This prevents false-negative + rejections on numeric data that the standard heuristic misses + due to low keyword coverage from formula-heavy queries. + """ + numbers = re.findall(r'[\$€£]?\d[\d,]*\.?\d*', evidence[:20000]) + if len(numbers) < 2: + return False + kw_coverage = AgenticSearch._compute_keyword_coverage(query, evidence) + return kw_coverage >= 0.3 + + @staticmethod + def _deep_prune_evidence( + evidence: str, + decomposition: "DeepDecomposition", + ) -> str: + """Prune evidence paragraphs not matching required time_periods/entities. + + Splits evidence into paragraph-level blocks. Retains a block if it + mentions ANY required time period or entity. Blocks that match + neither are discarded (unless fewer than 30% of blocks would remain, + in which case no pruning is applied to avoid over-filtering). + """ + if not decomposition.time_periods and not decomposition.entities: + return evidence + + periods = {p.lower() for p in decomposition.time_periods} + year_patterns = {re.search(r"\d{4}", p).group() for p in periods if re.search(r"\d{4}", p)} + entities = {e.lower() for e in decomposition.entities} + + blocks = re.split(r'\n{2,}', evidence) + if len(blocks) <= 3: + return evidence + + kept: List[str] = [] + for block in blocks: + block_lower = block.lower() + has_period = any(y in block for y in year_patterns) if year_patterns else True + has_entity = any(e in block_lower for e in entities) if entities else True + if has_period or has_entity: + kept.append(block) + + if len(kept) < len(blocks) * 0.3: + return evidence + return "\n\n".join(kept) + + @staticmethod + def _deep_format_constraint(query: str) -> str: + """Derive answer format guidance from query semantics. + + Returns a short instruction string appended to the synthesis prompt + to steer PRECISE_ANSWER toward the expected format. + """ + q_lower = query.lower() + if re.search(r'\b(is|does|did|was|were|has|have|can|will|should)\b', q_lower) and "?" in query: + return "Answer with Yes or No first, then provide justification." + if re.search(r'(?:million|billion|mn|bn)\b', q_lower): + unit = "billion" if re.search(r'\b(billion|bn)\b', q_lower) else "million" + return f"Express the final answer in {unit}s (e.g. $X {unit})." + if re.search(r'\bratio\b', q_lower): + return "Express the final answer as a decimal ratio (e.g. 1.5x or 0.75)." + if re.search(r'\b(percentage|percent|%)\b', q_lower): + return "Express the final answer as a percentage (e.g. 25.3%)." + if re.search(r'\bgrowth\b.*\brate\b|\brate\b.*\bgrowth\b', q_lower): + return "Express the final answer as a percentage change (e.g. +12.5% or -3.2%)." + return "" + + async def _deep_self_consistency( + self, + query: str, + first_answer: str, + evidence: str, + decomposition: "DeepDecomposition", + artifacts: "CompileArtifacts", + retrieval: "DeepRetrieval", + context: "SearchContext", + ) -> str: + """Run a second synthesis and pick the consistent answer. + + Compares PRECISE_ANSWER from both runs. If they match (within + numeric tolerance), returns the first. If they diverge, picks + the answer whose PRECISE_ANSWER contains a valid number. + """ + second_answer, _, _ = await self._deep_synthesize( + query, evidence, decomposition, artifacts, retrieval, context, + ) + if not second_answer: + return first_answer + + precise_1 = re.search(r'\*\*Answer:\s*(.+?)\*\*', first_answer) + precise_2 = re.search(r'\*\*Answer:\s*(.+?)\*\*', second_answer) + if not precise_1 or not precise_2: + return first_answer + + val_1 = re.sub(r'[^\d.\-]', '', precise_1.group(1).replace(',', '')) + val_2 = re.sub(r'[^\d.\-]', '', precise_2.group(1).replace(',', '')) + + try: + n1, n2 = float(val_1), float(val_2) + except (ValueError, TypeError): + return first_answer + + tolerance = max(abs(n1) * 0.05, 0.01) + if abs(n1 - n2) <= tolerance: + return first_answer + + await self._logger.info( + f"[DEEP:S5] Self-consistency divergence: {val_1} vs {val_2}, using second" + ) + return second_answer + @staticmethod def _deep_verify_answer( query: str, @@ -2349,35 +2547,70 @@ def _deep_verify_answer( ) -> Tuple[str, bool]: """Stage 5: Verify calculation answers with Python eval. - Extracts numeric expressions from the answer, attempts to - evaluate them, and flags discrepancies. Returns - (potentially_corrected_answer, verified). + Extracts the COMPUTATION block from the answer, parses variable + assignments (``var = expr``), evaluates them in a safe namespace, + and compares the final result against PRECISE_ANSWER. When a + discrepancy is found, replaces the PRECISE_ANSWER with the + recomputed value. + + Returns (potentially_corrected_answer, verified). """ - num_pattern = re.compile( - r'[\$€£]?\s*([\d,]+\.?\d*)\s*(?:million|billion|%|percent)?', - re.IGNORECASE, + computation_match = re.search( + r'(.*?)', answer, re.DOTALL, ) - answer_numbers = num_pattern.findall(answer[:500]) - if len(answer_numbers) < 1: + precise_match = re.search( + r'(.*?)', answer, re.DOTALL, + ) + if not computation_match: return answer, True - calc_patterns = [ - re.compile(r'(\d[\d,]*\.?\d*)\s*[/÷]\s*(\d[\d,]*\.?\d*)', re.IGNORECASE), - re.compile(r'(\d[\d,]*\.?\d*)\s*[-−]\s*(\d[\d,]*\.?\d*)', re.IGNORECASE), - re.compile(r'\((\d[\d,]*\.?\d*)\s*[-−]\s*(\d[\d,]*\.?\d*)\)\s*[/÷]\s*(\d[\d,]*\.?\d*)', re.IGNORECASE), - ] + computation_text = computation_match.group(1) + assignments = re.findall( + r'(?:^|\n)\s*[\w\s]+?=\s*(.+?)(?:\s*\(|$|\n)', + computation_text, + ) + if not assignments: + return answer, True - for pat in calc_patterns: - matches = pat.findall(evidence + " " + answer) - for m in matches: - try: - nums = [float(n.replace(",", "")) for n in m if n] - if len(nums) >= 2 and nums[-1] != 0: - _ = nums[0] / nums[-1] - except (ValueError, ZeroDivisionError): - pass + safe_ns: Dict[str, float] = {} + last_result: Optional[float] = None + + for expr_raw in assignments: + expr = expr_raw.strip().rstrip("(") + expr = re.sub(r'[\$€£,]', '', expr) + expr = expr.replace('−', '-').replace('÷', '/').replace('×', '*') + expr = re.sub(r'[a-zA-Z%]+$', '', expr).strip() + if not expr or not re.search(r'\d', expr): + continue + try: + result = float(eval(expr, {"__builtins__": {}}, safe_ns)) # noqa: S307 + last_result = result + except Exception: + continue - return answer, True + if last_result is None or not precise_match: + return answer, True + + precise_text = precise_match.group(1).strip() + precise_num = re.sub(r'[^\d.\-]', '', precise_text.replace(',', '')) + try: + stated = float(precise_num) if precise_num else None + except ValueError: + stated = None + + if stated is None: + return answer, True + + tolerance = max(abs(stated) * 0.02, 0.01) + if abs(stated - last_result) <= tolerance: + return answer, True + + fmt = f"{last_result:.1f}" if abs(last_result) < 100 else f"{last_result:,.0f}" + corrected = answer.replace( + precise_match.group(0), + f"{fmt}", + ) + return corrected, False # ------------------------------------------------------------------ # Phase 0a: Direct document analysis (intent-gated) @@ -2766,6 +2999,22 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _DEEP_STRUCTURED_MAX_FILES: int = 3 """Maximum files to process through structured reasoning pipeline.""" + # --- DEEP v2 evidence budgets --- + _DEEP_EVIDENCE_TOTAL_CHARS: int = 120_000 + """Total evidence budget for DEEP mode (uses more context than FAST).""" + _DEEP_TABLE_BUDGET: int = 40_000 + """Table digest budget per file in DEEP mode (no tree nav overlap).""" + _DEEP_TABLE_BUDGET_WITH_NAV: int = 20_000 + """Table digest budget per file when tree nav also provides evidence.""" + _DEEP_SYNTHESIS_MAX_CHARS: int = 60_000 + """Maximum evidence chars sent to the synthesis LLM call.""" + _DEEP_STATEMENT_PATTERNS: Tuple[re.Pattern, ...] = ( + re.compile(r"(?:income|operations|earnings|profit.loss)", re.IGNORECASE), + re.compile(r"(?:balance.sheet|financial.position|assets.liab)", re.IGNORECASE), + re.compile(r"(?:cash.flow|cash.provided|financing.activities)", re.IGNORECASE), + ) + """Section title patterns for financial statement harvesting.""" + # --- Evidence acceptance thresholds --- _EVIDENCE_MIN_ACCEPT_LENGTH: int = 800 """Minimum evidence character length for heuristic override.""" From 7a0adf754a8e28cf992ba98b1b8be349368eb0cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 15 May 2026 01:11:37 +0800 Subject: [PATCH 63/70] fallback --- src/sirchmunk/llm/prompts.py | 72 -- src/sirchmunk/search.py | 1273 +++++++++++----------------------- 2 files changed, 389 insertions(+), 956 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index dcfbdc6..4b4c3bb 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -533,78 +533,6 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: """ -# --------------------------------------------------------------------------- -# DEEP mode: Question Decomposition -# --------------------------------------------------------------------------- - -DEEP_QUESTION_DECOMPOSE = """Analyze the user query and decompose it into a structured plan for document-based answering. - -### User Query -{query} - -### Available Evidence Summary -{evidence_summary} - -### Output -Return JSON only, no extra text: -{{ - "query_type": "lookup|calculation|comparison|synthesis", - "sub_questions": ["sub-question 1", "sub-question 2"], - "required_data": ["data point 1", "data point 2"], - "time_periods": ["FY2021", "FY2022"], - "entities": ["Company A"], - "calculation_steps": ["step 1: find X", "step 2: compute Y = X / Z"] -}} - -Rules: -- **query_type**: "lookup" for direct fact retrieval; "calculation" for queries needing arithmetic (ratios, growth rates, differences); "comparison" for year-over-year or entity-vs-entity; "synthesis" for multi-fact integration. -- **sub_questions**: Break compound queries into atomic retrievable questions. Single-fact queries get one sub-question. -- **required_data**: Specific data points needed from the document (e.g. "FY2022 total revenue", "FY2021 net income"). -- **time_periods**: Fiscal years, quarters, or date ranges mentioned or implied. -- **entities**: Company names, subsidiaries, product lines, or segments referenced. -- **calculation_steps**: For "calculation" type only — ordered steps. For other types, empty array. -""" - - -# --------------------------------------------------------------------------- -# DEEP mode: Calculation-Aware Synthesis -# --------------------------------------------------------------------------- - -DEEP_CALCULATION_SYNTHESIS = """ -### Task -Answer the user's question by performing precise calculations on the provided evidence. - -### Constraints -1. **Language Continuity**: Reply in the SAME language as the User Input. -2. **Computation-first**: Extract ALL required numbers from the evidence BEFORE computing. List each number with its source (page, table, section). -3. **Show work**: Write out each calculation step explicitly. Use the format: `variable = value (source)`. -4. **Unit consistency**: Verify all numbers use compatible units before computing. Convert if needed — state the conversion. -5. **Rounding**: Match the precision implied by the query. For percentages, use at most one decimal place. For dollar amounts, round to the nearest whole number in the stated unit. -6. **Cross-check**: After computing, verify the result by a different method or sanity check (e.g. "Revenue growth of 50% seems high — let me re-verify the base figures"). -7. **Best-effort**: Compute from whatever relevant data is available. Only refuse when evidence contains NO related numbers at all. - -### Calculation Plan -{calculation_steps} - -### Input Data -- **User Input**: {user_input} -- **Evidence**: {text_content} - -### Output Format - -[List all extracted values with sources, then show each calculation step] - - -[Concise Markdown summary of the analysis and result] - - -[Final numeric answer only, matching the query's expected format] - -true/false -true/false -""" - - # --------------------------------------------------------------------------- # Knowledge Compile prompts # --------------------------------------------------------------------------- diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 4976b43..1acc220 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -27,8 +27,6 @@ DOC_CHUNK_SUMMARY, DOC_MERGE_SUMMARIES, DEEP_SECTION_SELECT, - DEEP_QUESTION_DECOMPOSE, - DEEP_CALCULATION_SYNTHESIS, ) from sirchmunk.retrieve.text_retriever import GrepRetriever from sirchmunk.schema.knowledge import ( @@ -184,29 +182,6 @@ class CompileHints: extra_keywords: List[str] -@dataclass -class DeepRetrieval: - """Structured output of DEEP Stage 0: file discovery and ranking.""" - - file_paths: List[str] - keywords: List[str] - keyword_idfs: Dict[str, float] - catalog_routed: List[str] - tree_probed: List[str] - - -@dataclass -class DeepDecomposition: - """Structured output of DEEP Stage 2: question analysis.""" - - query_type: str # "lookup", "calculation", "comparison", "synthesis" - sub_questions: List[str] = field(default_factory=list) - required_data: List[str] = field(default_factory=list) - time_periods: List[str] = field(default_factory=list) - entities: List[str] = field(default_factory=list) - calculation_steps: List[str] = field(default_factory=list) - - @dataclass class CompileArtifacts: """Compile artifact availability context for adaptive activation in FAST mode. @@ -1695,7 +1670,7 @@ async def search( return answer # ------------------------------------------------------------------ - # DEEP mode — staged evidence-first pipeline + # DEEP mode — parallel multi-path retrieval with ReAct fallback # ------------------------------------------------------------------ async def _search_deep( @@ -1713,15 +1688,7 @@ async def _search_deep( spec_stale_hours: float = 72.0, llm_fallback: bool = False, ) -> Tuple[str, Optional[KnowledgeCluster], SearchContext]: - """Evidence-first DEEP pipeline: retrieve → saturate → decompose → synthesize. - - Stages: - 0. FAST-style retrieval (1 LLM: query analysis + catalog routing) - 1. Evidence saturation (tree nav + table digest + rga per file) - 2. Question decomposition (1 LLM: classify + plan) - 3. Evidence adequacy check (0 LLM: rule-based gap detection) - 4. Strategy-routed synthesis (1 LLM: answer generation) - 5. Persistence (quality-gated cluster save) + """Parallel multi-path retrieval pipeline (Phases 0a–5). Returns: ``(answer, cluster, context)`` tuple. @@ -1732,885 +1699,451 @@ async def _search_deep( ) _llm_usage_start = len(self.llm_usages) + # --- Adaptive compile artifact detection (shared with FAST) --- _scope = _PathScope(paths) artifacts = self._detect_compile_artifacts(paths) - self._tree_nav_cache = _TreeNavCache() - # --- Short-circuits (Phase 0a + Phase 0) --- + # ============================================================== + # Phase 0a: Direct document analysis (intent-gated short-circuit) + # ============================================================== direct = await self._try_direct_doc_analysis(query, paths) if direct is not None: return direct, self._make_answer_cluster(query, direct, "DQ", file_paths=paths), context + # ============================================================== + # Phase 0: Cluster reuse (instant short-circuit) + # When reuse_knowledge=True and a similar cluster is found, we + # return here — Phase 5 (Persistence) is not executed for that path. + # ============================================================== reused = await self._try_reuse_cluster(query, paths) if reused is not None: return self._enrich_reused_content(reused), reused, context - await self._logger.info(f"[DEEP] Starting evidence-first pipeline for: '{query[:80]}'") - - # ==================== Stage 0: Retrieval ==================== - retrieval = await self._deep_retrieve( - query, paths, artifacts, _scope, context, - top_k_files=top_k_files, enable_dir_scan=enable_dir_scan, - max_depth=max_depth, include=include, exclude=exclude, - ) - - if not retrieval.file_paths: - if llm_fallback: - answer, _ = await self._summarise_cluster_fallback(query) - return answer, None, context - return _NO_RESULTS_MESSAGE, None, context - - # ==================== Stage 1: Evidence saturation ==================== - evidence = await self._deep_gather_evidence( - query, retrieval, artifacts, context, - ) - - if not evidence or len(evidence.strip()) < 50: - if llm_fallback: - answer, _ = await self._summarise_cluster_fallback(query) - return answer, None, context - return _NO_RESULTS_MESSAGE, None, context - - # ==================== Stage 2: Question decomposition ==================== - decomposition = await self._deep_decompose_question(query, evidence, context) - - # ==================== Stage 3: Adequacy check + gap-fill ==================== - adequate, gaps = self._deep_check_adequacy(query, evidence, decomposition) - needs_gap_fill = (not adequate and gaps) or ( - decomposition.query_type in ("calculation", "comparison") - ) - if needs_gap_fill: - fill_gaps = gaps if gaps else ["numeric data for calculation"] - await self._logger.info( - f"[DEEP:S3] Gaps detected ({len(fill_gaps)}): {fill_gaps[:3]}" - ) - extra = await self._deep_fill_evidence_gaps( - query, fill_gaps, retrieval, artifacts, context, - decomposition=decomposition, - ) - if extra: - evidence = f"{evidence}\n\n---\n\n{extra}" - - # ==================== Stage 4: Synthesis ==================== - answer, should_save, should_answer = await self._deep_synthesize( - query, evidence, decomposition, artifacts, retrieval, context, - ) - - # Self-correction: if synthesis rejected, try expanded evidence - if not answer: - await self._logger.info("[DEEP:S4] First synthesis rejected, trying self-correction") - sc_evidence = await self._deep_self_correct( - query, retrieval.file_paths, retrieval.keyword_idfs, context, - ) - if sc_evidence: - evidence = sc_evidence - answer, should_save, should_answer = await self._deep_synthesize( - query, sc_evidence, decomposition, artifacts, retrieval, context, - ) - - # Final fallback - if not answer: - if llm_fallback: - answer, should_save = await self._summarise_cluster_fallback(query) - else: - return _NO_RESULTS_MESSAGE, None, context - - # ==================== Stage 5: Self-consistency + Verification ==================== - if answer and decomposition.query_type in ("calculation", "comparison"): - answer = await self._deep_self_consistency( - query, answer, evidence, decomposition, artifacts, retrieval, context, - ) - answer, _ = self._deep_verify_answer(query, answer, evidence) - - # --- Token accounting --- - new_usages = self.llm_usages[_llm_usage_start:] - for usage in new_usages: - if usage and isinstance(usage, dict): - total_tok = usage.get("total_tokens", 0) - if total_tok == 0: - total_tok = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0) - context.add_llm_tokens(total_tok, usage=usage) - - # --- Persistence --- - cluster: Optional[KnowledgeCluster] = None - if should_save and answer: - cluster = self._make_answer_cluster( - query, evidence[:5000], "DEEP", - file_paths=retrieval.file_paths[:5], - ) - cluster.content = evidence[:10000] - self._add_query_to_cluster(cluster, query) - try: - await self._save_cluster_with_embedding(cluster) - except Exception as exc: - _loguru_logger.warning(f"[DEEP:S5] Cluster save failed: {exc}") - - await self._logger.success(f"[DEEP] Complete: {context.summary()}") - return answer, cluster, context - - # ------------------------------------------------------------------ - # DEEP v2: Staged pipeline methods - # ------------------------------------------------------------------ - - async def _deep_retrieve( - self, - query: str, - paths: List[str], - artifacts: "CompileArtifacts", - scope: "_PathScope", - context: "SearchContext", - *, - top_k_files: int = 5, - enable_dir_scan: bool = False, - max_depth: Optional[int] = 5, - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - ) -> "DeepRetrieval": - """Stage 0: FAST-style file discovery and ranking. - - Reuses the proven FAST retrieval pipeline: query analysis (1 LLM) - + keyword search (rga) + tree probe (scope-filtered) + catalog - routing + compile hints + summary index. Returns a structured - DeepRetrieval with ranked file paths. - """ - catalog = artifacts.catalog - catalog_routed_files: List[str] = [] + # P2: gradient reuse — extract hints from moderately similar clusters + soft_hit = await self._try_soft_reuse(query, paths) - tree_hints = "" - if artifacts and artifacts.tree_available_paths: - tree_hints = self._build_tree_root_hints(artifacts) + await self._logger.info(f"[search] Starting multi-path retrieval for: '{query[:80]}'") - if catalog: - listing = self._build_enriched_catalog_listing(catalog) - prompt = FAST_QUERY_ANALYSIS_WITH_CATALOG.format( - user_input=query, document_listing=listing, - ) - else: - prompt = FAST_QUERY_ANALYSIS.format(user_input=query) - if tree_hints: - prompt = prompt + tree_hints - - llm_task = self.llm.achat( - messages=[{"role": "user", "content": prompt}], - stream=False, - ) - compile_task = self._probe_compile_hints([query], scope=scope) - tree_task = self._probe_tree_index(query, scope=scope, artifacts=artifacts) - summary_task = self._probe_summary_index(query, artifacts, scope=scope) - catalog_deep_task = self._probe_catalog_for_deep(query, artifacts) + # ============================================================== + # Phase 1: Parallel probing — five paths fire concurrently + # ============================================================== + await self._logger.info("[Phase 1] Parallel probing: keywords + dir_scan + knowledge + spec_cache + tree_index") + context.increment_loop() - results = await asyncio.gather( - llm_task, compile_task, tree_task, summary_task, catalog_deep_task, + phase1_results = await asyncio.gather( + self._probe_keywords(query), + self._probe_dir_scan(paths, enable_dir_scan), + self._probe_knowledge_cache(query), + self._load_spec_context(paths, stale_hours=spec_stale_hours), + self._probe_tree_index(query), + self._probe_compile_hints([query], scope=_scope), # query-level hints; keyword-level runs post-Phase 1 + self._probe_summary_index(query, artifacts, scope=_scope), # GAP 2: zero-LLM BM25 + self._probe_catalog_for_deep(query, artifacts), # GAP 4: zero-LLM keyword overlap return_exceptions=True, ) - resp = results[0] if not isinstance(results[0], Exception) else None - early_hints = results[1] if not isinstance(results[1], Exception) else CompileHints([], []) - tree_probed = results[2] if not isinstance(results[2], Exception) else [] - summary_hits = results[3] if not isinstance(results[3], Exception) else [] - catalog_deep_hits = results[4] if not isinstance(results[4], Exception) else [] + kw_result = phase1_results[0] if not isinstance(phase1_results[0], Exception) else ({}, []) + scan_result = phase1_results[1] if not isinstance(phase1_results[1], Exception) else None + knowledge_probe = phase1_results[2] if not isinstance(phase1_results[2], Exception) else KnowledgeProbeResult([], [], "") + spec_context = phase1_results[3] if not isinstance(phase1_results[3], Exception) else "" + tree_hits = phase1_results[4] if not isinstance(phase1_results[4], Exception) else [] + compile_hints = phase1_results[5] if not isinstance(phase1_results[5], Exception) else CompileHints([], []) + summary_index_hits = phase1_results[6] if not isinstance(phase1_results[6], Exception) else [] + catalog_deep_hits = phase1_results[7] if not isinstance(phase1_results[7], Exception) else [] - for i, label in enumerate(["llm", "compile", "tree", "summary", "catalog"]): - if isinstance(results[i], Exception): - await self._logger.warning(f"[DEEP:S0] {label} failed: {results[i]}") + for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index", "compile_hints", "summary_index", "catalog_deep"]): + if isinstance(phase1_results[i], Exception): + await self._logger.warning(f"[Phase 1] {label} probe failed: {phase1_results[i]}") - if resp and not isinstance(resp, Exception): - self.llm_usages.append(resp.usage) - if resp.usage and isinstance(resp.usage, dict): - context.add_llm_tokens( - resp.usage.get("total_tokens", 0), usage=resp.usage, - ) + # Backwards compat: knowledge_probe may be a plain list from old code paths + if isinstance(knowledge_probe, list): + knowledge_probe = KnowledgeProbeResult(file_paths=knowledge_probe, extra_keywords=[], background_context="") - analysis = self._parse_fast_json(resp.content if resp else "") - primary = analysis.get("primary", [])[:2] - fallback = analysis.get("fallback", [])[:3] - primary_alt = analysis.get("primary_alt", [])[:2] - fallback_alt = analysis.get("fallback_alt", [])[:3] - if primary_alt: - primary = primary + primary_alt - if fallback_alt: - fallback = fallback + fallback_alt - keyword_idfs: Dict[str, float] = analysis.get("idf", {}) - all_keywords = primary + fallback - - if catalog: - for idx in analysis.get("selected_docs", []): - if isinstance(idx, int) and 0 <= idx < len(catalog): - fp = catalog[idx]["path"] - if Path(fp).exists(): - catalog_routed_files.append(fp) + query_keywords, initial_keywords = kw_result if isinstance(kw_result, tuple) else ({}, []) - kw_hints = await self._probe_compile_hints(all_keywords, scope=scope) - compile_hints = self._merge_compile_hints(early_hints, kw_hints) - for kw in compile_hints.extra_keywords: - if kw not in all_keywords: - all_keywords.append(kw) - keyword_idfs.setdefault(kw, 0.5) + # P2: inject soft-hit patterns into keywords + if soft_hit: + for p in soft_hit.patterns: + if p not in initial_keywords: + initial_keywords.append(p) + if p not in query_keywords: + query_keywords[p] = 0.6 + + # P3: inject extra keywords from structured knowledge probe + for kw in knowledge_probe.extra_keywords: + if kw not in initial_keywords: + initial_keywords.append(kw) + if kw not in query_keywords: + query_keywords[kw] = 0.5 + + # P2 + P3: append background context for Phase 4 LLM prompt + if soft_hit and soft_hit.context_summary: + spec_context = f"{spec_context}\n\n{soft_hit.context_summary}" if spec_context else soft_hit.context_summary + if knowledge_probe.background_context: + spec_context = f"{spec_context}\n\n{knowledge_probe.background_context}" if spec_context else knowledge_probe.background_context - context.increment_loop() await self._logger.info( - f"[DEEP:S0] keywords={len(all_keywords)}, " - f"catalog_routed={len(catalog_routed_files)}, " - f"tree_probed={len(tree_probed)}, " - f"summary={len(summary_hits)}, " - f"catalog_deep={len(catalog_deep_hits)}" + f"[Phase 1] Results: keywords={len(initial_keywords)}, " + f"dir_scan={'OK' if scan_result else 'N/A'}, " + f"knowledge_files={len(knowledge_probe.file_paths)}, " + f"tree_hits={len(tree_hits)}, " + f"compile_hints={len(compile_hints.file_paths)}, " + f"summary_index={len(summary_index_hits)}, " + f"catalog_deep={len(catalog_deep_hits)}, " + f"soft_hit={'YES' if soft_hit else 'NO'}, " + f"spec_cache={'YES' if spec_context else 'NO'}" ) - rga_kwargs = dict( - paths=paths, max_depth=max_depth, - include=list(include or []), exclude=exclude, - ) - tree_probed_set = frozenset(tree_probed) - - best_files: Optional[List[Dict[str, Any]]] = None - - if catalog_routed_files and analysis.get("doc_confidence") == "high": - best_files = [ - {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} - for p in catalog_routed_files[:top_k_files] - ] + # ============================================================== + # Phase 2: Parallel retrieval — keyword search + dir_scan rank + # ============================================================== + keyword_files: List[str] = [] + dir_scan_files: List[str] = [] - if not best_files and tree_probed_set and primary: - best_files = await self._fast_find_best_file( - primary, paths=list(tree_probed_set), - top_k=top_k_files, keyword_idfs=keyword_idfs, - query=query, artifacts=artifacts, - ) + if _PURE_TREE_SEARCH: + # Pure tree search mode: skip rga and dir_scan, rely solely on tree hits + await self._logger.info("[Phase 2:PureTree] Skipping rga keyword search and dir_scan") + context.increment_loop() + else: + await self._logger.info("[Phase 2] Parallel retrieval: rga keyword search + dir_scan LLM rank") + context.increment_loop() - if not best_files and primary: - best_files = await self._fast_find_best_file( - primary, top_k=top_k_files, keyword_idfs=keyword_idfs, - query=query, artifacts=artifacts, - tree_probed_paths=tree_probed_set or None, - **rga_kwargs, - ) + phase2_tasks = [] - if not best_files and fallback: - best_files = await self._fast_find_best_file( - fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, - query=query, artifacts=artifacts, - tree_probed_paths=tree_probed_set or None, - **rga_kwargs, - ) + if initial_keywords: + phase2_tasks.append( + self._retrieve_by_keywords( + initial_keywords, paths, + max_depth=max_depth, include=include, exclude=exclude, + ) + ) + else: + phase2_tasks.append(self._async_noop([])) - hint_files: List[str] = [] - seen: set = set() - for fp in catalog_routed_files + list(tree_probed) + summary_hits + catalog_deep_hits + compile_hints.file_paths: - if fp and fp not in seen: - seen.add(fp) - hint_files.append(fp) + if scan_result is not None and enable_dir_scan: + phase2_tasks.append( + self._rank_dir_scan_candidates(query, scan_result) + ) + else: + phase2_tasks.append(self._async_noop([])) - if not best_files and hint_files: - best_files = [ - {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} - for p in hint_files[:top_k_files] - ] + phase2_results = await asyncio.gather(*phase2_tasks, return_exceptions=True) - if not best_files and enable_dir_scan: - ranked = await self._scan_and_rank_paths( - query, paths, top_k=top_k_files, include_medium=True, - ) - if ranked: - best_files = [ - {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} - for p in ranked[:top_k_files] - ] + keyword_files = phase2_results[0] if not isinstance(phase2_results[0], Exception) else [] + dir_scan_files = phase2_results[1] if not isinstance(phase2_results[1], Exception) else [] - file_paths: List[str] = [] - if best_files: - fp_set: set = set() - for bf in best_files: - fp = bf["path"] - if fp not in fp_set: - fp_set.add(fp) - file_paths.append(fp) - for fp in hint_files: - if fp not in fp_set and len(file_paths) < top_k_files: - fp_set.add(fp) - file_paths.append(fp) + for i, label in enumerate(["keyword_search", "dir_scan_rank"]): + if isinstance(phase2_results[i], Exception): + await self._logger.warning(f"[Phase 2] {label} failed: {phase2_results[i]}") await self._logger.info( - f"[DEEP:S0] Retrieved {len(file_paths)} files: " - f"{[Path(p).name for p in file_paths[:5]]}" + f"[Phase 2] Results: keyword_files={len(keyword_files)}, " + f"dir_scan_files={len(dir_scan_files)}" ) - return DeepRetrieval( - file_paths=file_paths, - keywords=all_keywords, - keyword_idfs=keyword_idfs, - catalog_routed=catalog_routed_files, - tree_probed=list(tree_probed), - ) - - async def _deep_gather_evidence( - self, - query: str, - retrieval: "DeepRetrieval", - artifacts: "CompileArtifacts", - context: "SearchContext", - ) -> str: - """Stage 1: Evidence saturation for all retrieved files. - - Gathers evidence from multiple sources per file: - 1. Tree navigation (LLM-guided section targeting) - 2. Table digest (pre-compiled structured tables, expanded budget) - 3. Financial statement harvesting (rule-based section title match) - 4. Tree-guided sampling / rga fallback - - Uses DEEP-specific budgets (_DEEP_EVIDENCE_TOTAL_CHARS) that are - larger than FAST's to maximize evidence quality. - """ - if not retrieval.file_paths: - return "" - - file_paths = retrieval.file_paths - tree_paths = artifacts.tree_available_paths if artifacts else set() - max_chars = self._DEEP_EVIDENCE_TOTAL_CHARS - - async def _gather_for_file(fp: str) -> str: - parts: List[str] = [] - fname = Path(fp).name - - nav_ev = "" - if fp in tree_paths: - try: - nav_ev = await self._navigate_tree_for_evidence( - fp, query, - max_results=self._TREE_NAV_MAX_RESULTS, - ) or "" - except Exception: - pass - if nav_ev: - parts.append(nav_ev) - table_ev = "" - try: - from sirchmunk.utils.file_utils import get_fast_hash - fh = get_fast_hash(fp) - if fh: - tables = self._load_table_digest(self.work_path, fh) - if tables: - budget = ( - self._DEEP_TABLE_BUDGET_WITH_NAV if nav_ev - else self._DEEP_TABLE_BUDGET + # --- Phase 2.5: Parallel tree pre-navigation for top tree hits --- + _pre_nav_evidence: Dict[str, str] = {} + if tree_hits: + _nav_fps = [fp for fp in tree_hits[:self._DEEP_PRE_NAV_MAX_FILES]] + if _nav_fps: + _nav_results = await asyncio.gather( + *[self._tree_guided_sample( + fp, query, max_chars=self._FAST_MAX_EVIDENCE_CHARS, + ) for fp in _nav_fps], + return_exceptions=True, + ) + for fp, nav_res in zip(_nav_fps, _nav_results): + if isinstance(nav_res, Exception): + await self._logger.warning( + f"[Phase 2.5] Tree pre-nav failed for {Path(fp).name}: {nav_res}" ) - table_ev = self._format_table_evidence( - tables, max_chars=budget, query=query, - ) or "" - except Exception: - pass - if table_ev: - parts.append(f"[{fname} - Table Evidence]\n{table_ev}") - - stmt_ev = await self._harvest_financial_statements(fp, query, artifacts) - if stmt_ev: - parts.append(f"[{fname} - Financial Statements]\n{stmt_ev}") - - if not nav_ev and fp in tree_paths: - try: - tree_sample = await self._tree_guided_sample( - fp, query, - max_chars=self._FAST_MAX_EVIDENCE_CHARS, - artifacts=artifacts, + elif isinstance(nav_res, str) and nav_res: + _pre_nav_evidence[fp] = nav_res + if _pre_nav_evidence: + await self._logger.info( + f"[Phase 2.5] Pre-navigated {len(_pre_nav_evidence)} tree files" ) - if tree_sample: - parts.append(tree_sample) - except Exception: - pass - - if not parts: - try: - rga_ev = await self._fast_sample_evidence(fp, []) - if rga_ev: - parts.append(rga_ev) - except Exception: - pass - - context.mark_file_read(fp) - if parts: - return f"[Source: {fname}]\n" + "\n\n".join(parts) - return "" - - tasks = [_gather_for_file(fp) for fp in file_paths] - results = await asyncio.gather(*tasks, return_exceptions=True) - - evidence_parts: List[str] = [] - total_chars = 0 - for r in results: - if isinstance(r, Exception): - continue - if r and total_chars < max_chars: - remaining = max_chars - total_chars - evidence_parts.append(r[:remaining]) - total_chars += len(evidence_parts[-1]) + # ============================================================== + # Phase 3: Merge file paths + build KnowledgeCluster + # P1 tree hits get highest priority; P2 soft-hit files next + # ============================================================== context.increment_loop() - combined = "\n\n---\n\n".join(evidence_parts) - await self._logger.info( - f"[DEEP:S1] Evidence: {len(combined)} chars from " - f"{len(evidence_parts)} files" - ) - return combined - - async def _harvest_financial_statements( - self, - file_path: str, - query: str, - artifacts: "CompileArtifacts", - ) -> str: - """Proactively extract financial statement sections via tree index. - - Scans tree section titles for income/balance/cashflow patterns and - extracts those pages. Complements tree navigation which may focus - on narrative sections instead of data-dense statements. - """ - tree_paths = artifacts.tree_available_paths if artifacts else set() - if file_path not in tree_paths: - return "" - - indexer = self._get_tree_indexer() - if indexer is None: - return "" - - tree = indexer.load_tree(file_path) - if tree is None or tree.root is None: - return "" - - statement_sections: List[Dict[str, Any]] = [] - section_map, sections_meta = self._build_section_map( - tree.root, max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH, - ) - - for sec in sections_meta: - title = (sec.get("title") or "").lower() - if any(pat.search(title) for pat in self._DEEP_STATEMENT_PATTERNS): - if sec.get("page_range"): - statement_sections.append(sec) - - if not statement_sections: - return "" - - try: - ev = await self._extract_targeted_pages( - file_path, statement_sections[:6], query, - ) - return ev or "" - except Exception: - return "" - - async def _deep_decompose_question( - self, - query: str, - evidence: str, - context: "SearchContext", - ) -> "DeepDecomposition": - """Stage 2: Decompose the question into structured plan (1 LLM call). + extra_knowledge_files = knowledge_probe.file_paths + if soft_hit: + extra_knowledge_files = soft_hit.file_paths + extra_knowledge_files - Classifies query type (lookup/calculation/comparison/synthesis), - extracts required data points, time periods, entities, and - calculation steps. - """ - evidence_summary = evidence[:3000] if evidence else "(no evidence yet)" - prompt = DEEP_QUESTION_DECOMPOSE.format( - query=query, - evidence_summary=evidence_summary, - ) - try: - resp = await self.llm.achat( - messages=[{"role": "user", "content": prompt}], - stream=False, - ) - self.llm_usages.append(resp.usage) - if resp.usage and isinstance(resp.usage, dict): - context.add_llm_tokens( - resp.usage.get("total_tokens", 0), usage=resp.usage, + if _PURE_TREE_SEARCH: + # Pure tree search: only use tree hits (+ soft-hit fallback if no tree hits) + pure_tree_files = list(tree_hits) + if not pure_tree_files and soft_hit: + pure_tree_files = soft_hit.file_paths + await self._logger.info( + f"[Phase 3:PureTree] No tree hits, using {len(pure_tree_files)} soft-hit files" ) - context.increment_loop() - - raw = (resp.content or "").strip() - parsed = self._parse_fast_json(raw) - - return DeepDecomposition( - query_type=parsed.get("query_type", "lookup"), - sub_questions=parsed.get("sub_questions", [query]), - required_data=parsed.get("required_data", []), - time_periods=parsed.get("time_periods", []), - entities=parsed.get("entities", []), - calculation_steps=parsed.get("calculation_steps", []), + merged_files = self._merge_file_paths( + keyword_files=pure_tree_files, + dir_scan_files=[], + knowledge_hits=[], ) - except Exception as exc: - await self._logger.warning(f"[DEEP:S2] Decomposition failed: {exc}") - return DeepDecomposition(query_type="lookup", sub_questions=[query]) - - @staticmethod - def _deep_check_adequacy( - query: str, - evidence: str, - decomposition: "DeepDecomposition", - ) -> Tuple[bool, List[str]]: - """Stage 3: Rule-based evidence adequacy check (0 LLM calls). - - Checks whether the evidence contains the required data points - from the decomposition. Returns (adequate, gap_descriptions). - """ - if not evidence or len(evidence.strip()) < 200: - return False, ["evidence too short"] - - evidence_lower = evidence.lower() - gaps: List[str] = [] - - for data_point in decomposition.required_data: - tokens = [t.lower() for t in re.findall(r"[A-Za-z0-9]+", data_point) if len(t) >= 3] - if tokens and not any(t in evidence_lower for t in tokens): - gaps.append(data_point) - - for period in decomposition.time_periods: - year_match = re.search(r"(\d{4})", period) - if year_match and year_match.group(1) not in evidence: - gaps.append(f"time period: {period}") - - if decomposition.query_type == "calculation": - numbers = re.findall(r'[\$€£]?\d[\d,]*\.?\d*', evidence) - if len(numbers) < 2: - gaps.append("insufficient numeric data for calculation") - - adequate = len(gaps) <= len(decomposition.required_data) * 0.3 - return adequate, gaps - - async def _deep_fill_evidence_gaps( - self, - query: str, - gaps: List[str], - retrieval: "DeepRetrieval", - artifacts: "CompileArtifacts", - context: "SearchContext", - decomposition: Optional["DeepDecomposition"] = None, - ) -> str: - """Fill identified evidence gaps with targeted retrieval. - - Strategy per file: - 1. Tree section selection with expanded depth for gap-specific terms - 2. Table digest supplement for numeric gaps - 3. Keyword rga fallback for non-tree files - """ - extra_parts: List[str] = [] - tree_paths = artifacts.tree_available_paths if artifacts else set() - indexer = self._get_tree_indexer() - is_calc = decomposition and decomposition.query_type in ("calculation", "comparison") - - for fp in retrieval.file_paths[:3]: - gap_query = f"{query} — specifically looking for: {'; '.join(gaps[:5])}" - - if fp in tree_paths and indexer is not None: - tree = indexer.load_tree(fp) - if tree is not None and tree.root is not None: - section_map, sections_meta = self._build_section_map( - tree.root, max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH + 2, - ) - if sections_meta: - selected = await self._select_evidence_sections( - gap_query, section_map, sections_meta, - ) - context.increment_loop() - if selected: - ev = await self._extract_targeted_pages(fp, selected, query) - if ev and len(ev.strip()) > 100: - extra_parts.append(f"[Gap-fill: {Path(fp).name}]\n{ev}") - context.mark_file_read(fp) - - if is_calc and any("numeric" in g or "data" in g for g in gaps): - table_index = (artifacts.table_index or {}).get(fp, []) - if table_index: - table_ev = self._format_table_evidence( - table_index, max_chars=self._DEEP_TABLE_BUDGET_WITH_NAV, query=query, - ) - if table_ev and len(table_ev.strip()) > 100: - extra_parts.append(f"[Table-supplement: {Path(fp).name}]\n{table_ev}") - - if extra_parts: await self._logger.info( - f"[DEEP:S3] Gap-fill: {len(extra_parts)} additional evidence sources" - ) - return "\n\n---\n\n".join(extra_parts) - - async def _deep_synthesize( - self, - query: str, - evidence: str, - decomposition: "DeepDecomposition", - artifacts: "CompileArtifacts", - retrieval: "DeepRetrieval", - context: "SearchContext", - ) -> Tuple[str, bool, bool]: - """Stage 4: Strategy-routed answer synthesis. - - Routes to specialized prompts based on query_type: - - calculation: DEEP_CALCULATION_SYNTHESIS with explicit computation steps - - lookup/comparison/synthesis: ROI_RESULT_SUMMARY with document context - - Pre-processing: prunes evidence by entities/time_periods, and - appends format constraints derived from query semantics. - """ - evidence = self._deep_prune_evidence(evidence, decomposition) - - format_hint = self._deep_format_constraint(query) - - doc_context: Optional[str] = None - if artifacts and artifacts.catalog_map: - ctx_parts = [ - self._build_answer_context(fp, artifacts) - for fp in retrieval.file_paths[:3] - ] - ctx_parts = [c for c in ctx_parts if c] - if ctx_parts: - doc_context = "\n".join(ctx_parts) - - if decomposition.query_type == "calculation" and decomposition.calculation_steps: - steps_text = "\n".join( - f"{i+1}. {s}" for i, s in enumerate(decomposition.calculation_steps) - ) - synth_prompt = DEEP_CALCULATION_SYNTHESIS.format( - calculation_steps=steps_text, - user_input=query, - text_content=evidence[:self._DEEP_SYNTHESIS_MAX_CHARS], - ) - elif doc_context: - from sirchmunk.llm.prompts import ROI_RESULT_SUMMARY_WITH_CONTEXT - synth_prompt = ROI_RESULT_SUMMARY_WITH_CONTEXT.format( - user_input=query, - text_content=evidence[:self._DEEP_SYNTHESIS_MAX_CHARS], - document_context=doc_context, + f"[Phase 3:PureTree] Merged {len(merged_files)} tree-only candidate files" ) else: - synth_prompt = ROI_RESULT_SUMMARY.format( - user_input=query, - text_content=evidence[:self._DEEP_SYNTHESIS_MAX_CHARS], + merged_files = self._merge_file_paths( + keyword_files=list(tree_hits) + catalog_deep_hits + compile_hints.file_paths + summary_index_hits + keyword_files, + dir_scan_files=dir_scan_files, + knowledge_hits=extra_knowledge_files, ) + await self._logger.info(f"[Phase 3] Merged {len(merged_files)} unique candidate files") - if format_hint: - synth_prompt = f"{synth_prompt}\n\n### Format Constraint\n{format_hint}" - - resp = await self.llm.achat( - messages=[{"role": "user", "content": synth_prompt}], - stream=True, - ) - self.llm_usages.append(resp.usage) - if resp.usage and isinstance(resp.usage, dict): - context.add_llm_tokens( - resp.usage.get("total_tokens", 0), usage=resp.usage, + cluster: Optional[KnowledgeCluster] = None + if merged_files: + cluster = await self._build_cluster( + query=query, file_paths=merged_files, + query_keywords=query_keywords, top_k_files=top_k_files, ) - context.increment_loop() - - answer, should_save, should_answer = self._parse_summary_response( - resp.content or "" - ) - - accepted, accept_reason = self._evaluate_evidence_acceptance( - query, evidence, should_answer, - ) - - if not accepted and decomposition.query_type in ("calculation", "comparison"): - accepted = self._deep_relaxed_acceptance(query, evidence) - if accepted: - accept_reason = "deep_calc_relaxed" - should_answer = True - should_save = True - - await self._logger.info( - f"[DEEP:S4] Synthesis: accepted={accepted} ({accept_reason}), " - f"type={decomposition.query_type}" - ) - - if not accepted: - return "", False, False - return answer, should_save, should_answer - - @staticmethod - def _deep_relaxed_acceptance(query: str, evidence: str) -> bool: - """Relaxed acceptance for calculation/comparison queries. - - Accepts when evidence contains >=2 distinct numbers AND at least - one query-relevant keyword appears. This prevents false-negative - rejections on numeric data that the standard heuristic misses - due to low keyword coverage from formula-heavy queries. - """ - numbers = re.findall(r'[\$€£]?\d[\d,]*\.?\d*', evidence[:20000]) - if len(numbers) < 2: - return False - kw_coverage = AgenticSearch._compute_keyword_coverage(query, evidence) - return kw_coverage >= 0.3 - - @staticmethod - def _deep_prune_evidence( - evidence: str, - decomposition: "DeepDecomposition", - ) -> str: - """Prune evidence paragraphs not matching required time_periods/entities. - - Splits evidence into paragraph-level blocks. Retains a block if it - mentions ANY required time period or entity. Blocks that match - neither are discarded (unless fewer than 30% of blocks would remain, - in which case no pruning is applied to avoid over-filtering). - """ - if not decomposition.time_periods and not decomposition.entities: - return evidence - - periods = {p.lower() for p in decomposition.time_periods} - year_patterns = {re.search(r"\d{4}", p).group() for p in periods if re.search(r"\d{4}", p)} - entities = {e.lower() for e in decomposition.entities} - - blocks = re.split(r'\n{2,}', evidence) - if len(blocks) <= 3: - return evidence - - kept: List[str] = [] - for block in blocks: - block_lower = block.lower() - has_period = any(y in block for y in year_patterns) if year_patterns else True - has_entity = any(e in block_lower for e in entities) if entities else True - if has_period or has_entity: - kept.append(block) - - if len(kept) < len(blocks) * 0.3: - return evidence - return "\n\n".join(kept) - - @staticmethod - def _deep_format_constraint(query: str) -> str: - """Derive answer format guidance from query semantics. - - Returns a short instruction string appended to the synthesis prompt - to steer PRECISE_ANSWER toward the expected format. - """ - q_lower = query.lower() - if re.search(r'\b(is|does|did|was|were|has|have|can|will|should)\b', q_lower) and "?" in query: - return "Answer with Yes or No first, then provide justification." - if re.search(r'(?:million|billion|mn|bn)\b', q_lower): - unit = "billion" if re.search(r'\b(billion|bn)\b', q_lower) else "million" - return f"Express the final answer in {unit}s (e.g. $X {unit})." - if re.search(r'\bratio\b', q_lower): - return "Express the final answer as a decimal ratio (e.g. 1.5x or 0.75)." - if re.search(r'\b(percentage|percent|%)\b', q_lower): - return "Express the final answer as a percentage (e.g. 25.3%)." - if re.search(r'\bgrowth\b.*\brate\b|\brate\b.*\bgrowth\b', q_lower): - return "Express the final answer as a percentage change (e.g. +12.5% or -3.2%)." - return "" - async def _deep_self_consistency( - self, - query: str, - first_answer: str, - evidence: str, - decomposition: "DeepDecomposition", - artifacts: "CompileArtifacts", - retrieval: "DeepRetrieval", - context: "SearchContext", - ) -> str: - """Run a second synthesis and pick the consistent answer. - - Compares PRECISE_ANSWER from both runs. If they match (within - numeric tolerance), returns the first. If they diverge, picks - the answer whose PRECISE_ANSWER contains a valid number. - """ - second_answer, _, _ = await self._deep_synthesize( - query, evidence, decomposition, artifacts, retrieval, context, - ) - if not second_answer: - return first_answer - - precise_1 = re.search(r'\*\*Answer:\s*(.+?)\*\*', first_answer) - precise_2 = re.search(r'\*\*Answer:\s*(.+?)\*\*', second_answer) - if not precise_1 or not precise_2: - return first_answer - - val_1 = re.sub(r'[^\d.\-]', '', precise_1.group(1).replace(',', '')) - val_2 = re.sub(r'[^\d.\-]', '', precise_2.group(1).replace(',', '')) - - try: - n1, n2 = float(val_1), float(val_2) - except (ValueError, TypeError): - return first_answer + # ============================================================== + # Phase 3.5: Graph context enrichment (P5) + # Append related knowledge from graph neighbours to cluster content + # so the answer-generation LLM has richer context. + # ============================================================== + graph_ctx = "" + if cluster: + # Merge pre-navigated tree evidence into cluster content + if _pre_nav_evidence and cluster.content: + pre_nav_parts = [] + for fp, ev in _pre_nav_evidence.items(): + pre_nav_parts.append(f"[Tree evidence: {Path(fp).name}]\n{ev}") + if pre_nav_parts: + pre_nav_ctx = "\n\n".join(pre_nav_parts) + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = f"{cluster.content}\n\n{pre_nav_ctx}" + + graph_ctx = await self._gather_graph_context(cluster) + if graph_ctx and cluster.content: + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = f"{cluster.content}\n\n{graph_ctx}" - tolerance = max(abs(n1) * 0.05, 0.01) - if abs(n1 - n2) <= tolerance: - return first_answer + # ============================================================== + # Phase 4: Structured Reasoning → Cluster Summary fallback + # P0: DEEP mode always goes through full reasoning pipeline — + # no fast triage short-circuit. P4: query complexity determines + # whether the heavier section-map SR fires or we go straight to + # cluster synthesis. + # ============================================================== + context.increment_loop() + answer = "" + should_save = True + _query_complexity = self._classify_query_complexity(query) await self._logger.info( - f"[DEEP:S5] Self-consistency divergence: {val_1} vs {val_2}, using second" + f"[Phase 4] Query complexity: {_query_complexity}" ) - return second_answer - - @staticmethod - def _deep_verify_answer( - query: str, - answer: str, - evidence: str, - ) -> Tuple[str, bool]: - """Stage 5: Verify calculation answers with Python eval. - Extracts the COMPUTATION block from the answer, parses variable - assignments (``var = expr``), evaluates them in a safe namespace, - and compares the final result against PRECISE_ANSWER. When a - discrepancy is found, replaces the PRECISE_ANSWER with the - recomputed value. + # Attempt structured reasoning for moderate/complex queries + _sr_files: List[str] = [] + if _query_complexity != "simple": + if tree_hits: + _sr_files = list(tree_hits[: self._DEEP_STRUCTURED_MAX_FILES]) + elif artifacts and artifacts.tree_available_paths: + _sr_files = list(artifacts.tree_available_paths)[ + : self._DEEP_STRUCTURED_MAX_FILES + ] - Returns (potentially_corrected_answer, verified). - """ - computation_match = re.search( - r'(.*?)', answer, re.DOTALL, - ) - precise_match = re.search( - r'(.*?)', answer, re.DOTALL, - ) - if not computation_match: - return answer, True + if _sr_files: + await self._logger.info( + f"[Phase 4] Launching structured reasoning for " + f"{len(_sr_files)} tree-indexed files" + ) + sr_answer, sr_cluster, sr_evidence = await self._deep_structured_reasoning( + query, _sr_files, artifacts, context, + ) - computation_text = computation_match.group(1) - assignments = re.findall( - r'(?:^|\n)\s*[\w\s]+?=\s*(.+?)(?:\s*\(|$|\n)', - computation_text, - ) - if not assignments: - return answer, True - - safe_ns: Dict[str, float] = {} - last_result: Optional[float] = None - - for expr_raw in assignments: - expr = expr_raw.strip().rstrip("(") - expr = re.sub(r'[\$€£,]', '', expr) - expr = expr.replace('−', '-').replace('÷', '/').replace('×', '*') - expr = re.sub(r'[a-zA-Z%]+$', '', expr).strip() - if not expr or not re.search(r'\d', expr): - continue - try: - result = float(eval(expr, {"__builtins__": {}}, safe_ns)) # noqa: S307 - last_result = result - except Exception: - continue + if sr_answer: + answer, should_save, should_answer = self._parse_summary_response( + sr_answer + ) + accepted, accept_reason = self._evaluate_evidence_acceptance( + query, sr_evidence or sr_answer, should_answer, + ) + await self._logger.info( + f"[Phase 4] Structured reasoning: " + f"accepted={accepted} ({accept_reason})" + ) + if accepted: + cluster = sr_cluster or cluster + else: + answer = "" - if last_result is None or not precise_match: - return answer, True + # Fallback: cluster summary with ROI prompt or ReAct + if not answer: + if artifacts and artifacts.catalog_map and cluster and cluster.content: + _catalog_ctx_parts = [] + for fp in (cluster.search_results or merged_files)[:3]: + ctx = self._build_answer_context(fp, artifacts) + if ctx: + _catalog_ctx_parts.append(ctx) + if _catalog_ctx_parts: + _catalog_context = "\n".join(_catalog_ctx_parts) + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = ( + f"{cluster.content}\n\n" + f"[Document Context]\n{_catalog_context}" + ) - precise_text = precise_match.group(1).strip() - precise_num = re.sub(r'[^\d.\-]', '', precise_text.replace(',', '')) - try: - stated = float(precise_num) if precise_num else None - except ValueError: - stated = None + if cluster and cluster.content: + await self._logger.info( + "[Phase 4:Fallback] Generating summary from cluster" + ) + answer, should_save, should_answer = ( + await self._summarise_cluster(query, cluster) + ) + cluster_evidence = ( + str(cluster.content) if cluster.content else "" + ) + accepted, accept_reason = ( + self._evaluate_evidence_acceptance( + query, cluster_evidence, should_answer, + ) + ) + if not accepted: + if llm_fallback: + answer, should_save = ( + await self._summarise_cluster_fallback(query) + ) + else: + # DEEP self-correction before giving up + sc_evidence = await self._deep_self_correct( + query, merged_files, query_keywords, context, + ) + if sc_evidence: + sc_cluster = self._make_answer_cluster( + query, sc_evidence[:5000], "DSC", + file_paths=list(merged_files)[:3], + ) + sc_cluster.content = sc_evidence + answer, should_save, should_answer = ( + await self._summarise_cluster(query, sc_cluster) + ) + sc_accepted, _ = self._evaluate_evidence_acceptance( + query, sc_evidence, should_answer, + ) + if sc_accepted: + cluster = sc_cluster + else: + return _NO_RESULTS_MESSAGE, None, context + else: + return _NO_RESULTS_MESSAGE, None, context + if not cluster.search_results: + cluster.search_results = list(merged_files) + elif llm_fallback: + answer, should_save = ( + await self._summarise_cluster_fallback(query) + ) + else: + await self._logger.info( + "[Phase 4:Fallback] Launching ReAct refinement" + ) + # Seed ReAct with all available prior context so it + # doesn't start from scratch. + react_parts: List[str] = [] + if spec_context: + react_parts.append(spec_context) + if graph_ctx: + react_parts.append(graph_ctx) + if _pre_nav_evidence: + nav_seed = "\n\n".join( + f"[Pre-navigated: {Path(fp).name}]\n{ev}" + for fp, ev in _pre_nav_evidence.items() + ) + react_parts.append(nav_seed) + react_spec = "\n\n".join(react_parts) + react_answer, context = await self._react_refinement( + query=query, paths=paths, + initial_keywords=initial_keywords, + spec_context=react_spec, + enable_dir_scan=enable_dir_scan, + max_loops=max_loops, + max_token_budget=max_token_budget, + max_depth=max_depth, + include=include, exclude=exclude, + ) + if not cluster: + cluster = await self._build_cluster_from_context( + query=query, answer=react_answer, + context=context, + query_keywords=query_keywords, + top_k_files=top_k_files, + ) + elif react_answer and not cluster.content: + cluster.content = react_answer + if not cluster: + return _NO_RESULTS_MESSAGE, None, context + answer, should_save, should_answer = ( + await self._summarise_cluster(query, cluster) + ) + final_evidence = ( + str(cluster.content) if cluster.content else "" + ) + final_accepted, _ = self._evaluate_evidence_acceptance( + query, final_evidence, should_answer, + ) + if not final_accepted: + if llm_fallback: + answer, should_save = ( + await self._summarise_cluster_fallback(query) + ) + else: + sc_evidence = await self._deep_self_correct( + query, merged_files, query_keywords, context, + ) + if sc_evidence: + sc_cluster = self._make_answer_cluster( + query, sc_evidence[:5000], "DSC", + file_paths=list(merged_files)[:3], + ) + sc_cluster.content = sc_evidence + answer, should_save, _ = ( + await self._summarise_cluster(query, sc_cluster) + ) + cluster = sc_cluster + else: + return _NO_RESULTS_MESSAGE, None, context - if stated is None: - return answer, True + # Sync LLM token accounting into context + new_usages = self.llm_usages[_llm_usage_start:] + for usage in new_usages: + if usage and isinstance(usage, dict): + total_tok = usage.get("total_tokens", 0) + if total_tok == 0: + total_tok = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0) + context.add_llm_tokens(total_tok, usage=usage) - tolerance = max(abs(stated) * 0.02, 0.01) - if abs(stated - last_result) <= tolerance: - return answer, True + # ============================================================== + # Phase 5: Persistence (quality-gated) + # Skipped when Phase 4 quality check says the answer is low-quality + # or when Phase 0 reused a cluster (early-returned above). + # ============================================================== + phase5_tasks = [] + if cluster and should_save: + self._add_query_to_cluster(cluster, query) + phase5_tasks.append(self._save_cluster_with_embedding(cluster)) + elif not should_save: + await self._logger.info("[Phase 5] Quality gate: low-quality answer, skipping cluster save") + cluster = None + phase5_tasks.append(self._save_spec_context(paths, context, scan_result=scan_result)) + results = await asyncio.gather(*phase5_tasks, return_exceptions=True) + for r in results: + if isinstance(r, Exception): + _loguru_logger.warning(f"[Phase 5] Persistence task failed: {r}") - fmt = f"{last_result:.1f}" if abs(last_result) < 100 else f"{last_result:,.0f}" - corrected = answer.replace( - precise_match.group(0), - f"{fmt}", - ) - return corrected, False + await self._logger.success(f"[search] Complete: {context.summary()}") + return answer, cluster, context # ------------------------------------------------------------------ # Phase 0a: Direct document analysis (intent-gated) @@ -2999,22 +2532,6 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _DEEP_STRUCTURED_MAX_FILES: int = 3 """Maximum files to process through structured reasoning pipeline.""" - # --- DEEP v2 evidence budgets --- - _DEEP_EVIDENCE_TOTAL_CHARS: int = 120_000 - """Total evidence budget for DEEP mode (uses more context than FAST).""" - _DEEP_TABLE_BUDGET: int = 40_000 - """Table digest budget per file in DEEP mode (no tree nav overlap).""" - _DEEP_TABLE_BUDGET_WITH_NAV: int = 20_000 - """Table digest budget per file when tree nav also provides evidence.""" - _DEEP_SYNTHESIS_MAX_CHARS: int = 60_000 - """Maximum evidence chars sent to the synthesis LLM call.""" - _DEEP_STATEMENT_PATTERNS: Tuple[re.Pattern, ...] = ( - re.compile(r"(?:income|operations|earnings|profit.loss)", re.IGNORECASE), - re.compile(r"(?:balance.sheet|financial.position|assets.liab)", re.IGNORECASE), - re.compile(r"(?:cash.flow|cash.provided|financing.activities)", re.IGNORECASE), - ) - """Section title patterns for financial statement harvesting.""" - # --- Evidence acceptance thresholds --- _EVIDENCE_MIN_ACCEPT_LENGTH: int = 800 """Minimum evidence character length for heuristic override.""" @@ -6280,27 +5797,15 @@ async def _llm_select_from_trees( if Path(pool[idx].file_path).exists() ] - async def _probe_tree_index( - self, - query: str, - *, - scope: Optional["_PathScope"] = None, - artifacts: Optional["CompileArtifacts"] = None, - ) -> List[str]: + async def _probe_tree_index(self, query: str) -> List[str]: """LLM-driven file discovery via compiled tree root summaries (PageIndex). - Loads cached document trees, filters them by *scope* and/or - *artifacts.tree_available_paths*, presents root summaries to the - LLM, and asks it to select the most relevant documents. + Loads all cached document trees, presents their root summaries to the + LLM, and asks it to select the most relevant documents. Returns file + paths of the most relevant documents. """ try: trees = self._load_cached_trees() - if not trees: - return [] - if artifacts and artifacts.tree_available_paths: - trees = [t for t in trees if t.file_path in artifacts.tree_available_paths] - if scope and not scope.is_empty: - trees = [t for t in trees if scope.contains(t.file_path)] if not trees: return [] result = await self._llm_select_from_trees( From 5a55aa887ac8c2e2361f9f66c5cdc4236fd0ff9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 15 May 2026 15:49:28 +0800 Subject: [PATCH 64/70] refine deep mode --- src/sirchmunk/llm/prompts.py | 172 ++++++++++ src/sirchmunk/search.py | 587 ++++++++++++++++++++++++++++++----- 2 files changed, 676 insertions(+), 83 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 4b4c3bb..2d893fe 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -533,6 +533,178 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: """ +# --------------------------------------------------------------------------- +# DEEP mode query classification (Plan B) +# --------------------------------------------------------------------------- + +DEEP_QUERY_CLASSIFY = """Classify this search query along two dimensions. + +Query: {query} + +1. **Complexity** — how many reasoning steps are needed: + - "simple": Direct lookup of a single value (e.g. "What was revenue in FY2023?") + - "moderate": Requires light computation from 1-2 data points (e.g. "What was the gross margin?") + - "complex": Multi-step computation, multi-period comparison, or cross-entity analysis + +2. **Intent** — what the user needs: + - "lookup": Find and extract a specific stated value + - "computation": Calculate a derived metric (ratio, growth rate, difference, average) + - "comparison": Compare values across time periods, segments, or companies + +Return ONLY valid JSON on a single line: +{{"complexity": "simple", "intent": "lookup"}} +""" + +# --------------------------------------------------------------------------- +# Intent-specific synthesis prompts (Plan C) +# --------------------------------------------------------------------------- + +ROI_LOOKUP_SYNTHESIS = """### Task +Extract the specific value requested from the evidence and present it clearly. + +### Constraints +1. **Language Continuity**: The output must be in the SAME language as the User Input. +2. Find the EXACT value stated in the evidence. Do not compute or estimate. +3. If multiple candidate values exist, select based on the closest match to the query's time period, entity, and metric. +4. Quote the source passage containing the value. +5. If the value is not explicitly stated in the evidence, mark SHOULD_ANSWER as "false". + +### Input Data +- **User Input**: {user_input} +- **Evidence**: {text_content} + +### Output Format + +**Source passage**: [Quote the exact text containing the answer] + +**Extracted value**: [The specific value found] + +[value only, e.g. "$1,832 million", "Yes", "42%"] +true/false +true/false +""" + +ROI_COMPUTATION_SYNTHESIS = """### Task +Answer the query by extracting data from the evidence and performing the required calculation. + +### Constraints +1. **Language Continuity**: The output must be in the SAME language as the User Input. +2. Follow this STRICT sequence — do NOT skip any step: + a) **DATA EXTRACTION**: List each required data point with its exact value and where you found it. + b) **FORMULA**: State the formula needed (e.g. Gross Margin = (Revenue - COGS) / Revenue). + c) **SUBSTITUTION**: Plug in the extracted values into the formula. + d) **CALCULATION**: Show arithmetic step by step. For each step, write the operation and its result. + e) **VERIFICATION**: Re-compute the final result independently to confirm. +3. **Rounding**: Match the precision implied by the query. For percentages, use at most one decimal place unless asked for more. For dollar amounts, round to the nearest whole number in the stated unit. +4. **Units**: Convert all values to consistent units before computing. +5. If any required data point is missing, explicitly state what is missing and mark SHOULD_ANSWER as "false". + +### Input Data +- **User Input**: {user_input} +- **Evidence**: {text_content} + +### Output Format + +## Data Extraction +| Data Point | Value | Source | +|---|---|---| +| [name] | [exact value] | [where found in evidence] | + +## Calculation +**Formula**: [state formula] +**Step 1**: [operation] = [result] +**Step 2**: [operation] = [result] +**Verification**: [re-compute to confirm] + +[final computed value only] +true/false +true/false +""" + +ROI_COMPARISON_SYNTHESIS = """### Task +Compare the requested values across the specified dimensions (time periods, entities, or segments). + +### Constraints +1. **Language Continuity**: The output must be in the SAME language as the User Input. +2. Extract values for EACH comparison dimension from the evidence. +3. Present in a structured comparison table. +4. State the direction and magnitude of difference or change. +5. **Precision**: Use exact values from the evidence. When computing changes, show the arithmetic. +6. If values for any comparison dimension are missing, state what is missing. + +### Input Data +- **User Input**: {user_input} +- **Evidence**: {text_content} + +### Output Format + +## Comparison +| Dimension | Value | Source | +|---|---|---| +| [period/entity] | [value] | [where found] | + +## Analysis +**Direction**: [increased/decreased/stable] +**Magnitude**: [absolute and/or percentage change, with arithmetic shown] + +[concise comparison result, e.g. "Increased from $1.2B to $1.5B (25% growth)"] +true/false +true/false +""" + +# --------------------------------------------------------------------------- +# Evidence completeness check (Plan D) +# --------------------------------------------------------------------------- + +EVIDENCE_COMPLETENESS_CHECK = """Given the query and available evidence, determine whether all data points needed to answer are present. + +### Query +{query} + +### Query Type +{intent} + +### Evidence (excerpt) +{evidence_excerpt} + +### Instructions +1. Identify the specific data points required to answer this query. +2. Check whether each data point's actual value appears in the evidence. +3. A data point is FOUND only if its numeric/factual value is explicitly stated. + +Return ONLY valid JSON on a single line: +{{"complete": true, "missing": []}} +or +{{"complete": false, "missing": ["short description of what is missing"]}} +""" + +# --------------------------------------------------------------------------- +# Computation correction (Plan E) +# --------------------------------------------------------------------------- + +COMPUTATION_CORRECTION = """Your previous calculation contained an arithmetic error. Please revise. + +### Query +{query} + +### Your Previous Answer +{original_answer} + +### Detected Error +- Expression: {expression} +- Your result: {llm_result} +- Correct result: {correct_result} + +Revise your answer using the correct arithmetic. Keep the same analysis structure. + + +[Corrected analysis with fixed calculation] + +[Corrected final value] +true +true +""" + # --------------------------------------------------------------------------- # Knowledge Compile prompts # --------------------------------------------------------------------------- diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 1acc220..52c431b 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -23,6 +23,9 @@ FAST_QUERY_ANALYSIS, FAST_QUERY_ANALYSIS_WITH_CATALOG, ROI_RESULT_SUMMARY, + ROI_LOOKUP_SYNTHESIS, + ROI_COMPUTATION_SYNTHESIS, + ROI_COMPARISON_SYNTHESIS, DOC_SUMMARY, DOC_CHUNK_SUMMARY, DOC_MERGE_SUMMARIES, @@ -1159,6 +1162,53 @@ def _classify_query_complexity(cls, query: str) -> str: return "moderate" return "simple" + _VALID_COMPLEXITIES = frozenset({"simple", "moderate", "complex"}) + _VALID_INTENTS = frozenset({"lookup", "computation", "comparison"}) + + async def _classify_query_intent( + self, query: str, + ) -> Tuple[str, str]: + """Classify query complexity and intent via LLM. + + Falls back to regex-based ``_classify_query_complexity`` when the + LLM call fails or returns unparseable output. + + Returns: + ``(complexity, intent)`` where complexity is + ``simple|moderate|complex`` and intent is + ``lookup|computation|comparison``. + """ + try: + from sirchmunk.llm.prompts import DEEP_QUERY_CLASSIFY + + resp = await self.llm.achat( + messages=[{ + "role": "user", + "content": DEEP_QUERY_CLASSIFY.format(query=query), + }], + stream=True, + ) + self.llm_usages.append(resp.usage) + + raw = (resp.content or "").strip() + match = re.search(r'\{[^}]+\}', raw) + if match: + data = json.loads(match.group()) + complexity = data.get("complexity", "").lower() + intent = data.get("intent", "").lower() + if (complexity in self._VALID_COMPLEXITIES + and intent in self._VALID_INTENTS): + return complexity, intent + except Exception as exc: + await self._logger.warning( + f"[QueryClassify] LLM classification failed: {exc}, " + f"falling back to regex" + ) + + complexity = self._classify_query_complexity(query) + intent = "computation" if complexity != "simple" else "lookup" + return complexity, intent + @staticmethod def _evaluate_evidence_acceptance( query: str, @@ -1209,6 +1259,256 @@ def _evaluate_evidence_acceptance( f"kw_coverage={kw_coverage:.2f}, numeric=false)" ) + # ------------------------------------------------------------------ + # Plan E: Computation verification + # ------------------------------------------------------------------ + + # ------------------------------------------------------------------ + # Plan D: Evidence adequacy closed-loop + # ------------------------------------------------------------------ + + async def _check_evidence_completeness( + self, + query: str, + intent: str, + evidence: str, + ) -> Tuple[bool, List[str]]: + """Check if evidence contains all data points needed for the query. + + Returns: + ``(is_complete, missing)`` where *missing* lists descriptions + of data points not found in the evidence. + """ + try: + from sirchmunk.llm.prompts import EVIDENCE_COMPLETENESS_CHECK + + prompt = EVIDENCE_COMPLETENESS_CHECK.format( + query=query, + intent=intent, + evidence_excerpt=evidence[:3000], + ) + resp = await self.llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) + + raw = (resp.content or "").strip() + match = re.search(r'\{[^}]+\}', raw, re.DOTALL) + if match: + data = json.loads(match.group()) + is_complete = bool(data.get("complete", True)) + missing = data.get("missing", []) + if isinstance(missing, list) and missing: + return False, [str(m) for m in missing[:5]] + return is_complete, [] + except Exception as exc: + await self._logger.warning( + f"[Phase 3.75] Completeness check failed: {exc}" + ) + return True, [] + + async def _fill_evidence_gaps( + self, + query: str, + missing: List[str], + file_paths: List[str], + artifacts: Any, + ) -> Optional[str]: + """Targeted evidence retrieval for identified gaps. + + Constructs focused sub-queries from *missing* descriptions and + re-navigates tree indices or falls back to keyword retrieval. + + Returns supplementary evidence text, or None. + """ + sub_query = f"{query} — specifically: {'; '.join(missing)}" + parts: List[str] = [] + + indexer = self._get_tree_indexer() + for fp in file_paths[:3]: + try: + if indexer and indexer.has_tree(fp): + ev = await self._navigate_tree_for_evidence(fp, sub_query) + if ev and len(ev.strip()) > 100: + parts.append( + f"[Gap-fill: {Path(fp).name}]\n{ev}" + ) + continue + ev = await self._tree_guided_sample(fp, sub_query) + if isinstance(ev, str) and len(ev.strip()) > 100: + parts.append(f"[Gap-fill: {Path(fp).name}]\n{ev}") + except Exception: + continue + + if not parts and artifacts and artifacts.tree_available_paths: + extra_fps = [ + fp for fp in artifacts.tree_available_paths + if fp not in file_paths + ][:2] + for fp in extra_fps: + try: + ev = await self._navigate_tree_for_evidence(fp, sub_query) + if ev and len(ev.strip()) > 100: + parts.append( + f"[Gap-fill extra: {Path(fp).name}]\n{ev}" + ) + except Exception: + continue + + if not parts: + return None + return "\n\n".join(parts) + + # ------------------------------------------------------------------ + # Plan E: Computation verification + # ------------------------------------------------------------------ + + _ARITH_PATTERNS = [ + re.compile( + r'[\$€£]?\s*' + r'([\d,]+(?:\.\d+)?)\s*' + r'([+\-\*/])\s*' + r'[\$€£]?\s*' + r'([\d,]+(?:\.\d+)?)\s*' + r'=\s*' + r'[\$€£]?\s*' + r'([\-]?[\d,]+(?:\.\d+)?)\s*%?' + ), + re.compile( + r'\(\s*' + r'[\$€£]?\s*([\d,]+(?:\.\d+)?)\s*' + r'([+\-])\s*' + r'[\$€£]?\s*([\d,]+(?:\.\d+)?)\s*' + r'\)\s*[/\*]\s*' + r'[\$€£]?\s*([\d,]+(?:\.\d+)?)\s*' + r'=\s*' + r'[\$€£]?\s*([\-]?[\d,]+(?:\.\d+)?)\s*%?' + ), + ] + + _SAFE_EVAL_NS: Dict[str, Any] = {"__builtins__": {}, "abs": abs, "round": round} + _ARITH_TOLERANCE: float = 0.01 + + @classmethod + def _extract_arithmetic_expressions(cls, text: str) -> List[Dict[str, Any]]: + """Extract arithmetic expressions and their stated results from text. + + Returns list of ``{"expr": str, "stated": float, "computed": float}``. + Only includes entries where Python evaluation succeeded. + """ + results: List[Dict[str, Any]] = [] + + def _parse_num(s: str) -> float: + return float(s.replace(",", "")) + + for line in text.split("\n"): + for pat in cls._ARITH_PATTERNS: + for m in pat.finditer(line): + groups = m.groups() + try: + if len(groups) == 4: + a, op, b, stated = groups + a_val, b_val = _parse_num(a), _parse_num(b) + expr = f"{a_val} {op} {b_val}" + computed = eval(expr, cls._SAFE_EVAL_NS) + results.append({ + "expr": expr, + "stated": _parse_num(stated), + "computed": float(computed), + "raw": m.group(), + }) + elif len(groups) == 5: + a, op, b, divisor, stated = groups + a_val, b_val = _parse_num(a), _parse_num(b) + d_val = _parse_num(divisor) + inner = f"{a_val} {op} {b_val}" + inner_result = eval(inner, cls._SAFE_EVAL_NS) + op2 = "/" if "/" in line[m.start():m.end()] else "*" + computed = eval( + f"{inner_result} {op2} {d_val}", + cls._SAFE_EVAL_NS, + ) + results.append({ + "expr": f"({inner}) {op2} {d_val}", + "stated": _parse_num(stated), + "computed": float(computed), + "raw": m.group(), + }) + except Exception: + continue + return results + + async def _verify_computation( + self, + query: str, + answer: str, + ) -> Tuple[str, bool]: + """Verify arithmetic in computation-type answers. + + Extracts arithmetic expressions, evaluates them with Python, and + re-prompts the LLM if a discrepancy is detected. + + Returns: + ``(corrected_answer, was_corrected)``. + """ + expressions = self._extract_arithmetic_expressions(answer) + if not expressions: + return answer, False + + discrepancies = [] + for expr_info in expressions: + stated = expr_info["stated"] + computed = expr_info["computed"] + if stated == 0 and computed == 0: + continue + denom = max(abs(stated), abs(computed), 1e-9) + if abs(stated - computed) / denom > self._ARITH_TOLERANCE: + discrepancies.append(expr_info) + + if not discrepancies: + return answer, False + + worst = max( + discrepancies, + key=lambda d: abs(d["stated"] - d["computed"]), + ) + + await self._logger.info( + f"[Phase 4.5:Verify] Arithmetic discrepancy: " + f"{worst['expr']} = {worst['stated']} (stated) vs " + f"{worst['computed']} (computed)" + ) + + try: + from sirchmunk.llm.prompts import COMPUTATION_CORRECTION + + correction_prompt = COMPUTATION_CORRECTION.format( + query=query, + original_answer=answer[:3000], + expression=worst["expr"], + llm_result=worst["stated"], + correct_result=worst["computed"], + ) + resp = await self.llm.achat( + messages=[{"role": "user", "content": correction_prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) + + corrected = resp.content or "" + if corrected and len(corrected) > 100: + await self._logger.info( + "[Phase 4.5:Verify] Correction applied" + ) + return corrected, True + except Exception as exc: + await self._logger.warning( + f"[Phase 4.5:Verify] Correction failed: {exc}" + ) + + return answer, False + @staticmethod def _extract_and_validate_multi_level_keywords( llm_resp: str, @@ -1841,28 +2141,14 @@ async def _search_deep( f"dir_scan_files={len(dir_scan_files)}" ) - # --- Phase 2.5: Parallel tree pre-navigation for top tree hits --- - _pre_nav_evidence: Dict[str, str] = {} + # --- Phase 2.5: Full tree evidence collection for DEEP mode --- + _tree_evidence: Dict[str, str] = {} + _tree_sufficient = False if tree_hits: - _nav_fps = [fp for fp in tree_hits[:self._DEEP_PRE_NAV_MAX_FILES]] - if _nav_fps: - _nav_results = await asyncio.gather( - *[self._tree_guided_sample( - fp, query, max_chars=self._FAST_MAX_EVIDENCE_CHARS, - ) for fp in _nav_fps], - return_exceptions=True, - ) - for fp, nav_res in zip(_nav_fps, _nav_results): - if isinstance(nav_res, Exception): - await self._logger.warning( - f"[Phase 2.5] Tree pre-nav failed for {Path(fp).name}: {nav_res}" - ) - elif isinstance(nav_res, str) and nav_res: - _pre_nav_evidence[fp] = nav_res - if _pre_nav_evidence: - await self._logger.info( - f"[Phase 2.5] Pre-navigated {len(_pre_nav_evidence)} tree files" - ) + _tree_evidence, _tree_sufficient = ( + await self._collect_deep_tree_evidence(tree_hits, query) + ) + _pre_nav_evidence = _tree_evidence # ============================================================== # Phase 3: Merge file paths + build KnowledgeCluster @@ -1898,11 +2184,35 @@ async def _search_deep( await self._logger.info(f"[Phase 3] Merged {len(merged_files)} unique candidate files") cluster: Optional[KnowledgeCluster] = None - if merged_files: + if _tree_sufficient and _tree_evidence: + combined_tree_ev = "\n\n---\n\n".join( + f"[Source: {Path(fp).name}]\n{ev}" + for fp, ev in _tree_evidence.items() + ) + cluster = self._make_answer_cluster( + query, combined_tree_ev[:5000], "DTE", + file_paths=list(_tree_evidence.keys()), + ) + cluster.content = combined_tree_ev + await self._logger.info( + f"[Phase 3:DirectTree] Tree evidence sufficient " + f"({len(combined_tree_ev)} chars), bypassing Monte Carlo" + ) + elif merged_files: cluster = await self._build_cluster( query=query, file_paths=merged_files, query_keywords=query_keywords, top_k_files=top_k_files, ) + if _tree_evidence and cluster and cluster.content: + pre_nav_parts = [ + f"[Tree evidence: {Path(fp).name}]\n{ev}" + for fp, ev in _tree_evidence.items() + ] + if pre_nav_parts: + pre_nav_ctx = "\n\n".join(pre_nav_parts) + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = f"{cluster.content}\n\n{pre_nav_ctx}" # ============================================================== # Phase 3.5: Graph context enrichment (P5) @@ -1911,39 +2221,64 @@ async def _search_deep( # ============================================================== graph_ctx = "" if cluster: - # Merge pre-navigated tree evidence into cluster content - if _pre_nav_evidence and cluster.content: - pre_nav_parts = [] - for fp, ev in _pre_nav_evidence.items(): - pre_nav_parts.append(f"[Tree evidence: {Path(fp).name}]\n{ev}") - if pre_nav_parts: - pre_nav_ctx = "\n\n".join(pre_nav_parts) - if isinstance(cluster.content, list): - cluster.content = "\n".join(cluster.content) - cluster.content = f"{cluster.content}\n\n{pre_nav_ctx}" - graph_ctx = await self._gather_graph_context(cluster) if graph_ctx and cluster.content: if isinstance(cluster.content, list): cluster.content = "\n".join(cluster.content) cluster.content = f"{cluster.content}\n\n{graph_ctx}" + # ============================================================== + # Phase 3.8: Query classification (feeds Phase 3.75 + Phase 4) + # ============================================================== + _query_complexity, _query_intent = await self._classify_query_intent(query) + context.increment_loop() + await self._logger.info( + f"[Phase 3.8] Query: complexity={_query_complexity}, intent={_query_intent}" + ) + + # ============================================================== + # Phase 3.75: Evidence adequacy closed-loop (Plan D) + # For computation/comparison queries, verify required data points + # are present and trigger targeted gap-fill if missing. + # ============================================================== + if ( + cluster and cluster.content + and _query_intent in ("computation", "comparison") + ): + _ev_text = ( + str(cluster.content) if isinstance(cluster.content, str) + else "\n".join(cluster.content) + ) + is_complete, missing = await self._check_evidence_completeness( + query, _query_intent, _ev_text, + ) + context.increment_loop() + if not is_complete and missing: + await self._logger.info( + f"[Phase 3.75] Missing data points: {missing}" + ) + gap_evidence = await self._fill_evidence_gaps( + query, missing, merged_files, artifacts, + ) + if gap_evidence: + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = ( + f"{cluster.content}\n\n" + f"[Gap-fill evidence]\n{gap_evidence}" + ) + await self._logger.info( + f"[Phase 3.75] Filled {len(missing)} gaps " + f"({len(gap_evidence)} chars)" + ) + # ============================================================== # Phase 4: Structured Reasoning → Cluster Summary fallback - # P0: DEEP mode always goes through full reasoning pipeline — - # no fast triage short-circuit. P4: query complexity determines - # whether the heavier section-map SR fires or we go straight to - # cluster synthesis. # ============================================================== context.increment_loop() answer = "" should_save = True - _query_complexity = self._classify_query_complexity(query) - await self._logger.info( - f"[Phase 4] Query complexity: {_query_complexity}" - ) - # Attempt structured reasoning for moderate/complex queries _sr_files: List[str] = [] if _query_complexity != "simple": @@ -1960,7 +2295,7 @@ async def _search_deep( f"{len(_sr_files)} tree-indexed files" ) sr_answer, sr_cluster, sr_evidence = await self._deep_structured_reasoning( - query, _sr_files, artifacts, context, + query, _sr_files, artifacts, context, _query_intent, ) if sr_answer: @@ -2001,7 +2336,7 @@ async def _search_deep( "[Phase 4:Fallback] Generating summary from cluster" ) answer, should_save, should_answer = ( - await self._summarise_cluster(query, cluster) + await self._summarise_cluster(query, cluster, _query_intent) ) cluster_evidence = ( str(cluster.content) if cluster.content else "" @@ -2028,7 +2363,7 @@ async def _search_deep( ) sc_cluster.content = sc_evidence answer, should_save, should_answer = ( - await self._summarise_cluster(query, sc_cluster) + await self._summarise_cluster(query, sc_cluster, _query_intent) ) sc_accepted, _ = self._evaluate_evidence_acceptance( query, sc_evidence, should_answer, @@ -2049,8 +2384,6 @@ async def _search_deep( await self._logger.info( "[Phase 4:Fallback] Launching ReAct refinement" ) - # Seed ReAct with all available prior context so it - # doesn't start from scratch. react_parts: List[str] = [] if spec_context: react_parts.append(spec_context) @@ -2085,7 +2418,7 @@ async def _search_deep( if not cluster: return _NO_RESULTS_MESSAGE, None, context answer, should_save, should_answer = ( - await self._summarise_cluster(query, cluster) + await self._summarise_cluster(query, cluster, _query_intent) ) final_evidence = ( str(cluster.content) if cluster.content else "" @@ -2109,12 +2442,20 @@ async def _search_deep( ) sc_cluster.content = sc_evidence answer, should_save, _ = ( - await self._summarise_cluster(query, sc_cluster) + await self._summarise_cluster(query, sc_cluster, _query_intent) ) cluster = sc_cluster else: return _NO_RESULTS_MESSAGE, None, context + # ============================================================== + # Phase 4.5: Computation verification (Plan E) + # ============================================================== + if answer and _query_intent == "computation": + answer, was_corrected = await self._verify_computation(query, answer) + if was_corrected: + _, should_save, _ = self._parse_summary_response(answer) + # Sync LLM token accounting into context new_usages = self.llm_usages[_llm_usage_start:] for usage in new_usages: @@ -2537,6 +2878,12 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Minimum evidence character length for heuristic override.""" _EVIDENCE_KEYWORD_COVERAGE_THRESHOLD: float = 0.5 """Minimum keyword coverage ratio for heuristic override.""" + + # --- Plan A: Evidence channel unification --- + _TREE_EVIDENCE_MIN_DIRECT_CHARS: int = 2000 + """Minimum tree evidence length to bypass Monte Carlo and feed directly to synthesis.""" + _TREE_DIRECT_KW_THRESHOLD: float = 0.3 + """Minimum keyword coverage for tree evidence to qualify for the direct channel.""" _NUMERIC_INTENT_KEYWORDS: frozenset = frozenset({ "revenue", "margin", "ratio", "ebitda", "income", "profit", "loss", "cash", "debt", "equity", "eps", "dpo", "growth", "rate", @@ -4383,6 +4730,65 @@ async def _tree_guided_sample( ) return evidence + async def _collect_deep_tree_evidence( + self, + file_paths: List[str], + query: str, + ) -> Tuple[Dict[str, str], bool]: + """Full tree navigation for DEEP mode primary files. + + Runs ``_navigate_tree_for_evidence`` (complement nav, table supplement, + referenced-page gap-fill) on each file, then assesses whether the + aggregated evidence is rich enough to bypass Monte Carlo sampling. + + Returns: + ``(evidence_dict, is_sufficient)`` where *evidence_dict* maps + file_path to raw evidence text, and *is_sufficient* indicates + that the direct-channel can replace ``_build_cluster``. + """ + indexer = self._get_tree_indexer() + if indexer is None: + return {}, False + + nav_fps = [fp for fp in file_paths[:self._DEEP_PRE_NAV_MAX_FILES] + if indexer.has_tree(fp)] + if not nav_fps: + return {}, False + + results = await asyncio.gather( + *[self._navigate_tree_for_evidence(fp, query) for fp in nav_fps], + return_exceptions=True, + ) + + evidence_dict: Dict[str, str] = {} + for fp, res in zip(nav_fps, results): + if isinstance(res, Exception): + await self._logger.warning( + f"[Phase 2.5:DirectTree] Navigation failed for " + f"{Path(fp).name}: {res}" + ) + elif isinstance(res, str) and res.strip(): + evidence_dict[fp] = res + + if not evidence_dict: + return {}, False + + combined = "\n\n".join(evidence_dict.values()) + total_len = len(combined) + kw_coverage = self._compute_keyword_coverage(query, combined) + + is_sufficient = ( + total_len >= self._TREE_EVIDENCE_MIN_DIRECT_CHARS + and kw_coverage >= self._TREE_DIRECT_KW_THRESHOLD + ) + + await self._logger.info( + f"[Phase 2.5:DirectTree] {len(evidence_dict)} files, " + f"{total_len} chars, kw_cov={kw_coverage:.2f}, " + f"sufficient={is_sufficient}" + ) + return evidence_dict, is_sufficient + @classmethod def _classify_leaves(cls, leaves: list) -> Tuple[List[tuple], List, List]: """Classify leaf nodes by preferred extraction strategy. @@ -6315,13 +6721,45 @@ async def _gather_graph_context(self, cluster: KnowledgeCluster) -> str: # Phase 4: Answer generation # ------------------------------------------------------------------ + _INTENT_PROMPT_MAP = { + "lookup": ROI_LOOKUP_SYNTHESIS, + "computation": ROI_COMPUTATION_SYNTHESIS, + "comparison": ROI_COMPARISON_SYNTHESIS, + } + + @classmethod + def _select_synthesis_prompt( + cls, + query: str, + evidence: str, + intent: str = "", + *, + document_context: Optional[str] = None, + ) -> str: + """Select and format the synthesis prompt based on query intent. + + Falls back to ``ROI_RESULT_SUMMARY`` for unknown intents or when + the caller passes no intent (FAST mode compatibility). + """ + template = cls._INTENT_PROMPT_MAP.get(intent, ROI_RESULT_SUMMARY) + + prompt = template.format(user_input=query, text_content=evidence) + + if document_context: + prompt = ( + f"{prompt}\n\n### Document Context\n{document_context}" + ) + return prompt + async def _summarise_cluster( self, query: str, cluster: KnowledgeCluster, + intent: str = "", ) -> Tuple[str, bool, bool]: """Generate a final answer summary from a KnowledgeCluster. - Uses ``ROI_RESULT_SUMMARY`` (with precision / best-effort constraints) - for both FAST and DEEP modes, ensuring consistent answer quality. + When *intent* is provided, selects a specialised synthesis prompt + (lookup / computation / comparison). Falls back to the general + ``ROI_RESULT_SUMMARY`` for FAST mode or unknown intents. Returns: ``(summary_text, should_save, should_answer)`` where: @@ -6335,9 +6773,8 @@ async def _summarise_cluster( f"{cluster.content if isinstance(cluster.content, str) else sep.join(cluster.content)}" ) - result_sum_prompt = ROI_RESULT_SUMMARY.format( - user_input=query, - text_content=cluster_text_content, + result_sum_prompt = self._select_synthesis_prompt( + query, cluster_text_content, intent, ) await self._logger.info("[Phase 4] Generating search result summary...") @@ -6580,6 +7017,7 @@ async def _deep_structured_reasoning( tree_files: List[str], artifacts: Any, context: "SearchContext", + intent: str = "", ) -> Tuple[str, Optional["KnowledgeCluster"], str]: """Orchestrate the Deep Structured Reasoning pipeline. @@ -6587,7 +7025,7 @@ async def _deep_structured_reasoning( 1. Section map — build from tree index top layers (no LLM) 2. Section select — LLM picks relevant sections (1 LLM) 3. Targeted extraction — pull pages + tables for sections (no LLM) - 4. Synthesis — ROI_RESULT_SUMMARY on targeted evidence (1 LLM) + 4. Synthesis — intent-aware prompt on targeted evidence (1 LLM) 5. Recovery — if refused, expand sections and re-synthesize Returns ``(raw_llm_output, cluster, combined_evidence)`` where @@ -6658,19 +7096,11 @@ async def _deep_structured_reasoning( if ctx_parts: doc_context = "\n".join(ctx_parts) - # Synthesize answer using the unified ROI prompt - if doc_context: - from sirchmunk.llm.prompts import ROI_RESULT_SUMMARY_WITH_CONTEXT - synth_prompt = ROI_RESULT_SUMMARY_WITH_CONTEXT.format( - user_input=query, - text_content=combined_evidence, - document_context=doc_context, - ) - else: - synth_prompt = ROI_RESULT_SUMMARY.format( - user_input=query, - text_content=combined_evidence, - ) + # Synthesize answer using intent-aware prompt + synth_prompt = self._select_synthesis_prompt( + query, combined_evidence, intent, + document_context=doc_context, + ) resp = await self.llm.achat( messages=[{"role": "user", "content": synth_prompt}], @@ -6722,21 +7152,12 @@ async def _deep_structured_reasoning( if not found_new: break combined_evidence = "\n\n---\n\n".join(expanded_parts) - if doc_context: - synth_prompt = ROI_RESULT_SUMMARY_WITH_CONTEXT.format( - user_input=query, - text_content=combined_evidence[ - : self._DEEP_STRUCTURED_MAX_CHARS - ], - document_context=doc_context, - ) - else: - synth_prompt = ROI_RESULT_SUMMARY.format( - user_input=query, - text_content=combined_evidence[ - : self._DEEP_STRUCTURED_MAX_CHARS - ], - ) + synth_prompt = self._select_synthesis_prompt( + query, + combined_evidence[:self._DEEP_STRUCTURED_MAX_CHARS], + intent, + document_context=doc_context, + ) resp = await self.llm.achat( messages=[{"role": "user", "content": synth_prompt}], stream=True, From d4a636605ee5e30fa3c68084efa5a8f0ca3d37d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 17 May 2026 14:06:12 +0800 Subject: [PATCH 65/70] fix pipeline deep --- src/sirchmunk/search.py | 149 ++++++++++++++++++++++++---------------- 1 file changed, 88 insertions(+), 61 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 52c431b..280fe71 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -1314,22 +1314,35 @@ async def _fill_evidence_gaps( missing: List[str], file_paths: List[str], artifacts: Any, + *, + scope: Optional["_PathScope"] = None, + nav_cache: Optional[Dict[str, str]] = None, ) -> Optional[str]: """Targeted evidence retrieval for identified gaps. Constructs focused sub-queries from *missing* descriptions and re-navigates tree indices or falls back to keyword retrieval. + When *scope* is provided, extra files drawn from + ``artifacts.tree_available_paths`` are filtered to the scope. + When *nav_cache* is provided, navigation results are cached to + avoid duplicate LLM calls across phases. + Returns supplementary evidence text, or None. """ sub_query = f"{query} — specifically: {'; '.join(missing)}" parts: List[str] = [] + async def _navigate(fp: str, q: str) -> Optional[str]: + if nav_cache is not None: + return await self._cached_navigate_tree(fp, q, nav_cache) + return await self._navigate_tree_for_evidence(fp, q) + indexer = self._get_tree_indexer() for fp in file_paths[:3]: try: if indexer and indexer.has_tree(fp): - ev = await self._navigate_tree_for_evidence(fp, sub_query) + ev = await _navigate(fp, sub_query) if ev and len(ev.strip()) > 100: parts.append( f"[Gap-fill: {Path(fp).name}]\n{ev}" @@ -1345,10 +1358,11 @@ async def _fill_evidence_gaps( extra_fps = [ fp for fp in artifacts.tree_available_paths if fp not in file_paths + and (not scope or scope.contains(fp)) ][:2] for fp in extra_fps: try: - ev = await self._navigate_tree_for_evidence(fp, sub_query) + ev = await _navigate(fp, sub_query) if ev and len(ev.strip()) > 100: parts.append( f"[Gap-fill extra: {Path(fp).name}]\n{ev}" @@ -2001,6 +2015,7 @@ async def _search_deep( # --- Adaptive compile artifact detection (shared with FAST) --- _scope = _PathScope(paths) + _nav_cache: Dict[str, str] = {} artifacts = self._detect_compile_artifacts(paths) # ============================================================== @@ -2143,10 +2158,9 @@ async def _search_deep( # --- Phase 2.5: Full tree evidence collection for DEEP mode --- _tree_evidence: Dict[str, str] = {} - _tree_sufficient = False if tree_hits: - _tree_evidence, _tree_sufficient = ( - await self._collect_deep_tree_evidence(tree_hits, query) + _tree_evidence = await self._collect_deep_tree_evidence( + tree_hits, query, scope=_scope, nav_cache=_nav_cache, ) _pre_nav_evidence = _tree_evidence @@ -2184,21 +2198,7 @@ async def _search_deep( await self._logger.info(f"[Phase 3] Merged {len(merged_files)} unique candidate files") cluster: Optional[KnowledgeCluster] = None - if _tree_sufficient and _tree_evidence: - combined_tree_ev = "\n\n---\n\n".join( - f"[Source: {Path(fp).name}]\n{ev}" - for fp, ev in _tree_evidence.items() - ) - cluster = self._make_answer_cluster( - query, combined_tree_ev[:5000], "DTE", - file_paths=list(_tree_evidence.keys()), - ) - cluster.content = combined_tree_ev - await self._logger.info( - f"[Phase 3:DirectTree] Tree evidence sufficient " - f"({len(combined_tree_ev)} chars), bypassing Monte Carlo" - ) - elif merged_files: + if merged_files: cluster = await self._build_cluster( query=query, file_paths=merged_files, query_keywords=query_keywords, top_k_files=top_k_files, @@ -2259,6 +2259,7 @@ async def _search_deep( ) gap_evidence = await self._fill_evidence_gaps( query, missing, merged_files, artifacts, + scope=_scope, nav_cache=_nav_cache, ) if gap_evidence: if isinstance(cluster.content, list): @@ -2283,11 +2284,13 @@ async def _search_deep( _sr_files: List[str] = [] if _query_complexity != "simple": if tree_hits: - _sr_files = list(tree_hits[: self._DEEP_STRUCTURED_MAX_FILES]) + _scoped_hits = [fp for fp in tree_hits if _scope.contains(fp)] + _sr_files = _scoped_hits[: self._DEEP_STRUCTURED_MAX_FILES] elif artifacts and artifacts.tree_available_paths: - _sr_files = list(artifacts.tree_available_paths)[ - : self._DEEP_STRUCTURED_MAX_FILES - ] + _sr_files = [ + fp for fp in artifacts.tree_available_paths + if _scope.contains(fp) + ][: self._DEEP_STRUCTURED_MAX_FILES] if _sr_files: await self._logger.info( @@ -2521,6 +2524,16 @@ async def _try_direct_doc_analysis( if operation is None: return None + # Computation/comparison queries need the full evidence pipeline + if re.search( + r'\b(?:ratio|margin|growth.?rate|turnover|coverage' + r'|what is (?:the )?fy\d|calculate|compute' + r'|improv(?:ing|ed)|declin(?:ing|ed)' + r'|which .{0,30}(?:best|worst|most|least|highest|lowest))\b', + query, re.IGNORECASE, + ): + return None + filenames = ", ".join(Path(d.path).name for d in doc_files) await self._logger.info( f"[DocQA] Intent '{operation}' detected — " @@ -2879,11 +2892,6 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _EVIDENCE_KEYWORD_COVERAGE_THRESHOLD: float = 0.5 """Minimum keyword coverage ratio for heuristic override.""" - # --- Plan A: Evidence channel unification --- - _TREE_EVIDENCE_MIN_DIRECT_CHARS: int = 2000 - """Minimum tree evidence length to bypass Monte Carlo and feed directly to synthesis.""" - _TREE_DIRECT_KW_THRESHOLD: float = 0.3 - """Minimum keyword coverage for tree evidence to qualify for the direct channel.""" _NUMERIC_INTENT_KEYWORDS: frozenset = frozenset({ "revenue", "margin", "ratio", "ebitda", "income", "profit", "loss", "cash", "debt", "equity", "eps", "dpo", "growth", "rate", @@ -4730,35 +4738,65 @@ async def _tree_guided_sample( ) return evidence + async def _cached_navigate_tree( + self, + file_path: str, + query: str, + nav_cache: Dict[str, str], + ) -> Optional[str]: + """``_navigate_tree_for_evidence`` with per-query dedup cache.""" + cache_key = f"{file_path}::{query}" + if cache_key in nav_cache: + return nav_cache[cache_key] + result = await self._navigate_tree_for_evidence(file_path, query) + if isinstance(result, str) and result.strip(): + nav_cache[cache_key] = result + return result + async def _collect_deep_tree_evidence( self, file_paths: List[str], query: str, - ) -> Tuple[Dict[str, str], bool]: + *, + scope: Optional["_PathScope"] = None, + nav_cache: Optional[Dict[str, str]] = None, + ) -> Dict[str, str]: """Full tree navigation for DEEP mode primary files. Runs ``_navigate_tree_for_evidence`` (complement nav, table supplement, - referenced-page gap-fill) on each file, then assesses whether the - aggregated evidence is rich enough to bypass Monte Carlo sampling. - - Returns: - ``(evidence_dict, is_sufficient)`` where *evidence_dict* maps - file_path to raw evidence text, and *is_sufficient* indicates - that the direct-channel can replace ``_build_cluster``. + referenced-page gap-fill) on each file. Returns a dict mapping + file paths to raw evidence text. The evidence is used to + **supplement** (not replace) Monte Carlo sampling. + + When *scope* is provided, only files within the search path scope + are navigated — prevents cross-document evidence contamination. + When *nav_cache* is provided, results are cached to avoid + duplicate navigation across pipeline phases. """ indexer = self._get_tree_indexer() if indexer is None: - return {}, False + return {} + + if scope: + file_paths = [fp for fp in file_paths if scope.contains(fp)] + if not file_paths: + return {} nav_fps = [fp for fp in file_paths[:self._DEEP_PRE_NAV_MAX_FILES] if indexer.has_tree(fp)] if not nav_fps: - return {}, False + return {} - results = await asyncio.gather( - *[self._navigate_tree_for_evidence(fp, query) for fp in nav_fps], - return_exceptions=True, - ) + if nav_cache is not None: + results = await asyncio.gather( + *[self._cached_navigate_tree(fp, query, nav_cache) for fp in nav_fps], + return_exceptions=True, + ) + else: + results = await asyncio.gather( + *[self._navigate_tree_for_evidence(fp, query) for fp in nav_fps], + return_exceptions=True, + ) evidence_dict: Dict[str, str] = {} for fp, res in zip(nav_fps, results): @@ -4770,24 +4808,13 @@ async def _collect_deep_tree_evidence( elif isinstance(res, str) and res.strip(): evidence_dict[fp] = res - if not evidence_dict: - return {}, False - - combined = "\n\n".join(evidence_dict.values()) - total_len = len(combined) - kw_coverage = self._compute_keyword_coverage(query, combined) - - is_sufficient = ( - total_len >= self._TREE_EVIDENCE_MIN_DIRECT_CHARS - and kw_coverage >= self._TREE_DIRECT_KW_THRESHOLD - ) - - await self._logger.info( - f"[Phase 2.5:DirectTree] {len(evidence_dict)} files, " - f"{total_len} chars, kw_cov={kw_coverage:.2f}, " - f"sufficient={is_sufficient}" - ) - return evidence_dict, is_sufficient + if evidence_dict: + total_len = sum(len(v) for v in evidence_dict.values()) + await self._logger.info( + f"[Phase 2.5:DirectTree] {len(evidence_dict)} files, " + f"{total_len} chars" + ) + return evidence_dict @classmethod def _classify_leaves(cls, leaves: list) -> Tuple[List[tuple], List, List]: From 436d8889462bf2a7ae2628667057b2938934ba8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 17 May 2026 20:41:24 +0800 Subject: [PATCH 66/70] refactor deep mode for tree indexing loop --- src/sirchmunk/llm/prompts.py | 91 ++++ src/sirchmunk/search.py | 882 ++++++++++++++++++++++++----------- 2 files changed, 688 insertions(+), 285 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 2d893fe..e5b3238 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -705,6 +705,97 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: true """ +# --------------------------------------------------------------------------- +# Agentic retrieval prompts (DEEP mode) +# --------------------------------------------------------------------------- + +DEEP_DATA_REQUIREMENTS = """Given the user's question, identify the specific data points needed to answer it. + +### Question +{query} + +### Question Type +{intent} + +### Instructions +1. List each specific data point needed to answer this question (e.g., "Total Revenue for FY2022", "Accounts Payable as of fiscal year end 2019"). +2. For each data point, identify the likely document section type where it would appear (e.g., "Income Statement", "Balance Sheet", "Cash Flow Statement", "Notes to Financial Statements", "Management Discussion and Analysis", "Segment Information"). +3. If a calculation is required, state the exact formula. +4. Identify the time period(s) required. + +Return ONLY valid JSON on a single line: +{{"data_points": ["data point 1", "data point 2"], "likely_sources": ["section type 1", "section type 2"], "formula": "formula or null", "time_period": "period or null"}} +""" + +DEEP_PAGE_SELECT = """You are locating specific data in a document. Select pages to fetch. + +### Question +{query} + +### Data Still Needed +{data_requirements} + +### Document Outline (with page ranges) +{section_map} + +### Pages Already Fetched +{fetched_pages} + +### Instructions +- Reason about which sections contain the needed data based on section titles, summaries, and page ranges. +- Financial statements (Income Statement, Balance Sheet, Cash Flow Statement) typically contain quantitative data needed for calculations. +- Sections with tables are often high-value for data extraction. +- Do NOT re-select pages listed in "Pages Already Fetched". +- Select 3-8 pages that are most likely to contain the missing data. +- When uncertain, prefer sections deeper in the document (financial statements are usually after narrative sections). + +Return ONLY a JSON array of page numbers to fetch: [45, 46, 52, 53] +""" + +DEEP_CHECK_REQUIREMENTS = """Check whether the evidence contains all required data points. + +### Question +{query} + +### Required Data Points +{data_points} + +### Formula (if applicable) +{formula} + +### Evidence +{evidence} + +### Instructions +For each required data point, check if its actual numeric or factual value appears in the evidence. A data point is FOUND only if you can identify its specific value in the text. + +Return ONLY valid JSON: +{{"complete": true, "found": [{{"point": "description", "value": "extracted value"}}], "missing": []}} +or +{{"complete": false, "found": [{{"point": "description", "value": "extracted value"}}], "missing": ["description of missing data point"]}} +""" + +DEEP_TOC_ANALYSIS = """Analyze the following pages from the beginning of a document and extract its structural outline. + +### Document Pages +{toc_page_text} + +### Total Document Pages +{total_pages} + +### Instructions +1. Look for a table of contents, section listing, or structural overview. +2. Extract every section entry with its title, starting page number, and hierarchy level. +3. Infer page_end from the start of the next section (use {total_pages} for the last section). +4. If page numbers appear as dot leaders (e.g. "Item 7. MD&A ........ 45"), extract the page number. +5. If no structural information can be extracted, return an empty array. + +Return ONLY valid JSON — an array of section objects: +[{{"title": "Section Title", "page_start": 3, "page_end": 15, "level": 1}}, ...] + +If no structure found, return: [] +""" + # --------------------------------------------------------------------------- # Knowledge Compile prompts # --------------------------------------------------------------------------- diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 280fe71..8384b39 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -30,6 +30,10 @@ DOC_CHUNK_SUMMARY, DOC_MERGE_SUMMARIES, DEEP_SECTION_SELECT, + DEEP_DATA_REQUIREMENTS, + DEEP_PAGE_SELECT, + DEEP_CHECK_REQUIREMENTS, + DEEP_TOC_ANALYSIS, ) from sirchmunk.retrieve.text_retriever import GrepRetriever from sirchmunk.schema.knowledge import ( @@ -203,6 +207,27 @@ class CompileArtifacts: summary_index: Optional[Any] = None # CompileSummaryIndex (lazy-loaded) +@dataclass +class DataRequirements: + """Pre-retrieval analysis of what data points a query needs.""" + + data_points: List[str] + likely_sources: List[str] + formula: Optional[str] + time_period: Optional[str] + intent: str + + +@dataclass +class RetrievalResult: + """Output of the agentic retrieval loop.""" + + evidence: str + pages_extracted: Dict[str, List[int]] + is_complete: bool + rounds_used: int + + class _TreeNavCache: """Per-search-session cache for tree navigation results. @@ -1191,9 +1216,8 @@ async def _classify_query_intent( self.llm_usages.append(resp.usage) raw = (resp.content or "").strip() - match = re.search(r'\{[^}]+\}', raw) - if match: - data = json.loads(match.group()) + data = self._extract_json_object(raw) + if data: complexity = data.get("complexity", "").lower() intent = data.get("intent", "").lower() if (complexity in self._VALID_COMPLEXITIES @@ -1209,11 +1233,37 @@ async def _classify_query_intent( intent = "computation" if complexity != "simple" else "lookup" return complexity, intent + @staticmethod + def _extract_json_object(raw: str) -> Optional[dict]: + """Extract the outermost JSON object from LLM response text.""" + start = raw.find("{") + end = raw.rfind("}") + if start >= 0 and end > start: + try: + return json.loads(raw[start : end + 1]) + except (json.JSONDecodeError, TypeError): + pass + return None + + @staticmethod + def _extract_json_array(raw: str) -> Optional[list]: + """Extract the outermost JSON array from LLM response text.""" + start = raw.find("[") + end = raw.rfind("]") + if start >= 0 and end > start: + try: + return json.loads(raw[start : end + 1]) + except (json.JSONDecodeError, TypeError): + pass + return None + @staticmethod def _evaluate_evidence_acceptance( query: str, evidence: str, llm_should_answer: bool, + *, + retrieval_complete: bool = False, ) -> Tuple[bool, str]: """Multi-factor decision on whether to accept retrieved evidence. @@ -1226,6 +1276,10 @@ def _evaluate_evidence_acceptance( boolean decision and *reason* is a human-readable string documenting which factor(s) determined the outcome. """ + # Factor 0: Agentic retrieval confirmed data completeness + if retrieval_complete: + return True, "retrieval_complete" + # Factor 1: LLM direct acceptance if llm_should_answer: return True, "llm_accepted" @@ -2015,7 +2069,6 @@ async def _search_deep( # --- Adaptive compile artifact detection (shared with FAST) --- _scope = _PathScope(paths) - _nav_cache: Dict[str, str] = {} artifacts = self._detect_compile_artifacts(paths) # ============================================================== @@ -2156,305 +2209,61 @@ async def _search_deep( f"dir_scan_files={len(dir_scan_files)}" ) - # --- Phase 2.5: Full tree evidence collection for DEEP mode --- - _tree_evidence: Dict[str, str] = {} - if tree_hits: - _tree_evidence = await self._collect_deep_tree_evidence( - tree_hits, query, scope=_scope, nav_cache=_nav_cache, - ) - _pre_nav_evidence = _tree_evidence - # ============================================================== - # Phase 3: Merge file paths + build KnowledgeCluster - # P1 tree hits get highest priority; P2 soft-hit files next + # Phase 3: Query analysis + file selection # ============================================================== context.increment_loop() + _query_complexity, _query_intent = await self._classify_query_intent(query) + data_reqs = await self._analyze_data_requirements(query, _query_intent) + context.increment_loop() + + await self._logger.info( + f"[Phase 3] Query: complexity={_query_complexity}, " + f"intent={_query_intent}, " + f"data_points={len(data_reqs.data_points)}, " + f"formula={data_reqs.formula or 'N/A'}" + ) + extra_knowledge_files = knowledge_probe.file_paths if soft_hit: extra_knowledge_files = soft_hit.file_paths + extra_knowledge_files - if _PURE_TREE_SEARCH: - # Pure tree search: only use tree hits (+ soft-hit fallback if no tree hits) - pure_tree_files = list(tree_hits) - if not pure_tree_files and soft_hit: - pure_tree_files = soft_hit.file_paths - await self._logger.info( - f"[Phase 3:PureTree] No tree hits, using {len(pure_tree_files)} soft-hit files" - ) - merged_files = self._merge_file_paths( - keyword_files=pure_tree_files, - dir_scan_files=[], - knowledge_hits=[], - ) - await self._logger.info( - f"[Phase 3:PureTree] Merged {len(merged_files)} tree-only candidate files" - ) - else: - merged_files = self._merge_file_paths( - keyword_files=list(tree_hits) + catalog_deep_hits + compile_hints.file_paths + summary_index_hits + keyword_files, - dir_scan_files=dir_scan_files, - knowledge_hits=extra_knowledge_files, - ) - await self._logger.info(f"[Phase 3] Merged {len(merged_files)} unique candidate files") + merged_files = self._merge_file_paths( + keyword_files=list(tree_hits) + catalog_deep_hits + compile_hints.file_paths + summary_index_hits + keyword_files, + dir_scan_files=dir_scan_files, + knowledge_hits=extra_knowledge_files, + ) + target_files = self._select_target_files(merged_files, _scope, artifacts) - cluster: Optional[KnowledgeCluster] = None - if merged_files: - cluster = await self._build_cluster( - query=query, file_paths=merged_files, - query_keywords=query_keywords, top_k_files=top_k_files, - ) - if _tree_evidence and cluster and cluster.content: - pre_nav_parts = [ - f"[Tree evidence: {Path(fp).name}]\n{ev}" - for fp, ev in _tree_evidence.items() - ] - if pre_nav_parts: - pre_nav_ctx = "\n\n".join(pre_nav_parts) - if isinstance(cluster.content, list): - cluster.content = "\n".join(cluster.content) - cluster.content = f"{cluster.content}\n\n{pre_nav_ctx}" + await self._logger.info( + f"[Phase 3] Merged {len(merged_files)} files, " + f"target {len(target_files)} for agentic retrieval" + ) # ============================================================== - # Phase 3.5: Graph context enrichment (P5) - # Append related knowledge from graph neighbours to cluster content - # so the answer-generation LLM has richer context. + # Phase 4: Agentic retrieval loop # ============================================================== - graph_ctx = "" - if cluster: - graph_ctx = await self._gather_graph_context(cluster) - if graph_ctx and cluster.content: - if isinstance(cluster.content, list): - cluster.content = "\n".join(cluster.content) - cluster.content = f"{cluster.content}\n\n{graph_ctx}" + retrieval = await self._agentic_retrieve( + query, data_reqs, target_files, context, + ) - # ============================================================== - # Phase 3.8: Query classification (feeds Phase 3.75 + Phase 4) - # ============================================================== - _query_complexity, _query_intent = await self._classify_query_intent(query) - context.increment_loop() await self._logger.info( - f"[Phase 3.8] Query: complexity={_query_complexity}, intent={_query_intent}" + f"[Phase 4] Retrieval: {retrieval.rounds_used} rounds, " + f"complete={retrieval.is_complete}, " + f"{sum(len(ps) for ps in retrieval.pages_extracted.values())} pages" ) # ============================================================== - # Phase 3.75: Evidence adequacy closed-loop (Plan D) - # For computation/comparison queries, verify required data points - # are present and trigger targeted gap-fill if missing. + # Phase 4.5: Synthesis # ============================================================== - if ( - cluster and cluster.content - and _query_intent in ("computation", "comparison") - ): - _ev_text = ( - str(cluster.content) if isinstance(cluster.content, str) - else "\n".join(cluster.content) - ) - is_complete, missing = await self._check_evidence_completeness( - query, _query_intent, _ev_text, - ) - context.increment_loop() - if not is_complete and missing: - await self._logger.info( - f"[Phase 3.75] Missing data points: {missing}" - ) - gap_evidence = await self._fill_evidence_gaps( - query, missing, merged_files, artifacts, - scope=_scope, nav_cache=_nav_cache, - ) - if gap_evidence: - if isinstance(cluster.content, list): - cluster.content = "\n".join(cluster.content) - cluster.content = ( - f"{cluster.content}\n\n" - f"[Gap-fill evidence]\n{gap_evidence}" - ) - await self._logger.info( - f"[Phase 3.75] Filled {len(missing)} gaps " - f"({len(gap_evidence)} chars)" - ) - - # ============================================================== - # Phase 4: Structured Reasoning → Cluster Summary fallback - # ============================================================== - context.increment_loop() - answer = "" - should_save = True - - # Attempt structured reasoning for moderate/complex queries - _sr_files: List[str] = [] - if _query_complexity != "simple": - if tree_hits: - _scoped_hits = [fp for fp in tree_hits if _scope.contains(fp)] - _sr_files = _scoped_hits[: self._DEEP_STRUCTURED_MAX_FILES] - elif artifacts and artifacts.tree_available_paths: - _sr_files = [ - fp for fp in artifacts.tree_available_paths - if _scope.contains(fp) - ][: self._DEEP_STRUCTURED_MAX_FILES] - - if _sr_files: - await self._logger.info( - f"[Phase 4] Launching structured reasoning for " - f"{len(_sr_files)} tree-indexed files" - ) - sr_answer, sr_cluster, sr_evidence = await self._deep_structured_reasoning( - query, _sr_files, artifacts, context, _query_intent, - ) - - if sr_answer: - answer, should_save, should_answer = self._parse_summary_response( - sr_answer - ) - accepted, accept_reason = self._evaluate_evidence_acceptance( - query, sr_evidence or sr_answer, should_answer, - ) - await self._logger.info( - f"[Phase 4] Structured reasoning: " - f"accepted={accepted} ({accept_reason})" - ) - if accepted: - cluster = sr_cluster or cluster - else: - answer = "" - - # Fallback: cluster summary with ROI prompt or ReAct - if not answer: - if artifacts and artifacts.catalog_map and cluster and cluster.content: - _catalog_ctx_parts = [] - for fp in (cluster.search_results or merged_files)[:3]: - ctx = self._build_answer_context(fp, artifacts) - if ctx: - _catalog_ctx_parts.append(ctx) - if _catalog_ctx_parts: - _catalog_context = "\n".join(_catalog_ctx_parts) - if isinstance(cluster.content, list): - cluster.content = "\n".join(cluster.content) - cluster.content = ( - f"{cluster.content}\n\n" - f"[Document Context]\n{_catalog_context}" - ) - - if cluster and cluster.content: - await self._logger.info( - "[Phase 4:Fallback] Generating summary from cluster" - ) - answer, should_save, should_answer = ( - await self._summarise_cluster(query, cluster, _query_intent) - ) - cluster_evidence = ( - str(cluster.content) if cluster.content else "" - ) - accepted, accept_reason = ( - self._evaluate_evidence_acceptance( - query, cluster_evidence, should_answer, - ) - ) - if not accepted: - if llm_fallback: - answer, should_save = ( - await self._summarise_cluster_fallback(query) - ) - else: - # DEEP self-correction before giving up - sc_evidence = await self._deep_self_correct( - query, merged_files, query_keywords, context, - ) - if sc_evidence: - sc_cluster = self._make_answer_cluster( - query, sc_evidence[:5000], "DSC", - file_paths=list(merged_files)[:3], - ) - sc_cluster.content = sc_evidence - answer, should_save, should_answer = ( - await self._summarise_cluster(query, sc_cluster, _query_intent) - ) - sc_accepted, _ = self._evaluate_evidence_acceptance( - query, sc_evidence, should_answer, - ) - if sc_accepted: - cluster = sc_cluster - else: - return _NO_RESULTS_MESSAGE, None, context - else: - return _NO_RESULTS_MESSAGE, None, context - if not cluster.search_results: - cluster.search_results = list(merged_files) - elif llm_fallback: - answer, should_save = ( - await self._summarise_cluster_fallback(query) - ) - else: - await self._logger.info( - "[Phase 4:Fallback] Launching ReAct refinement" - ) - react_parts: List[str] = [] - if spec_context: - react_parts.append(spec_context) - if graph_ctx: - react_parts.append(graph_ctx) - if _pre_nav_evidence: - nav_seed = "\n\n".join( - f"[Pre-navigated: {Path(fp).name}]\n{ev}" - for fp, ev in _pre_nav_evidence.items() - ) - react_parts.append(nav_seed) - react_spec = "\n\n".join(react_parts) - react_answer, context = await self._react_refinement( - query=query, paths=paths, - initial_keywords=initial_keywords, - spec_context=react_spec, - enable_dir_scan=enable_dir_scan, - max_loops=max_loops, - max_token_budget=max_token_budget, - max_depth=max_depth, - include=include, exclude=exclude, - ) - if not cluster: - cluster = await self._build_cluster_from_context( - query=query, answer=react_answer, - context=context, - query_keywords=query_keywords, - top_k_files=top_k_files, - ) - elif react_answer and not cluster.content: - cluster.content = react_answer - if not cluster: - return _NO_RESULTS_MESSAGE, None, context - answer, should_save, should_answer = ( - await self._summarise_cluster(query, cluster, _query_intent) - ) - final_evidence = ( - str(cluster.content) if cluster.content else "" - ) - final_accepted, _ = self._evaluate_evidence_acceptance( - query, final_evidence, should_answer, - ) - if not final_accepted: - if llm_fallback: - answer, should_save = ( - await self._summarise_cluster_fallback(query) - ) - else: - sc_evidence = await self._deep_self_correct( - query, merged_files, query_keywords, context, - ) - if sc_evidence: - sc_cluster = self._make_answer_cluster( - query, sc_evidence[:5000], "DSC", - file_paths=list(merged_files)[:3], - ) - sc_cluster.content = sc_evidence - answer, should_save, _ = ( - await self._summarise_cluster(query, sc_cluster, _query_intent) - ) - cluster = sc_cluster - else: - return _NO_RESULTS_MESSAGE, None, context + answer, should_save, cluster = await self._synthesize_from_retrieval( + query, _query_intent, retrieval, merged_files, + ) # ============================================================== - # Phase 4.5: Computation verification (Plan E) + # Phase 4.75: Computation verification # ============================================================== - if answer and _query_intent == "computation": + if answer and answer != _NO_RESULTS_MESSAGE and _query_intent == "computation": answer, was_corrected = await self._verify_computation(query, answer) if was_corrected: _, should_save, _ = self._parse_summary_response(answer) @@ -2886,10 +2695,24 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _DEEP_STRUCTURED_MAX_FILES: int = 3 """Maximum files to process through structured reasoning pipeline.""" + # --- Agentic retrieval --- + _AGENTIC_MAX_ROUNDS: int = 3 + """Maximum retrieval rounds in the agentic loop.""" + _AGENTIC_MAX_PAGES_PER_ROUND: int = 8 + """Maximum new pages to extract per round per file.""" + _AGENTIC_MAX_TOTAL_PAGES: int = 20 + """Maximum total pages across all rounds.""" + _AGENTIC_MAX_FILES: int = 3 + """Maximum files to process through agentic retrieval.""" + _AGENTIC_SECTION_MAP_DEPTH: int = 8 + """Section map depth for agentic page selection.""" + _AGENTIC_EVIDENCE_MAX_CHARS: int = 40_000 + """Maximum evidence characters to feed to synthesis prompt.""" + # --- Evidence acceptance thresholds --- _EVIDENCE_MIN_ACCEPT_LENGTH: int = 800 """Minimum evidence character length for heuristic override.""" - _EVIDENCE_KEYWORD_COVERAGE_THRESHOLD: float = 0.5 + _EVIDENCE_KEYWORD_COVERAGE_THRESHOLD: float = 0.3 """Minimum keyword coverage ratio for heuristic override.""" _NUMERIC_INTENT_KEYWORDS: frozenset = frozenset({ @@ -6858,7 +6681,496 @@ async def _summarise_fast_fallback( return answer, False # Never save fallback answers # ------------------------------------------------------------------ - # Deep Structured Reasoning pipeline + # Agentic retrieval pipeline (DEEP mode) + # ------------------------------------------------------------------ + + async def _analyze_data_requirements( + self, query: str, intent: str, + ) -> DataRequirements: + """Identify what data points the query needs before any retrieval.""" + try: + prompt = DEEP_DATA_REQUIREMENTS.format(query=query, intent=intent) + resp = await self.llm.achat( + [{"role": "user", "content": prompt}], stream=False, + ) + self.llm_usages.append(resp.usage) + raw = (resp.content or "").strip() + data = self._extract_json_object(raw) + if data: + return DataRequirements( + data_points=data.get("data_points", [query]), + likely_sources=data.get("likely_sources", []), + formula=data.get("formula"), + time_period=data.get("time_period"), + intent=intent, + ) + except Exception as exc: + await self._logger.warning( + f"[Phase 3] Data requirements analysis failed: {exc}" + ) + return DataRequirements( + data_points=[query], likely_sources=[], formula=None, + time_period=None, intent=intent, + ) + + def _select_target_files( + self, + merged_files: List[str], + scope: "_PathScope", + artifacts: Optional["CompileArtifacts"], + ) -> List[str]: + """Select top files for agentic retrieval, preferring tree-indexed ones.""" + scoped = [fp for fp in merged_files if scope.contains(fp)] + if not scoped: + scoped = list(merged_files) + + tree_paths = ( + artifacts.tree_available_paths if artifacts else set() + ) + with_tree = [fp for fp in scoped if fp in tree_paths] + without_tree = [fp for fp in scoped if fp not in tree_paths] + ranked = with_tree + without_tree + return ranked[: self._AGENTIC_MAX_FILES] + + async def _select_pages_for_data( + self, + query: str, + data_reqs: DataRequirements, + section_map: str, + evidence_so_far: str, + fetched_pages: set, + sections_meta: Optional[List[Dict[str, Any]]] = None, + total_pages: Optional[int] = None, + ) -> set: + """LLM-driven page selection given document outline and data needs.""" + reqs_str = "\n".join( + f"- {dp}" for dp in data_reqs.data_points + ) + if data_reqs.formula: + reqs_str += f"\nFormula: {data_reqs.formula}" + + fetched_str = ( + ", ".join(str(p) for p in sorted(fetched_pages)) + if fetched_pages else "None" + ) + + prompt = DEEP_PAGE_SELECT.format( + query=query, + data_requirements=reqs_str, + section_map=section_map, + fetched_pages=fetched_str, + ) + try: + resp = await self.llm.achat( + [{"role": "user", "content": prompt}], stream=False, + ) + self.llm_usages.append(resp.usage) + raw = (resp.content or "").strip() + match = re.search(r"\[[\d\s,]+\]", raw) + if match: + pages = json.loads(match.group()) + result = { + int(p) for p in pages + if isinstance(p, (int, float)) and int(p) > 0 + } + if result: + return result + except Exception as exc: + await self._logger.warning( + f"[Phase 4] Page selection failed: {exc}" + ) + + return self._fallback_page_selection( + data_reqs, sections_meta, fetched_pages, total_pages, + ) + + @staticmethod + def _fallback_page_selection( + data_reqs: DataRequirements, + sections_meta: Optional[List[Dict[str, Any]]], + fetched_pages: set, + total_pages: Optional[int], + ) -> set: + """Heuristic page selection when LLM fails or returns empty.""" + candidates: set = set() + + if sections_meta: + source_keywords = { + s.lower() for s in data_reqs.likely_sources + } + for sec in sections_meta: + title_lower = (sec.get("title") or "").lower() + pr = sec.get("page_range") + if not pr or not pr[0]: + continue + if any(kw in title_lower for kw in source_keywords): + start, end = int(pr[0]), int(pr[1]) + for p in range(start, min(start + 4, end + 1)): + candidates.add(p) + + if not candidates and total_pages and total_pages > 10: + mid = total_pages // 2 + last_quarter = total_pages * 3 // 4 + for p in range(mid, min(mid + 4, total_pages + 1)): + candidates.add(p) + for p in range(last_quarter, min(last_quarter + 4, total_pages + 1)): + candidates.add(p) + + return candidates - fetched_pages + + async def _check_data_requirements( + self, + query: str, + data_reqs: DataRequirements, + evidence: str, + ) -> Tuple[bool, List[str]]: + """Check if evidence satisfies all data requirements. + + Returns ``(is_complete, missing_data_points)``. + """ + try: + prompt = DEEP_CHECK_REQUIREMENTS.format( + query=query, + data_points="\n".join(f"- {dp}" for dp in data_reqs.data_points), + formula=data_reqs.formula or "N/A", + evidence=evidence[:self._AGENTIC_EVIDENCE_MAX_CHARS], + ) + resp = await self.llm.achat( + [{"role": "user", "content": prompt}], stream=False, + ) + self.llm_usages.append(resp.usage) + raw = (resp.content or "").strip() + json_start = raw.find("{") + json_end = raw.rfind("}") + if json_start >= 0 and json_end > json_start: + data = json.loads(raw[json_start : json_end + 1]) + is_complete = bool(data.get("complete", True)) + missing = data.get("missing", []) + if isinstance(missing, list) and missing: + return False, [str(m) for m in missing[:5]] + return is_complete, [] + except Exception as exc: + await self._logger.warning( + f"[Phase 4] Data requirements check failed: {exc}" + ) + return True, [] + + async def _agentic_retrieve( + self, + query: str, + data_reqs: DataRequirements, + target_files: List[str], + context: "SearchContext", + ) -> RetrievalResult: + """Core agentic retrieval loop: select pages → extract → check → repeat.""" + indexer = self._get_tree_indexer() + evidence_parts: List[str] = [] + pages_extracted: Dict[str, set] = {} + total_pages = 0 + + outlines: Dict[str, str] = {} + outlines_meta: Dict[str, List[Dict[str, Any]]] = {} + file_total_pages: Dict[str, int] = {} + for fp in target_files: + tree = indexer.load_tree(fp) if indexer else None + if tree and tree.total_pages: + file_total_pages[fp] = tree.total_pages + + # Strategy 1: LLM-analyzed TOC pages (highest quality) + tp = file_total_pages.get(fp) + toc_outline, toc_meta = await self._build_outline_from_toc_pages(fp, tp) + if toc_outline.strip(): + outlines[fp] = toc_outline + outlines_meta[fp] = toc_meta + continue + + # Strategy 2: Tree-index section map (fallback) + if tree and tree.root: + outline, sec_meta = self._build_section_map( + tree.root, max_depth=self._AGENTIC_SECTION_MAP_DEPTH, + ) + if outline.strip(): + outlines[fp] = outline + outlines_meta[fp] = sec_meta + + current_reqs = data_reqs + + for round_idx in range(self._AGENTIC_MAX_ROUNDS): + round_fetched_any = False + + for fp in target_files: + if total_pages >= self._AGENTIC_MAX_TOTAL_PAGES: + break + + fname = Path(fp).name + fetched = pages_extracted.get(fp, set()) + outline = outlines.get(fp, "") + sec_meta = outlines_meta.get(fp) + tp = file_total_pages.get(fp) + + if not outline and not tp: + continue + + new_pages = await self._select_pages_for_data( + query, current_reqs, outline or "(no outline available)", + "\n\n".join(evidence_parts)[:8000], + fetched, + sections_meta=sec_meta, + total_pages=tp, + ) + new_pages -= fetched + if not new_pages: + continue + + budget = self._AGENTIC_MAX_PAGES_PER_ROUND + capped = sorted(new_pages)[:budget] + try: + contents = DocumentExtractor.extract_pages(fp, capped) + for pc in contents: + if pc.content and pc.content.strip(): + evidence_parts.append( + f"[{fname} p.{pc.page_number}]\n{pc.content}" + ) + pages_extracted.setdefault(fp, set()).update(capped) + total_pages += len(capped) + round_fetched_any = True + except Exception as exc: + await self._logger.warning( + f"[Phase 4] Page extraction failed for {fname}: {exc}" + ) + + # Append table digests for newly fetched pages only + try: + from sirchmunk.utils.file_utils import get_fast_hash + fhash = get_fast_hash(fp) + if fhash: + tables = self._load_table_digest(self.work_path, fhash) + if tables: + new_page_set = set(capped) + page_tables = [ + t for t in tables + if t.get("page_number") in new_page_set + ] + if page_tables: + table_ev = self._format_table_evidence( + page_tables, + max_chars=self._TABLE_EVIDENCE_DEFAULT_CHARS, + query=query, + ) + if table_ev: + evidence_parts.append( + f"[{fname} tables]\n{table_ev}" + ) + except Exception: + pass + + context.increment_loop() + + if not round_fetched_any: + break + + combined = "\n\n".join(evidence_parts) + is_complete, missing = await self._check_data_requirements( + query, current_reqs, combined, + ) + context.increment_loop() + + await self._logger.info( + f"[Phase 4] Round {round_idx + 1}: " + f"{total_pages} pages, complete={is_complete}, " + f"missing={len(missing)}" + ) + + if is_complete or not missing: + return RetrievalResult( + evidence=combined[:self._AGENTIC_EVIDENCE_MAX_CHARS], + pages_extracted={ + fp: sorted(ps) for fp, ps in pages_extracted.items() + }, + is_complete=True, + rounds_used=round_idx + 1, + ) + + current_reqs = DataRequirements( + data_points=missing, + likely_sources=data_reqs.likely_sources, + formula=data_reqs.formula, + time_period=data_reqs.time_period, + intent=data_reqs.intent, + ) + + combined = "\n\n".join(evidence_parts) + return RetrievalResult( + evidence=combined[:self._AGENTIC_EVIDENCE_MAX_CHARS], + pages_extracted={ + fp: sorted(ps) for fp, ps in pages_extracted.items() + }, + is_complete=False, + rounds_used=self._AGENTIC_MAX_ROUNDS, + ) + + async def _synthesize_from_retrieval( + self, + query: str, + intent: str, + retrieval: RetrievalResult, + file_paths: List[str], + ) -> Tuple[str, bool, Optional["KnowledgeCluster"]]: + """Synthesize final answer from agentic retrieval evidence.""" + if not retrieval.evidence.strip(): + return _NO_RESULTS_MESSAGE, False, None + + synth_prompt = self._select_synthesis_prompt( + query, retrieval.evidence, intent, + ) + resp = await self.llm.achat( + messages=[{"role": "user", "content": synth_prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) + + raw = resp.content or "" + answer, should_save, should_answer = self._parse_summary_response(raw) + + accepted, reason = self._evaluate_evidence_acceptance( + query, retrieval.evidence, should_answer, + retrieval_complete=retrieval.is_complete, + ) + await self._logger.info( + f"[Phase 4.5] Synthesis: accepted={accepted} ({reason})" + ) + + if not accepted: + return _NO_RESULTS_MESSAGE, False, None + + cluster = self._make_answer_cluster( + query, retrieval.evidence[:5000], "AGT", + file_paths=list(retrieval.pages_extracted.keys())[:3], + ) + cluster.content = retrieval.evidence + return answer, should_save, cluster + + # ------------------------------------------------------------------ + # LLM-powered document outline from TOC pages + # ------------------------------------------------------------------ + + _TOC_ANALYSIS_PAGES: List[int] = [1, 2, 3] + + async def _build_outline_from_toc_pages( + self, + file_path: str, + total_pages: Optional[int] = None, + ) -> Tuple[str, List[Dict[str, Any]]]: + """Build a section map by extracting and LLM-analyzing TOC pages. + + Extracts the first few pages of a PDF (where the Table of Contents + typically resides), sends the text to the LLM for structural parsing, + and returns an outline string plus section metadata in the same format + as ``_build_section_map()`` for seamless integration. + + Results are cached per file hash to avoid repeated LLM calls. + """ + from sirchmunk.utils.file_utils import get_fast_hash + + fhash = get_fast_hash(file_path) + if not fhash: + return "", [] + + cache_dir = self.work_path / ".cache" / "compile" / "toc_outlines" + cache_path = cache_dir / f"{fhash}.json" + + sections_raw: Optional[list] = None + if cache_path.exists(): + try: + sections_raw = json.loads(cache_path.read_text()) + except Exception: + pass + + if sections_raw is None: + try: + contents = DocumentExtractor.extract_pages( + file_path, self._TOC_ANALYSIS_PAGES, + ) + toc_text = "\n\n".join( + f"--- Page {pc.page_number} ---\n{pc.content}" + for pc in contents if pc.content and pc.content.strip() + ) + if len(toc_text.strip()) < 200: + return "", [] + + tp = total_pages or len(contents) + prompt = DEEP_TOC_ANALYSIS.format( + toc_page_text=toc_text[:6000], + total_pages=tp, + ) + resp = await self.llm.achat( + [{"role": "user", "content": prompt}], stream=False, + ) + self.llm_usages.append(resp.usage) + + raw = (resp.content or "").strip() + sections_raw = self._extract_json_array(raw) + if sections_raw is None: + sections_raw = [] + + cache_dir.mkdir(parents=True, exist_ok=True) + cache_path.write_text(json.dumps(sections_raw, ensure_ascii=False)) + except Exception as exc: + await self._logger.warning( + f"[Phase 4] TOC outline extraction failed: {exc}" + ) + return "", [] + + if not sections_raw: + return "", [] + + return self._toc_sections_to_outline(sections_raw, total_pages) + + @staticmethod + def _toc_sections_to_outline( + sections_raw: list, + total_pages: Optional[int] = None, + ) -> Tuple[str, List[Dict[str, Any]]]: + """Convert raw TOC section list to outline string and metadata.""" + sections_meta: List[Dict[str, Any]] = [] + lines: List[str] = [] + + for i, sec in enumerate(sections_raw): + if not isinstance(sec, dict): + continue + title = str(sec.get("title", "")).strip() + if not title: + continue + + ps = sec.get("page_start") + pe = sec.get("page_end") + level = int(sec.get("level", 1)) - 1 + + page_range = None + if ps is not None: + ps = int(ps) + pe = int(pe) if pe is not None else (total_pages or ps) + page_range = [ps, pe] + + idx = len(sections_meta) + sections_meta.append({ + "idx": idx, + "title": title, + "page_range": page_range, + "char_range": None, + "depth": level, + "node_id": f"toc_{idx}", + "summary": "", + }) + + indent = " " * level + page_str = f"(p{page_range[0]}-{page_range[1]})" if page_range else "" + lines.append(f"[{idx}] {indent}{title} {page_str}") + + return "\n".join(lines), sections_meta + + # ------------------------------------------------------------------ + # Deep Structured Reasoning pipeline (legacy, used by older code paths) # ------------------------------------------------------------------ @staticmethod From 7484f689b0523a4d21f9ba60eb1478747f1e24cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 17 May 2026 22:01:13 +0800 Subject: [PATCH 67/70] improve search and prompts --- src/sirchmunk/llm/prompts.py | 30 +++++--- src/sirchmunk/search.py | 138 +++++++++++++++++++++++++++++++++-- 2 files changed, 151 insertions(+), 17 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index e5b3238..05d3078 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -424,7 +424,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. -6. **Rounding**: Match the precision implied by the query. When the question asks for a value in specific units (e.g. "in USD millions"), round the final result to match the expected granularity. For percentages, use at most one decimal place unless the query explicitly asks for more. For dollar amounts, round to the nearest whole number in the stated unit. Example: if the raw calculation yields $8.738 billion and the expected unit is "USD billions", report $8.7 billion or $8.74 billion, not $8.738 billion. +6. **Rounding**: Match the precision implied by the query. Dollar amounts: round to the nearest whole number in the stated unit. Percentages: round to the nearest whole number unless the query explicitly requests decimal precision. Example: $8,738 millions → "$8,738 million" or "$8.74 billion"; $381,603 thousands → "$382 million"; 36.47% → "36%". 7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Input Data @@ -467,7 +467,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. -6. **Rounding**: Match the precision implied by the query. When the question asks for a value in specific units (e.g. "in USD millions"), round the final result to match the expected granularity. For percentages, use at most one decimal place unless the query explicitly asks for more. For dollar amounts, round to the nearest whole number in the stated unit. Example: if the raw calculation yields $8.738 billion and the expected unit is "USD billions", report $8.7 billion or $8.74 billion, not $8.738 billion. +6. **Rounding**: Match the precision implied by the query. Dollar amounts: round to the nearest whole number in the stated unit. Percentages: round to the nearest whole number unless the query explicitly requests decimal precision. Example: $8,738 millions → "$8,738 million" or "$8.74 billion"; $381,603 thousands → "$382 million"; 36.47% → "36%". 7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Document Context @@ -564,10 +564,12 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: ### Constraints 1. **Language Continuity**: The output must be in the SAME language as the User Input. -2. Find the EXACT value stated in the evidence. Do not compute or estimate. -3. If multiple candidate values exist, select based on the closest match to the query's time period, entity, and metric. -4. Quote the source passage containing the value. -5. If the value is not explicitly stated in the evidence, mark SHOULD_ANSWER as "false". +2. Find the value stated in the evidence. If the exact total is not stated but its components are clearly present, compute it by summing the components. +3. **Rounding**: When the query specifies units (e.g., "in USD millions"), convert and round the extracted value to match. Dollar amounts: round to the nearest whole number in the stated unit. Percentages: round to the nearest whole number unless the query asks for decimal precision. Examples: $302,578 thousands → "$303 million"; $381,603 thousands → "$382 million"; 36.47% → "36%". +4. If multiple candidate values exist, select based on the closest match to the query's time period, entity, and metric. +5. Quote the source passage containing the value. +6. Only mark SHOULD_ANSWER as "false" when no relevant data exists in the evidence. Always prefer attempting an answer over refusing. +7. When the evidence contains relevant data but you feel uncertain, still attempt to answer. ### Input Data - **User Input**: {user_input} @@ -595,9 +597,14 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: c) **SUBSTITUTION**: Plug in the extracted values into the formula. d) **CALCULATION**: Show arithmetic step by step. For each step, write the operation and its result. e) **VERIFICATION**: Re-compute the final result independently to confirm. -3. **Rounding**: Match the precision implied by the query. For percentages, use at most one decimal place unless asked for more. For dollar amounts, round to the nearest whole number in the stated unit. +3. **Rounding**: Round the final result to match the precision implied by the query. + - Dollar amounts: round to the nearest whole number in the stated unit. Example: $381,603 thousands → "$382 million". + - Percentages: round to 1 decimal place if the query says "round to one decimal place"; otherwise round to the nearest whole number. + - Ratios: round to 2 decimal places. + - When the query says "round to X decimal places", follow that exactly. 4. **Units**: Convert all values to consistent units before computing. 5. If any required data point is missing, explicitly state what is missing and mark SHOULD_ANSWER as "false". +6. **Definition precision**: When computing financial ratios, use the broadest standard definition unless the query specifies otherwise. Quick ratio = (Current Assets - Inventories) / Current Liabilities. Asset turnover = Revenue / Average Total Assets. ### Input Data - **User Input**: {user_input} @@ -629,8 +636,9 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. Extract values for EACH comparison dimension from the evidence. 3. Present in a structured comparison table. 4. State the direction and magnitude of difference or change. -5. **Precision**: Use exact values from the evidence. When computing changes, show the arithmetic. -6. If values for any comparison dimension are missing, state what is missing. +5. **Precision**: Use exact values from the evidence. When computing changes, show the arithmetic. Round results: dollar amounts to the nearest whole number in the stated unit, percentages to the nearest whole number. +6. **"Best performing"** means highest growth rate or change rate, not highest absolute value, unless the query explicitly says "largest" or "highest revenue". +7. If values for any comparison dimension are missing, state what is missing. ### Input Data - **User Input**: {user_input} @@ -741,8 +749,12 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: ### Pages Already Fetched {fetched_pages} +### Evidence Already Gathered +{evidence_summary} + ### Instructions - Reason about which sections contain the needed data based on section titles, summaries, and page ranges. +- Consider what data has already been gathered to avoid fetching redundant content. - Financial statements (Income Statement, Balance Sheet, Cash Flow Statement) typically contain quantitative data needed for calculations. - Sections with tables are often high-value for data extraction. - Do NOT re-select pages listed in "Pages Already Fetched". diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 8384b39..4f731ca 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2708,6 +2708,8 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Section map depth for agentic page selection.""" _AGENTIC_EVIDENCE_MAX_CHARS: int = 40_000 """Maximum evidence characters to feed to synthesis prompt.""" + _SHORT_DOC_THRESHOLD: int = 30 + """Documents with this many pages or fewer are extracted in full.""" # --- Evidence acceptance thresholds --- _EVIDENCE_MIN_ACCEPT_LENGTH: int = 800 @@ -6754,11 +6756,16 @@ async def _select_pages_for_data( if fetched_pages else "None" ) + evidence_summary = ( + evidence_so_far[:2000] if evidence_so_far.strip() else "None yet" + ) + prompt = DEEP_PAGE_SELECT.format( query=query, data_requirements=reqs_str, section_map=section_map, fetched_pages=fetched_str, + evidence_summary=evidence_summary, ) try: resp = await self.llm.achat( @@ -6871,13 +6878,44 @@ async def _agentic_retrieve( outlines: Dict[str, str] = {} outlines_meta: Dict[str, List[Dict[str, Any]]] = {} file_total_pages: Dict[str, int] = {} + outline_target_files: List[str] = [] + for fp in target_files: tree = indexer.load_tree(fp) if indexer else None if tree and tree.total_pages: file_total_pages[fp] = tree.total_pages + if fp not in file_total_pages: + try: + from pypdf import PdfReader + file_total_pages[fp] = len(PdfReader(fp).pages) + except Exception: + pass - # Strategy 1: LLM-analyzed TOC pages (highest quality) tp = file_total_pages.get(fp) + if tp and tp <= self._SHORT_DOC_THRESHOLD: + fname = Path(fp).name + try: + all_pages = list(range(1, tp + 1)) + contents = DocumentExtractor.extract_pages(fp, all_pages) + for pc in contents: + if pc.content and pc.content.strip(): + evidence_parts.append( + f"[{fname} p.{pc.page_number}]\n{pc.content}" + ) + pages_extracted[fp] = set(all_pages) + total_pages += tp + except Exception as exc: + await self._logger.warning( + f"[Phase 4] Full extraction of short doc {fname} failed: {exc}" + ) + outline_target_files.append(fp) + continue + outline_target_files.append(fp) + + for fp in outline_target_files: + tp = file_total_pages.get(fp) + + # Strategy 1: LLM-analyzed TOC pages (highest quality) toc_outline, toc_meta = await self._build_outline_from_toc_pages(fp, tp) if toc_outline.strip(): outlines[fp] = toc_outline @@ -6885,6 +6923,7 @@ async def _agentic_retrieve( continue # Strategy 2: Tree-index section map (fallback) + tree = indexer.load_tree(fp) if indexer else None if tree and tree.root: outline, sec_meta = self._build_section_map( tree.root, max_depth=self._AGENTIC_SECTION_MAP_DEPTH, @@ -6892,13 +6931,33 @@ async def _agentic_retrieve( if outline.strip(): outlines[fp] = outline outlines_meta[fp] = sec_meta + continue + + # Strategy 3: Sampled content outline for docs with known page count + if tp: + sampled_outline, sampled_meta = self._build_sampled_outline( + fp, tp, + ) + outlines[fp] = sampled_outline + outlines_meta[fp] = sampled_meta current_reqs = data_reqs + if not outline_target_files and evidence_parts: + combined = "\n\n".join(evidence_parts) + return RetrievalResult( + evidence=combined[:self._AGENTIC_EVIDENCE_MAX_CHARS], + pages_extracted={ + fp: sorted(ps) for fp, ps in pages_extracted.items() + }, + is_complete=True, + rounds_used=0, + ) + for round_idx in range(self._AGENTIC_MAX_ROUNDS): round_fetched_any = False - for fp in target_files: + for fp in outline_target_files: if total_pages >= self._AGENTIC_MAX_TOTAL_PAGES: break @@ -7054,7 +7113,7 @@ async def _synthesize_from_retrieval( # LLM-powered document outline from TOC pages # ------------------------------------------------------------------ - _TOC_ANALYSIS_PAGES: List[int] = [1, 2, 3] + _TOC_ANALYSIS_PAGES: List[int] = [1, 2, 3, 4, 5] async def _build_outline_from_toc_pages( self, @@ -7100,7 +7159,7 @@ async def _build_outline_from_toc_pages( tp = total_pages or len(contents) prompt = DEEP_TOC_ANALYSIS.format( - toc_page_text=toc_text[:6000], + toc_page_text=toc_text[:12000], total_pages=tp, ) resp = await self.llm.achat( @@ -7133,7 +7192,6 @@ def _toc_sections_to_outline( ) -> Tuple[str, List[Dict[str, Any]]]: """Convert raw TOC section list to outline string and metadata.""" sections_meta: List[Dict[str, Any]] = [] - lines: List[str] = [] for i, sec in enumerate(sections_raw): if not isinstance(sec, dict): @@ -7163,9 +7221,73 @@ def _toc_sections_to_outline( "summary": "", }) - indent = " " * level - page_str = f"(p{page_range[0]}-{page_range[1]})" if page_range else "" - lines.append(f"[{idx}] {indent}{title} {page_str}") + # Post-process: fix page_range errors from LLM inference + for i, sec in enumerate(sections_meta): + pr = sec.get("page_range") + if not pr: + continue + needs_fix = pr[1] < pr[0] + if not needs_fix and pr[1] == pr[0] and i + 1 < len(sections_meta): + next_pr = sections_meta[i + 1].get("page_range") + needs_fix = next_pr and next_pr[0] == pr[0] + if needs_fix: + for j in range(i + 1, len(sections_meta)): + next_pr = sections_meta[j].get("page_range") + if next_pr and next_pr[0] > pr[0]: + pr[1] = next_pr[0] - 1 + break + else: + pr[1] = total_pages or pr[0] + + lines: List[str] = [] + for sec in sections_meta: + pr = sec.get("page_range") + indent = " " * sec["depth"] + page_str = f"(p{pr[0]}-{pr[1]})" if pr else "" + lines.append(f"[{sec['idx']}] {indent}{sec['title']} {page_str}") + + return "\n".join(lines), sections_meta + + @staticmethod + def _build_sampled_outline( + file_path: str, + total_pages: int, + interval: int = 20, + ) -> Tuple[str, List[Dict[str, Any]]]: + """Build an outline by sampling page content at regular intervals. + + Used as a fallback when TOC parsing and tree indices are unavailable. + Gives the LLM enough context to make informed page selections. + """ + sample_pages = list(range(1, total_pages + 1, interval)) + if total_pages not in sample_pages: + sample_pages.append(total_pages) + + sections_meta: List[Dict[str, Any]] = [] + lines: List[str] = [] + + try: + contents = DocumentExtractor.extract_pages(file_path, sample_pages) + page_snippets = { + pc.page_number: (pc.content or "").strip()[:200] + for pc in contents if pc.content + } + except Exception: + page_snippets = {} + + for i, pg in enumerate(sample_pages): + snippet = page_snippets.get(pg, "") + snippet_clean = " ".join(snippet.split())[:150] + next_pg = sample_pages[i + 1] if i + 1 < len(sample_pages) else total_pages + page_range = [pg, next_pg] + + title = f'p{pg}: "{snippet_clean}..."' if snippet_clean else f"p{pg}" + sections_meta.append({ + "idx": i, "title": title, + "page_range": page_range, "char_range": None, + "depth": 0, "node_id": f"sample_{i}", "summary": "", + }) + lines.append(f"[{i}] {title} (p{page_range[0]}-{page_range[1]})") return "\n".join(lines), sections_meta From 629f27a9509462b572db078fda2dcf3881c78951 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 17 May 2026 23:25:46 +0800 Subject: [PATCH 68/70] refine search --- src/sirchmunk/llm/prompts.py | 32 ++++++---- src/sirchmunk/search.py | 113 ++++++++++++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 15 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 05d3078..a2ab3e2 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -424,7 +424,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. -6. **Rounding**: Match the precision implied by the query. Dollar amounts: round to the nearest whole number in the stated unit. Percentages: round to the nearest whole number unless the query explicitly requests decimal precision. Example: $8,738 millions → "$8,738 million" or "$8.74 billion"; $381,603 thousands → "$382 million"; 36.47% → "36%". +6. **Precision**: Preserve the precision from the source data. If the source says "$8.70 billion" or "36.8%", report that exact value. Only round when converting between units (e.g., $381,603 thousands → "$382 million"). When the query specifies a rounding rule, follow it exactly. 7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Input Data @@ -450,7 +450,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: [Generate the Markdown Briefing here with detailed analysis, supporting evidence, and full calculation steps. Complete all reasoning BEFORE the PRECISE_ANSWER tag.] -[State ONLY the final verified answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] +[State ONLY the final verified answer. For yes/no questions, start with "Yes" or "No". For identification questions ("What is the largest…?"), state the name/label. For value questions, state the number with units (e.g. "0.83", "$1,832 million"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] true/false true/false @@ -467,7 +467,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. -6. **Rounding**: Match the precision implied by the query. Dollar amounts: round to the nearest whole number in the stated unit. Percentages: round to the nearest whole number unless the query explicitly requests decimal precision. Example: $8,738 millions → "$8,738 million" or "$8.74 billion"; $381,603 thousands → "$382 million"; 36.47% → "36%". +6. **Precision**: Preserve the precision from the source data. If the source says "$8.70 billion" or "36.8%", report that exact value. Only round when converting between units (e.g., $381,603 thousands → "$382 million"). When the query specifies a rounding rule, follow it exactly. 7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Document Context @@ -496,7 +496,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: [Generate the Markdown Briefing here with detailed analysis, supporting evidence, and full calculation steps. Complete all reasoning BEFORE the PRECISE_ANSWER tag.] -[State ONLY the final verified answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] +[State ONLY the final verified answer. For yes/no questions, start with "Yes" or "No". For identification questions ("What is the largest…?"), state the name/label. For value questions, state the number with units (e.g. "0.83", "$1,832 million"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] true/false true/false @@ -565,11 +565,15 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: ### Constraints 1. **Language Continuity**: The output must be in the SAME language as the User Input. 2. Find the value stated in the evidence. If the exact total is not stated but its components are clearly present, compute it by summing the components. -3. **Rounding**: When the query specifies units (e.g., "in USD millions"), convert and round the extracted value to match. Dollar amounts: round to the nearest whole number in the stated unit. Percentages: round to the nearest whole number unless the query asks for decimal precision. Examples: $302,578 thousands → "$303 million"; $381,603 thousands → "$382 million"; 36.47% → "36%". +3. **Precision**: Preserve the precision from the source data. If the source says "36.8%", report "36.8%", not "37%". If the source says "$8.70 billion", report "$8.70 billion", not "$9 billion". Only round when explicitly converting between units (e.g., $302,578 thousands → "$303 million"; $381,603 thousands → "$382 million"). When converting, round to the nearest whole number in the target unit. When the query specifies a rounding rule (e.g., "round to one decimal place"), follow it exactly. 4. If multiple candidate values exist, select based on the closest match to the query's time period, entity, and metric. 5. Quote the source passage containing the value. 6. Only mark SHOULD_ANSWER as "false" when no relevant data exists in the evidence. Always prefer attempting an answer over refusing. 7. When the evidence contains relevant data but you feel uncertain, still attempt to answer. +8. **Answer format**: + - For yes/no questions (e.g., "Has X increased?", "Did the company…?"), PRECISE_ANSWER must start with "Yes" or "No", followed by a brief qualifier if needed. + - For identification questions (e.g., "What is the largest segment?", "Which company had the highest…?"), PRECISE_ANSWER should state the name/label, not the numeric value. + - For value questions (e.g., "What was total revenue?"), PRECISE_ANSWER should state the numeric value with units. ### Input Data - **User Input**: {user_input} @@ -581,7 +585,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: **Extracted value**: [The specific value found] -[value only, e.g. "$1,832 million", "Yes", "42%"] +[value only, e.g. "$1,832 million", "Yes, it increased by 5%", "Cloud Services segment"] true/false true/false """ @@ -597,14 +601,15 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: c) **SUBSTITUTION**: Plug in the extracted values into the formula. d) **CALCULATION**: Show arithmetic step by step. For each step, write the operation and its result. e) **VERIFICATION**: Re-compute the final result independently to confirm. -3. **Rounding**: Round the final result to match the precision implied by the query. - - Dollar amounts: round to the nearest whole number in the stated unit. Example: $381,603 thousands → "$382 million". - - Percentages: round to 1 decimal place if the query says "round to one decimal place"; otherwise round to the nearest whole number. +3. **Precision**: Preserve meaningful precision in computed results. + - Dollar amounts: when converting units, round to the nearest whole number in the target unit. Example: $381,603 thousands → "$382 million". Otherwise preserve the precision of the input values. + - Percentages: round to 1 decimal place by default. If the query says "round to one decimal place", follow exactly. If the query says "round to nearest whole number" or the context clearly calls for it, round to whole. - Ratios: round to 2 decimal places. - - When the query says "round to X decimal places", follow that exactly. + - When the query specifies "round to X decimal places", follow that exactly. 4. **Units**: Convert all values to consistent units before computing. 5. If any required data point is missing, explicitly state what is missing and mark SHOULD_ANSWER as "false". 6. **Definition precision**: When computing financial ratios, use the broadest standard definition unless the query specifies otherwise. Quick ratio = (Current Assets - Inventories) / Current Liabilities. Asset turnover = Revenue / Average Total Assets. +7. **Answer format**: For yes/no questions, PRECISE_ANSWER must start with "Yes" or "No". For identification questions, state the name/label, not just the number. ### Input Data - **User Input**: {user_input} @@ -636,7 +641,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. Extract values for EACH comparison dimension from the evidence. 3. Present in a structured comparison table. 4. State the direction and magnitude of difference or change. -5. **Precision**: Use exact values from the evidence. When computing changes, show the arithmetic. Round results: dollar amounts to the nearest whole number in the stated unit, percentages to the nearest whole number. +5. **Precision**: Use exact values from the evidence. When computing changes, show the arithmetic. Preserve the precision of source values. Only round when converting units (e.g., thousands → millions). Percentages: round to 1 decimal place by default. 6. **"Best performing"** means highest growth rate or change rate, not highest absolute value, unless the query explicitly says "largest" or "highest revenue". 7. If values for any comparison dimension are missing, state what is missing. @@ -728,11 +733,12 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: ### Instructions 1. List each specific data point needed to answer this question (e.g., "Total Revenue for FY2022", "Accounts Payable as of fiscal year end 2019"). 2. For each data point, identify the likely document section type where it would appear (e.g., "Income Statement", "Balance Sheet", "Cash Flow Statement", "Notes to Financial Statements", "Management Discussion and Analysis", "Segment Information"). -3. If a calculation is required, state the exact formula. +3. If a calculation is required, state the exact formula with explicit variable names matching how they typically appear in financial statements. Example: "Quick Ratio = (Total Current Assets - Total Inventories) / Total Current Liabilities". 4. Identify the time period(s) required. +5. For comparison or identification questions (e.g., "What is the largest segment?", "Which year had the highest growth?"), note what dimensions need comparison. Return ONLY valid JSON on a single line: -{{"data_points": ["data point 1", "data point 2"], "likely_sources": ["section type 1", "section type 2"], "formula": "formula or null", "time_period": "period or null"}} +{{"data_points": ["data point 1", "data point 2"], "likely_sources": ["section type 1", "section type 2"], "formula": "explicit formula with variable names, or null", "time_period": "period or null"}} """ DEEP_PAGE_SELECT = """You are locating specific data in a document. Select pages to fetch. diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 4f731ca..47c78f5 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2258,6 +2258,7 @@ async def _search_deep( # ============================================================== answer, should_save, cluster = await self._synthesize_from_retrieval( query, _query_intent, retrieval, merged_files, + formula=data_reqs.formula, ) # ============================================================== @@ -6941,6 +6942,14 @@ async def _agentic_retrieve( outlines[fp] = sampled_outline outlines_meta[fp] = sampled_meta + for fp in outline_target_files: + if fp in outlines and fp in outlines_meta: + refined, refined_meta = self._refine_large_sections( + fp, outlines[fp], outlines_meta[fp], + ) + outlines[fp] = refined + outlines_meta[fp] = refined_meta + current_reqs = data_reqs if not outline_target_files and evidence_parts: @@ -7074,13 +7083,18 @@ async def _synthesize_from_retrieval( intent: str, retrieval: RetrievalResult, file_paths: List[str], + formula: Optional[str] = None, ) -> Tuple[str, bool, Optional["KnowledgeCluster"]]: """Synthesize final answer from agentic retrieval evidence.""" if not retrieval.evidence.strip(): return _NO_RESULTS_MESSAGE, False, None + evidence = retrieval.evidence + if formula and intent == "computation": + evidence = f"[Required Formula: {formula}]\n\n{evidence}" + synth_prompt = self._select_synthesis_prompt( - query, retrieval.evidence, intent, + query, evidence, intent, ) resp = await self.llm.achat( messages=[{"role": "user", "content": synth_prompt}], @@ -7252,7 +7266,7 @@ def _toc_sections_to_outline( def _build_sampled_outline( file_path: str, total_pages: int, - interval: int = 20, + interval: int = 10, ) -> Tuple[str, List[Dict[str, Any]]]: """Build an outline by sampling page content at regular intervals. @@ -7291,6 +7305,101 @@ def _build_sampled_outline( return "\n".join(lines), sections_meta + _LARGE_SECTION_THRESHOLD: int = 15 + """Sections spanning more pages than this get sub-section sampling.""" + + @staticmethod + def _refine_large_sections( + file_path: str, + outline: str, + sections_meta: List[Dict[str, Any]], + ) -> Tuple[str, List[Dict[str, Any]]]: + """Expand sections spanning many pages with sampled sub-entries. + + For each section whose page range exceeds ``_LARGE_SECTION_THRESHOLD``, + sample pages at ~5-page intervals within the section and insert + sub-entries so the LLM has finer-grained navigation. + """ + threshold = 15 + large_sections = [] + for sec in sections_meta: + pr = sec.get("page_range") + if pr and (pr[1] - pr[0] + 1) > threshold: + large_sections.append(sec) + + if not large_sections: + return outline, sections_meta + + sample_pages_needed: List[int] = [] + for sec in large_sections: + pr = sec["page_range"] + start, end = pr[0], pr[1] + interval = max(5, (end - start) // 8) + for p in range(start, end + 1, interval): + if p not in sample_pages_needed: + sample_pages_needed.append(p) + + try: + contents = DocumentExtractor.extract_pages( + file_path, sorted(sample_pages_needed), + ) + page_snippets = { + pc.page_number: " ".join( + (pc.content or "").strip()[:200].split() + )[:150] + for pc in contents if pc.content and pc.content.strip() + } + except Exception: + return outline, sections_meta + + if not page_snippets: + return outline, sections_meta + + new_meta: List[Dict[str, Any]] = [] + large_set = {id(s) for s in large_sections} + + for sec in sections_meta: + if id(sec) not in large_set: + sec_copy = dict(sec) + sec_copy["idx"] = len(new_meta) + new_meta.append(sec_copy) + continue + + pr = sec["page_range"] + start, end = pr[0], pr[1] + interval = max(5, (end - start) // 8) + sub_pages = list(range(start, end + 1, interval)) + + parent_idx = len(new_meta) + parent_copy = dict(sec) + parent_copy["idx"] = parent_idx + new_meta.append(parent_copy) + + for i, sp in enumerate(sub_pages): + snippet = page_snippets.get(sp, "") + sub_end = sub_pages[i + 1] - 1 if i + 1 < len(sub_pages) else end + if not snippet: + continue + sub_idx = len(new_meta) + new_meta.append({ + "idx": sub_idx, + "title": f'"{snippet}..."', + "page_range": [sp, sub_end], + "char_range": None, + "depth": sec.get("depth", 0) + 1, + "node_id": f"subsample_{parent_idx}_{i}", + "summary": "", + }) + + lines: List[str] = [] + for sec in new_meta: + pr = sec.get("page_range") + indent = " " * sec.get("depth", 0) + page_str = f"(p{pr[0]}-{pr[1]})" if pr else "" + lines.append(f"[{sec['idx']}] {indent}{sec['title']} {page_str}") + + return "\n".join(lines), new_meta + # ------------------------------------------------------------------ # Deep Structured Reasoning pipeline (legacy, used by older code paths) # ------------------------------------------------------------------ From 9c53bb6a11beac1d934e475b1d5eb4e8f45e5239 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 18 May 2026 11:39:22 +0800 Subject: [PATCH 69/70] fallback and improve prompts --- src/sirchmunk/llm/prompts.py | 8 +-- src/sirchmunk/search.py | 103 ----------------------------------- 2 files changed, 4 insertions(+), 107 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index a2ab3e2..7576dfb 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -424,7 +424,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. -6. **Precision**: Preserve the precision from the source data. If the source says "$8.70 billion" or "36.8%", report that exact value. Only round when converting between units (e.g., $381,603 thousands → "$382 million"). When the query specifies a rounding rule, follow it exactly. +6. **Value precision**: Preserve the precision from the source data. If the source says "$8.70 billion" or "36.8%", report that exact value. Only round when converting between units (e.g., $381,603 thousands → "$382 million"). When the query specifies a rounding rule, follow it exactly. 7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Input Data @@ -467,7 +467,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. -6. **Precision**: Preserve the precision from the source data. If the source says "$8.70 billion" or "36.8%", report that exact value. Only round when converting between units (e.g., $381,603 thousands → "$382 million"). When the query specifies a rounding rule, follow it exactly. +6. **Value precision**: Preserve the precision from the source data. If the source says "$8.70 billion" or "36.8%", report that exact value. Only round when converting between units (e.g., $381,603 thousands → "$382 million"). When the query specifies a rounding rule, follow it exactly. 7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Document Context @@ -565,7 +565,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: ### Constraints 1. **Language Continuity**: The output must be in the SAME language as the User Input. 2. Find the value stated in the evidence. If the exact total is not stated but its components are clearly present, compute it by summing the components. -3. **Precision**: Preserve the precision from the source data. If the source says "36.8%", report "36.8%", not "37%". If the source says "$8.70 billion", report "$8.70 billion", not "$9 billion". Only round when explicitly converting between units (e.g., $302,578 thousands → "$303 million"; $381,603 thousands → "$382 million"). When converting, round to the nearest whole number in the target unit. When the query specifies a rounding rule (e.g., "round to one decimal place"), follow it exactly. +3. **Value precision**: Preserve the precision from the source data. If the source says "36.8%", report "36.8%", not "37%". If the source says "$8.70 billion", report "$8.70 billion", not "$9 billion". Only round when explicitly converting between units (e.g., $302,578 thousands → "$303 million"; $381,603 thousands → "$382 million"). When the query specifies a rounding rule (e.g., "round to one decimal place"), follow it exactly. 4. If multiple candidate values exist, select based on the closest match to the query's time period, entity, and metric. 5. Quote the source passage containing the value. 6. Only mark SHOULD_ANSWER as "false" when no relevant data exists in the evidence. Always prefer attempting an answer over refusing. @@ -601,7 +601,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: c) **SUBSTITUTION**: Plug in the extracted values into the formula. d) **CALCULATION**: Show arithmetic step by step. For each step, write the operation and its result. e) **VERIFICATION**: Re-compute the final result independently to confirm. -3. **Precision**: Preserve meaningful precision in computed results. +3. **Value precision**: Preserve meaningful precision in computed results. - Dollar amounts: when converting units, round to the nearest whole number in the target unit. Example: $381,603 thousands → "$382 million". Otherwise preserve the precision of the input values. - Percentages: round to 1 decimal place by default. If the query says "round to one decimal place", follow exactly. If the query says "round to nearest whole number" or the context clearly calls for it, round to whole. - Ratios: round to 2 decimal places. diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 47c78f5..d02c138 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -6942,14 +6942,6 @@ async def _agentic_retrieve( outlines[fp] = sampled_outline outlines_meta[fp] = sampled_meta - for fp in outline_target_files: - if fp in outlines and fp in outlines_meta: - refined, refined_meta = self._refine_large_sections( - fp, outlines[fp], outlines_meta[fp], - ) - outlines[fp] = refined - outlines_meta[fp] = refined_meta - current_reqs = data_reqs if not outline_target_files and evidence_parts: @@ -7305,101 +7297,6 @@ def _build_sampled_outline( return "\n".join(lines), sections_meta - _LARGE_SECTION_THRESHOLD: int = 15 - """Sections spanning more pages than this get sub-section sampling.""" - - @staticmethod - def _refine_large_sections( - file_path: str, - outline: str, - sections_meta: List[Dict[str, Any]], - ) -> Tuple[str, List[Dict[str, Any]]]: - """Expand sections spanning many pages with sampled sub-entries. - - For each section whose page range exceeds ``_LARGE_SECTION_THRESHOLD``, - sample pages at ~5-page intervals within the section and insert - sub-entries so the LLM has finer-grained navigation. - """ - threshold = 15 - large_sections = [] - for sec in sections_meta: - pr = sec.get("page_range") - if pr and (pr[1] - pr[0] + 1) > threshold: - large_sections.append(sec) - - if not large_sections: - return outline, sections_meta - - sample_pages_needed: List[int] = [] - for sec in large_sections: - pr = sec["page_range"] - start, end = pr[0], pr[1] - interval = max(5, (end - start) // 8) - for p in range(start, end + 1, interval): - if p not in sample_pages_needed: - sample_pages_needed.append(p) - - try: - contents = DocumentExtractor.extract_pages( - file_path, sorted(sample_pages_needed), - ) - page_snippets = { - pc.page_number: " ".join( - (pc.content or "").strip()[:200].split() - )[:150] - for pc in contents if pc.content and pc.content.strip() - } - except Exception: - return outline, sections_meta - - if not page_snippets: - return outline, sections_meta - - new_meta: List[Dict[str, Any]] = [] - large_set = {id(s) for s in large_sections} - - for sec in sections_meta: - if id(sec) not in large_set: - sec_copy = dict(sec) - sec_copy["idx"] = len(new_meta) - new_meta.append(sec_copy) - continue - - pr = sec["page_range"] - start, end = pr[0], pr[1] - interval = max(5, (end - start) // 8) - sub_pages = list(range(start, end + 1, interval)) - - parent_idx = len(new_meta) - parent_copy = dict(sec) - parent_copy["idx"] = parent_idx - new_meta.append(parent_copy) - - for i, sp in enumerate(sub_pages): - snippet = page_snippets.get(sp, "") - sub_end = sub_pages[i + 1] - 1 if i + 1 < len(sub_pages) else end - if not snippet: - continue - sub_idx = len(new_meta) - new_meta.append({ - "idx": sub_idx, - "title": f'"{snippet}..."', - "page_range": [sp, sub_end], - "char_range": None, - "depth": sec.get("depth", 0) + 1, - "node_id": f"subsample_{parent_idx}_{i}", - "summary": "", - }) - - lines: List[str] = [] - for sec in new_meta: - pr = sec.get("page_range") - indent = " " * sec.get("depth", 0) - page_str = f"(p{pr[0]}-{pr[1]})" if pr else "" - lines.append(f"[{sec['idx']}] {indent}{sec['title']} {page_str}") - - return "\n".join(lines), new_meta - # ------------------------------------------------------------------ # Deep Structured Reasoning pipeline (legacy, used by older code paths) # ------------------------------------------------------------------ From 8e79d779c606e0f78cf0fc903baa2776131133d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 18 May 2026 18:09:47 +0800 Subject: [PATCH 70/70] improve synthesis prompts: formula definitions, rounding, answer format MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add narrow quick ratio formula (Cash+STI+Receivables)/CL instead of broad (CA-Inventories)/CL — matches standard financial analysis - Add interest coverage ratio rule: negative EBIT → ratio = 0 - Strengthen yes/no answer format enforcement (MUST begin with Yes/No) - Add nature/composition guidance (describe proportions, not just totals) - Add listing completeness instruction - Refine rounding: 1dp for %, whole number for $≥10 in target unit, 2dp for $<10 in target unit - Add "use query formula if provided" precedence in data requirements Stable +3 improvements verified: Verizon quick ratio, AMCOR restructuring, Netflix unit conversion (all 3/4 correct across 4 benchmark runs). Co-Authored-By: Claude Opus 4.6 --- src/sirchmunk/llm/prompts.py | 41 ++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 7576dfb..e56002c 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -424,7 +424,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. -6. **Value precision**: Preserve the precision from the source data. If the source says "$8.70 billion" or "36.8%", report that exact value. Only round when converting between units (e.g., $381,603 thousands → "$382 million"). When the query specifies a rounding rule, follow it exactly. +6. **Rounding**: When converting units (thousands → millions, millions → billions), round to the nearest whole number in the target unit if result ≥10; use 2 decimal places if result <10. Examples: $5,466,312 thousands → "$5,466 million"; $389 million → "$0.39 billion". Percentages: round to 1 decimal place. When the query specifies a rounding rule, follow it exactly. 7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Input Data @@ -450,7 +450,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: [Generate the Markdown Briefing here with detailed analysis, supporting evidence, and full calculation steps. Complete all reasoning BEFORE the PRECISE_ANSWER tag.] -[State ONLY the final verified answer. For yes/no questions, start with "Yes" or "No". For identification questions ("What is the largest…?"), state the name/label. For value questions, state the number with units (e.g. "0.83", "$1,832 million"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] +[State ONLY the final verified answer. CRITICAL: For yes/no questions, the FIRST word MUST be "Yes" or "No". For identification questions ("What is the largest…?"), state the name/label. For value questions, state the number with units (e.g. "$1,832 million", "39.7%"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] true/false true/false @@ -467,7 +467,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. -6. **Value precision**: Preserve the precision from the source data. If the source says "$8.70 billion" or "36.8%", report that exact value. Only round when converting between units (e.g., $381,603 thousands → "$382 million"). When the query specifies a rounding rule, follow it exactly. +6. **Rounding**: When converting units (thousands → millions, millions → billions), round to the nearest whole number in the target unit if result ≥10; use 2 decimal places if result <10. Examples: $5,466,312 thousands → "$5,466 million"; $389 million → "$0.39 billion". Percentages: round to 1 decimal place. When the query specifies a rounding rule, follow it exactly. 7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Document Context @@ -496,7 +496,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: [Generate the Markdown Briefing here with detailed analysis, supporting evidence, and full calculation steps. Complete all reasoning BEFORE the PRECISE_ANSWER tag.] -[State ONLY the final verified answer. For yes/no questions, start with "Yes" or "No". For identification questions ("What is the largest…?"), state the name/label. For value questions, state the number with units (e.g. "0.83", "$1,832 million"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] +[State ONLY the final verified answer. CRITICAL: For yes/no questions, the FIRST word MUST be "Yes" or "No". For identification questions ("What is the largest…?"), state the name/label. For value questions, state the number with units (e.g. "$1,832 million", "39.7%"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] true/false true/false @@ -565,15 +565,17 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: ### Constraints 1. **Language Continuity**: The output must be in the SAME language as the User Input. 2. Find the value stated in the evidence. If the exact total is not stated but its components are clearly present, compute it by summing the components. -3. **Value precision**: Preserve the precision from the source data. If the source says "36.8%", report "36.8%", not "37%". If the source says "$8.70 billion", report "$8.70 billion", not "$9 billion". Only round when explicitly converting between units (e.g., $302,578 thousands → "$303 million"; $381,603 thousands → "$382 million"). When the query specifies a rounding rule (e.g., "round to one decimal place"), follow it exactly. +3. **Rounding**: When converting units (e.g., thousands → millions), round to the nearest whole number in the target unit IF the result is ≥10. If the result is <10, use 2 decimal places. Examples: $5,466,312 thousands → "$5,466 million"; $302,578 thousands → "$303 million"; $389 million → "$0.39 billion". Percentages: round to 1 decimal place. When the query specifies a rounding rule, follow it exactly. 4. If multiple candidate values exist, select based on the closest match to the query's time period, entity, and metric. 5. Quote the source passage containing the value. 6. Only mark SHOULD_ANSWER as "false" when no relevant data exists in the evidence. Always prefer attempting an answer over refusing. 7. When the evidence contains relevant data but you feel uncertain, still attempt to answer. 8. **Answer format**: - - For yes/no questions (e.g., "Has X increased?", "Did the company…?"), PRECISE_ANSWER must start with "Yes" or "No", followed by a brief qualifier if needed. + - For yes/no questions (e.g., "Has X increased?", "Did the company…?", "Does X maintain…?", "Is X healthy?"), PRECISE_ANSWER **MUST** begin with "Yes" or "No" as the very first word. Then provide a brief qualifier. - For identification questions (e.g., "What is the largest segment?", "Which company had the highest…?"), PRECISE_ANSWER should state the name/label, not the numeric value. - For value questions (e.g., "What was total revenue?"), PRECISE_ANSWER should state the numeric value with units. + - When asked about the "nature", "purpose", "composition", or "breakdown" of something, describe what it IS and its proportional components (e.g., "87% relates to employee liabilities"), not just the total dollar amount. + - When listing items (e.g., "Which securities are registered?"), provide the COMPLETE list from the evidence, not just one example. ### Input Data - **User Input**: {user_input} @@ -601,15 +603,23 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: c) **SUBSTITUTION**: Plug in the extracted values into the formula. d) **CALCULATION**: Show arithmetic step by step. For each step, write the operation and its result. e) **VERIFICATION**: Re-compute the final result independently to confirm. -3. **Value precision**: Preserve meaningful precision in computed results. - - Dollar amounts: when converting units, round to the nearest whole number in the target unit. Example: $381,603 thousands → "$382 million". Otherwise preserve the precision of the input values. - - Percentages: round to 1 decimal place by default. If the query says "round to one decimal place", follow exactly. If the query says "round to nearest whole number" or the context clearly calls for it, round to whole. +3. **Rounding**: + - Dollar amounts: when converting units, round to the nearest whole number in the target unit IF the result is ≥10. If the result is <10 in the target unit, use 2 decimal places. Examples: $381,603 thousands → "$382 million"; $5,466,312 thousands → "$5,466 million"; $389 million → "$0.39 billion". + - Percentages: round to 1 decimal place. - Ratios: round to 2 decimal places. + - Per-share values: round to 2 decimal places. - When the query specifies "round to X decimal places", follow that exactly. 4. **Units**: Convert all values to consistent units before computing. 5. If any required data point is missing, explicitly state what is missing and mark SHOULD_ANSWER as "false". -6. **Definition precision**: When computing financial ratios, use the broadest standard definition unless the query specifies otherwise. Quick ratio = (Current Assets - Inventories) / Current Liabilities. Asset turnover = Revenue / Average Total Assets. -7. **Answer format**: For yes/no questions, PRECISE_ANSWER must start with "Yes" or "No". For identification questions, state the name/label, not just the number. +6. **Financial ratio definitions**: + - **Quick ratio** = (Cash and Cash Equivalents + Short-term Investments + Net Receivables) / Total Current Liabilities. Do NOT include inventories, prepaid expenses, or other current assets in the numerator. + - **Interest coverage ratio** = EBIT / Interest Expense. If EBIT is negative, the coverage ratio is zero (or negative) — a company cannot service debt from negative earnings. + - **Asset turnover** = Revenue / Average Total Assets. + - A quick ratio below 1.0x generally indicates the company does NOT have a reasonably healthy liquidity position. +7. **Answer format**: + - For yes/no questions (e.g., "Does X have healthy liquidity?", "Has X improved?", "Does X maintain…?"), PRECISE_ANSWER **MUST** begin with "Yes" or "No" as the very first word. + - For identification questions, state the name/label, not just the number. + - When asked about "nature", "purpose", or "composition", describe qualitative aspects and proportions, not just total amounts. ### Input Data - **User Input**: {user_input} @@ -641,9 +651,10 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. Extract values for EACH comparison dimension from the evidence. 3. Present in a structured comparison table. 4. State the direction and magnitude of difference or change. -5. **Precision**: Use exact values from the evidence. When computing changes, show the arithmetic. Preserve the precision of source values. Only round when converting units (e.g., thousands → millions). Percentages: round to 1 decimal place by default. +5. **Rounding**: When computing changes or growth rates, round percentages to 1 decimal place. When converting units (e.g., thousands → millions), round to nearest whole number in target unit if result ≥10; otherwise use 2 decimal places. 6. **"Best performing"** means highest growth rate or change rate, not highest absolute value, unless the query explicitly says "largest" or "highest revenue". 7. If values for any comparison dimension are missing, state what is missing. +8. **Answer format**: For yes/no questions ("Has X improved?", "Was there any change?"), PRECISE_ANSWER **MUST** begin with "Yes" or "No" as the very first word, followed by the comparison details. ### Input Data - **User Input**: {user_input} @@ -733,7 +744,11 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: ### Instructions 1. List each specific data point needed to answer this question (e.g., "Total Revenue for FY2022", "Accounts Payable as of fiscal year end 2019"). 2. For each data point, identify the likely document section type where it would appear (e.g., "Income Statement", "Balance Sheet", "Cash Flow Statement", "Notes to Financial Statements", "Management Discussion and Analysis", "Segment Information"). -3. If a calculation is required, state the exact formula with explicit variable names matching how they typically appear in financial statements. Example: "Quick Ratio = (Total Current Assets - Total Inventories) / Total Current Liabilities". +3. If a calculation is required, state the exact formula with explicit variable names matching how they typically appear in financial statements. If the question provides its own formula definition, use THAT formula exactly. Otherwise use these standard definitions: + - Quick Ratio = (Cash and Cash Equivalents + Short-term Investments + Net Receivables) / Total Current Liabilities + - Interest Coverage Ratio = EBIT / Interest Expense (if EBIT is negative, ratio = 0) + - Asset Turnover = Revenue / Average Total Assets + - Net Profit Margin = Net Income / Total Revenue 4. Identify the time period(s) required. 5. For comparison or identification questions (e.g., "What is the largest segment?", "Which year had the highest growth?"), note what dimensions need comparison.