Skip to content

Commit 3449da9

Browse files
committed
feat: Add advanced filters and character-level tracing to RAG CLI
- Add --investor, --source-type, and --kind filters - Implement start_index tracing for precise document localization - Add --chunk-size, --chunk-overlap, and --format json support - Enhance interactive mode with filter context
1 parent 844c06f commit 3449da9

1 file changed

Lines changed: 83 additions & 10 deletions

File tree

examples/rag_langchain.py

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
python rag_langchain.py --interactive
1414
python rag_langchain.py --persist ./vectorstore "护城河分析"
1515
python rag_langchain.py --load ./vectorstore "巴菲特如何选股?"
16+
python rag_langchain.py "止损" --kind risk_management
17+
python rag_langchain.py "护城河" --investor warren_buffett
1618
"""
1719

1820
import argparse
@@ -105,6 +107,7 @@ def split_investor_documents(documents, chunk_size: int = 900, chunk_overlap: in
105107
106108
- 保留原 metadata(source/investor_id 等)
107109
- 增加 chunk_index/chunk_id/title_hint/source_type
110+
- 记录 start_index 用于精确溯源
108111
"""
109112
from langchain.text_splitter import RecursiveCharacterTextSplitter
110113
import re
@@ -113,6 +116,7 @@ def split_investor_documents(documents, chunk_size: int = 900, chunk_overlap: in
113116
chunk_size=chunk_size,
114117
chunk_overlap=chunk_overlap,
115118
separators=["\n## ", "\n### ", "\n#### ", "\n\n", "\n", " "],
119+
add_start_index=True,
116120
)
117121

118122
split_docs = []
@@ -225,9 +229,14 @@ def load_vectorstore(persist_dir: str):
225229
)
226230

227231

228-
def query_vectorstore(vectorstore, query: str, k: int = 5):
229-
"""查询向量存储"""
230-
results = vectorstore.similarity_search_with_score(query, k=k)
232+
def query_vectorstore(vectorstore, query: str, k: int = 5, filter_dict: dict = None):
233+
"""查询向量存储,支持元数据过滤"""
234+
# Chroma 过滤语法:{"metadata_key": "value"} 或 {"$and": [...]}
235+
results = vectorstore.similarity_search_with_score(
236+
query,
237+
k=k,
238+
filter=filter_dict
239+
)
231240
return results
232241

233242

@@ -243,6 +252,7 @@ def format_results(results):
243252
rule_id = doc.metadata.get("rule_id", "")
244253
chunk_id = doc.metadata.get("chunk_id", "")
245254
title_hint = doc.metadata.get("title_hint", "")
255+
start_index = doc.metadata.get("start_index", 0)
246256

247257
# 引用:优先 rule_id,其次 chunk_id
248258
citation = rule_id or chunk_id or "N/A"
@@ -252,6 +262,8 @@ def format_results(results):
252262
output.append(f" 投资者: {investor_name} ({investor_id})")
253263
if title_hint:
254264
output.append(f" 章节: {title_hint}")
265+
if source_type == "investor_doc":
266+
output.append(f" 位置: 字符偏移 {start_index}")
255267
output.append(f" 引用: {citation}")
256268
output.append("-" * 60)
257269

@@ -260,15 +272,17 @@ def format_results(results):
260272
if len(doc.page_content) > 500:
261273
content += "..."
262274
output.append(content)
263-
output.append(f"\n📌 可溯源引用: {source} -> {citation}")
275+
output.append(f"\n📌 可溯源引用: {source} -> {citation} (offset: {start_index})")
264276

265277
return "\n".join(output)
266278

267279

268-
def interactive_mode(vectorstore):
280+
def interactive_mode(vectorstore, filter_dict=None):
269281
"""交互模式"""
270282
print("\n" + "=" * 60)
271283
print("投资大师知识库 - 交互查询模式")
284+
if filter_dict:
285+
print(f"活动过滤器: {filter_dict}")
272286
print("输入问题进行查询,输入 'quit' 退出")
273287
print("=" * 60)
274288

@@ -286,7 +300,7 @@ def interactive_mode(vectorstore):
286300
if not query:
287301
continue
288302

289-
results = query_vectorstore(vectorstore, query)
303+
results = query_vectorstore(vectorstore, query, filter_dict=filter_dict)
290304
print(format_results(results))
291305

292306

@@ -333,12 +347,56 @@ def main():
333347
action="store_true",
334348
help="仅加载决策规则(更快)"
335349
)
350+
parser.add_argument(
351+
"--investor", "-inv",
352+
help="按投资者 ID 过滤 (例如: warren_buffett)"
353+
)
354+
parser.add_argument(
355+
"--source-type", "-t",
356+
choices=["investor_doc", "rule"],
357+
help="按来源类型过滤"
358+
)
359+
parser.add_argument(
360+
"--kind", "-knd",
361+
choices=["entry", "exit", "risk_management", "other"],
362+
help="按规则类型过滤 (仅对 rule 类型有效)"
363+
)
364+
parser.add_argument(
365+
"--chunk-size",
366+
type=int,
367+
default=900,
368+
help="投资者文档分块大小 (默认: 900)"
369+
)
370+
parser.add_argument(
371+
"--chunk-overlap",
372+
type=int,
373+
default=200,
374+
help="分块重叠大小 (默认: 200)"
375+
)
376+
parser.add_argument(
377+
"--format",
378+
choices=["text", "json"],
379+
default="text",
380+
help="输出格式 (默认: text)"
381+
)
336382

337383
args = parser.parse_args()
338384

339385
# 检查依赖
340386
check_dependencies()
341387

388+
# 构建过滤器
389+
filter_dict = {}
390+
if args.investor:
391+
filter_dict["investor_id"] = args.investor
392+
if args.source_type:
393+
filter_dict["source_type"] = args.source_type
394+
if args.kind:
395+
filter_dict["kind"] = args.kind
396+
397+
if not filter_dict:
398+
filter_dict = None
399+
342400
# 加载或创建向量存储
343401
if args.load:
344402
load_dir = Path(args.load)
@@ -359,7 +417,11 @@ def main():
359417
print(f"已加载 {len(documents)} 条决策规则")
360418
else:
361419
investor_docs = load_investor_documents()
362-
investor_docs = split_investor_documents(investor_docs)
420+
investor_docs = split_investor_documents(
421+
investor_docs,
422+
chunk_size=args.chunk_size,
423+
chunk_overlap=args.chunk_overlap
424+
)
363425
rule_docs = load_decision_rules()
364426
documents = investor_docs + rule_docs
365427
print(f"已加载 {len(investor_docs)} 个投资者文档分块 + {len(rule_docs)} 条决策规则")
@@ -370,10 +432,21 @@ def main():
370432

371433
# 执行查询
372434
if args.interactive:
373-
interactive_mode(vectorstore)
435+
interactive_mode(vectorstore, filter_dict=filter_dict)
374436
elif args.query:
375-
results = query_vectorstore(vectorstore, args.query, args.top_k)
376-
print(format_results(results))
437+
results = query_vectorstore(vectorstore, args.query, args.top_k, filter_dict=filter_dict)
438+
439+
if args.format == "json":
440+
import json
441+
json_results = []
442+
for doc, score in results:
443+
res = doc.metadata.copy()
444+
res["content"] = doc.page_content
445+
res["similarity_estimate"] = round(1 - score, 4)
446+
json_results.append(res)
447+
print(json.dumps(json_results, ensure_ascii=False, indent=2))
448+
else:
449+
print(format_results(results))
377450
else:
378451
parser.print_help()
379452

0 commit comments

Comments
 (0)