Skip to content

gdw0217181/Hello-world

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ARTO Reproduction (Code)

该仓库提供了可直接运行的 ARTO(Adaptive Reasoning Topology Optimization) 实现骨架,按论文中的核心模块拆分:

  • 难度分层(Natural Differentiation)与原型增强难度分类器
  • 容量-拓扑联合选择(Capacity-Topology Joint Selection)
  • DC2T(Dynamic Chain-to-Tree)
  • UT/OT 实例级诊断与阈值校准
  • 三层级回退策略(L1/L2/L3)

目录结构

  • arto/difficulty.py:冻结编码器 + MLP + prototype 对齐训练、离散/连续难度输出
  • arto/natural_differentiation.py:按最小可解模型定义 tier(Eq.2)
  • arto/router.py:容量函数(Eq.10)与配置选择(Eq.11)
  • arto/dc2t.py:四阶段树构造、Conf/Novel/Q(Eq.14-16)、阈值与聚合(Eq.17-18)
  • arto/diagnostics.py:UT/OT(Eq.ut/ot)与 ε_UT 的 F1 校准
  • arto/fallback.py:L1/L2/L3 回退
  • arto/pipeline.py:端到端推理闭环
  • scripts/train_difficulty.py:训练难度分类器
  • scripts/run_inference.py:最小可运行推理示例
  • tests/test_arto.py:单元测试

实验设定(按论文参数)

  • 难度分类器:384 维输入,MLP 384→256→128→5,GELU,dropout=0.1
  • 损失:L = CE + λ * L_alignλ=0.15
  • 优化器:Adam,lr=4e-5batch=8epochs=80(脚本默认)
  • 推理采样参数(论文建议):temperature=0.7top-p=0.9max_tokens=4096
  • UT 阈值:在验证集用 F1 标定(calibrate_ut_threshold

快速开始

python -m pip install -e .
pytest
python scripts/run_inference.py "If x+3=10, what is x?"

训练难度分类器

训练数据格式(jsonl):

{"question": "...", "tier": 1}

运行:

python scripts/train_difficulty.py --train data/train.jsonl --out checkpoints/difficulty.pt

说明

为保证仓库在离线环境可运行,默认提供 HashingEncoder 作为可替代冻结语义编码器;如需严格对齐论文中的 all-MiniLM-L6-v2,可将 FrozenEncoder 接口替换为 sentence-transformers 实例。

About

initialized introduction

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages