该仓库提供了可直接运行的 ARTO(Adaptive Reasoning Topology Optimization) 实现骨架,按论文中的核心模块拆分:
- 难度分层(Natural Differentiation)与原型增强难度分类器
- 容量-拓扑联合选择(Capacity-Topology Joint Selection)
- DC2T(Dynamic Chain-to-Tree)
- UT/OT 实例级诊断与阈值校准
- 三层级回退策略(L1/L2/L3)
arto/difficulty.py:冻结编码器 + MLP + prototype 对齐训练、离散/连续难度输出arto/natural_differentiation.py:按最小可解模型定义 tier(Eq.2)arto/router.py:容量函数(Eq.10)与配置选择(Eq.11)arto/dc2t.py:四阶段树构造、Conf/Novel/Q(Eq.14-16)、阈值与聚合(Eq.17-18)arto/diagnostics.py:UT/OT(Eq.ut/ot)与 ε_UT 的 F1 校准arto/fallback.py:L1/L2/L3 回退arto/pipeline.py:端到端推理闭环scripts/train_difficulty.py:训练难度分类器scripts/run_inference.py:最小可运行推理示例tests/test_arto.py:单元测试
- 难度分类器:384 维输入,MLP
384→256→128→5,GELU,dropout=0.1 - 损失:
L = CE + λ * L_align,λ=0.15 - 优化器:Adam,
lr=4e-5,batch=8,epochs=80(脚本默认) - 推理采样参数(论文建议):
temperature=0.7,top-p=0.9,max_tokens=4096 - UT 阈值:在验证集用 F1 标定(
calibrate_ut_threshold)
python -m pip install -e .
pytest
python scripts/run_inference.py "If x+3=10, what is x?"训练数据格式(jsonl):
{"question": "...", "tier": 1}运行:
python scripts/train_difficulty.py --train data/train.jsonl --out checkpoints/difficulty.pt为保证仓库在离线环境可运行,默认提供 HashingEncoder 作为可替代冻结语义编码器;如需严格对齐论文中的 all-MiniLM-L6-v2,可将 FrozenEncoder 接口替换为 sentence-transformers 实例。