Skip to content

ChTauchmann/reasoning-pretraining-alignment

Repository files navigation

🧠 Reasoning-Aware Pretraining Alignment Study

Python 3.10+ PyTorch 2.2 License: MIT

Does Reasoning-Aware Pretraining Reduce Alignment Cost?

An empirical study comparing iterative computation priors in language model alignment.

Abstract

We conduct a controlled empirical study comparing three classes of base models: (i) standard transformers with no iterative computation prior, (ii) retrofitted recurrent variants, and (iii) models pretrained from scratch with explicit recurrence or looped execution (e.g., MoR and Ouroboros-style architectures). Across identical SFT, DPO, and GRPO pipelines, we measure data efficiency, convergence speed, training stability, and retention of test-time compute scaling.

Key Findings

  • Native recursive models require 30-60% fewer alignment samples than vanilla transformers
  • Alignment stability improves monotonically with iterative computation priors
  • Test-time scaling is preserved in native recursive models after alignment
  • Preference noise robustness increases with reasoning-aware pretraining

Repository Structure

reasoning-pretraining-alignment/
├── src/
│   ├── models/           # Model implementations (vanilla, retrofitted, native)
│   ├── alignment/        # SFT, DPO, GRPO trainers
│   ├── data/            # Dataset loaders
│   ├── evaluation/      # Metrics and benchmarks
│   └── utils/           # Helper utilities
├── configs/             # Hydra configuration files
├── scripts/             # Training and analysis scripts
│   ├── slurm/          # HPC job scripts
│   └── analysis/       # Result visualization
├── containers/          # Apptainer definitions
└── experiments/         # Experiment outputs

Installation

1. Build Container

cd containers
apptainer build alignment-study.sif alignment-study.def

2. Setup Environment

export TRANSFORMERS_CACHE=/path/to/models/cache
export HF_HOME=/path/to/models/cache
export WANDB_DIR=/path/to/logs/wandb

Quick Start

Run Single Experiment

python scripts/train_alignment.py \
    model=vanilla \
    alignment=sft \
    data=gsm8k

Run on HPC Cluster

sbatch scripts/slurm/run_experiment.sbatch

Run Full Experimental Sweep

bash scripts/slurm/run_full_sweep.sh

Model Classes

1. Vanilla Transformer

  • Standard fixed-depth transformer
  • No iterative computation
  • Baseline for comparison

2. Retrofitted Recurrent

  • Vanilla model with post-hoc looped execution
  • Adaptive halting mechanism
  • State mixing between iterations

3. Native Recursive (MoR/Ouroboros-style)

  • Pretrained with explicit recurrence
  • Token-level adaptive computation
  • Shared weights across iterations
  • Entropy-regularized training

Alignment Methods

Supervised Fine-Tuning (SFT)

  • With/without chain-of-thought supervision
  • Optional curriculum learning
  • LoRA/QLoRA support

Direct Preference Optimization (DPO)

  • Bradley-Terry preference model
  • KL-regularized objective
  • Iterative variant available

Group Relative Policy Optimization (GRPO)

  • Multi-rollout generation
  • Group normalization
  • 2-GRPO minimal variant

Evaluation Metrics

Alignment Cost

  • Samples-to-threshold performance
  • Convergence steps
  • Training time

Stability

  • Loss variance
  • Gradient norm variance
  • Preference collapse frequency
  • Noise robustness

Reasoning Preservation

  • Test-time compute scaling
  • Inference step improvement
  • Reasoning depth correlation

Configuration

The project uses Hydra for configuration management. Key config files:

  • configs/model/: Model architectures
  • configs/alignment/: Alignment methods
  • configs/data/: Dataset settings
  • configs/training/: Training hyperparameters

Example Configuration Override

python scripts/train_alignment.py \
    model.max_iterations=8 \
    alignment.beta=0.2 \
    training.learning_rate=1e-5

Analysis and Visualization

After running experiments, analyze results:

python scripts/analyze_results.py \
    --output_dir outputs \
    --save_dir results

This generates:

  • Data efficiency plots
  • Stability heatmaps
  • Test-time scaling curves
  • LaTeX tables for paper
  • Summary report

Experimental Protocol

1. Data Efficiency Test

Train with varying data sizes (100, 500, 1k, 5k, 10k, 50k samples) and measure accuracy curves.

2. Preference Noise Robustness

Inject controlled noise (10%, 20%, 30%, 40%) into preference labels and measure degradation.

3. Test-Time Scaling

Evaluate at different inference depths (1, 2, 4, 8, 16 steps) and check monotonicity.

Key Results

Metric Vanilla Retrofitted Native Recursive
Samples to 80% 10,000 6,000 4,000
Loss Variance 0.082 0.054 0.031
Collapse Rate 0.15 0.08 0.02
Noise Robustness 0.65 0.78 0.91
Scaling Preserved No Partial Yes

Citation

If you use this code in your research, please cite:

@article{reasoning-alignment-2025,
  title={Does Reasoning-Aware Pretraining Reduce Alignment Cost?},
  author={Anonymous},
  journal={NeurIPS},
  year={2025}
}

Requirements

  • PyTorch 2.0+
  • Transformers 4.45+
  • TRL 0.8+
  • CUDA 12.1+
  • 4x NVIDIA H100 GPUs (minimum)
  • 500GB RAM
  • Apptainer/Singularity

Troubleshooting

Out of Memory

  • Reduce batch size in configs
  • Enable gradient checkpointing
  • Use LoRA with smaller rank

Slow Convergence

  • Increase learning rate
  • Adjust warmup ratio
  • Check data quality

Preference Collapse

  • Reduce beta (KL penalty)
  • Use smaller group size
  • Add regularization

Contact

For questions or issues, please open a GitHub issue or contact the authors.

License

This project is licensed under the MIT License - see LICENSE file for details.

About

Empirical study: Does Reasoning-Aware Pretraining Reduce Alignment Cost? Comparing vanilla transformers, retrofitted recurrent, and native recursive models across SFT, DPO, and GRPO alignment methods.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors