Dev Raj (23121) · Rishitha Pamu (23271) · Aaradhya Pathak (23005)
- Overview
- Key Contributions
- Task Definition
- Dataset
- Stage 1 — Supervised Fine-Tuning (SFT)
- Stage 2 — Reinforcement Learning via GRPO
- Stage 3 — Unlabeled Post-Training (MM-UPT)
- Mechanistic Interpretability
- Results
- Conclusion
- References
This repository replicates and extends the training pipeline from Cheng et al. (2025) on Qwen-2.5 model families (0.5B and 1.5B parameters). The project validates that:
- Supervised Fine-Tuning (SFT) on atomic skills (parametric recall + contextual reading), followed by
- Group Relative Policy Optimization (GRPO) on composite tasks,
enables models to generalize to unseen relational compositions — a capability called Complementary Reasoning (COMP).
Beyond behavioral evaluation, this work introduces a circuit-level mechanistic interpretability study that causally identifies a sparse, two-stage "Synthesis Circuit" localized strictly to layers 20–28 of the Qwen-2.5-1.5B model. The circuit explains why atomic skill decomposition is a necessary prerequisite for RL-driven generalization.
| # | Contribution | Description |
|---|---|---|
| 1 | Pipeline Replication Engineering | Full replication of the SFT→RL pipeline on Qwen-2.5 (0.5B, 1.5B), with documented failure modes and fixes |
| 2 | Novel CoT Entity Reward | A custom Chain-of-Thought reward for GRPO that provides denser gradient signal than sparse binary matching |
| 3 | MM-UPT | Self-supervised majority-voting post-training for zero-shot generalization without human-labeled targets |
| 4 | Synthesis Circuit | First causal, component-level mechanistic account of why atomic skill decomposition enables RL generalization |
| Type | Description | Example |
|---|---|---|
| MEM (Parametric) | Answer requires recalling facts from model weights; no context provided | "What book did the classmate of Jennifer Lang write?" |
| CTX (Contextual) | Answer is fully derivable from the provided biography paragraph | "What did Marie Pope major in?" (with paragraph) |
| COMP (Complementary) | Requires both sources simultaneously | "Who is the spouse of Charles Keith's advisor?" (advisor from CTX → spouse from MEM) |
| Level | Description |
|---|---|
| IID | Exact relational path appeared in training data |
| Composition | Individual relations were seen, but this combination is novel |
| Zero-shot | Entire relation type is unseen; requires out-of-distribution generalization |
Synthetic human biographies with 39 strictly defined relations, including:
- 8 symmetric relations (e.g., spouse, sibling)
- 8 pairs of inverse relations (e.g., husband/wife)
Generated using the Faker library to avoid conflicts with model pre-training data. Question-answer pairs are constructed by traversing relational chains of varying lengths from a synthetic knowledge graph.
Biography entries with empty question/answer fields are used for parametric memory injection during SFT and must be filtered before RL training.
| Group | Training | I.I.D. | Composition | Zero-shot |
|---|---|---|---|---|
| Parametric | 88,031 | 1,921 | 1,141 | 782 |
| Contextual | 2,651 | 1,910 | 1,320 | 453 |
| Complementary | 180,919 | 2,135 | 1,415 | 918 |
Train on MEM + CTX to teach two atomic skills independently:
- Parametric fact recall (from biography injection entries)
- Contextual reading comprehension (from QA entries with provided context)
An SFT(COMP) baseline is also trained for comparison.
Completion-only cross-entropy loss, masking all system and user prompt tokens:
where
Every token in the reasoning chain contributes to the gradient — not just the final answer token.
| Parameter | Value |
|---|---|
| Model | Qwen/Qwen2.5-0.5B |
| Effective batch size | 128 (per device=4, grad accum=32) |
| Learning rate | 3e-4 |
| LR scheduler | Cosine with min LR (3e-5) |
| Epochs | 16 (no early stopping) |
| Weight decay | 0.01 |
| Precision | bfloat16 |
| Optimizer | AdamW (fused) |
| Max sequence length | 2048 |
Checkpointing: Loss-threshold snapshots saved at train loss values of 0.30, 0.15, 0.05, and 0.01 for post-hoc selection.
Starting from the converged SFT(MEM+CTX) checkpoint, Group Relative Policy Optimization (GRPO) trains on COMP to synthesize atomic skills into complementary reasoning.
Core Claim: RL cannot create skills from scratch — it can only combine existing ones. Synthesis is only effective when atomic skills are already reliably present.
Unlike PPO, GRPO eliminates a separate value network by estimating baselines from a group of sampled outputs. For a prompt
where:
-
$\rho_i = \dfrac{\pi_\theta(o_i|q)}{\pi_{\theta_\text{old}}(o_i|q)}$ is the probability ratio -
$\hat{A}_i$ is the group-relative advantage -
$\beta$ is the KL penalty coefficient
A sparse reward providing no partial credit for intermediate reasoning:
where "So, the answer is:" in the generated text, and normalization lowercases and strips punctuation.
A hierarchically composed reward tailored for GRPO on COMP tasks. It explicitly rewards correct intermediate entity retrieval while strictly punishing hallucinated steps.
The three components are:
1. Annealed Answer Gate
Scales from 0 to 1 over training step
2. Count Alignment
A Gaussian kernel that symmetrically penalizes both over-generation and under-generation of entities. The hyperparameter
3. Entity Recall
Measures the proportion of ground-truth entities
| Parameter | Value |
|---|---|
| Base model |
SFT(MEM+CTX) at loss 0.05 |
| Rollouts per prompt |
8 |
| Per-device train batch size | 32 |
| Max completion length | 128 |
| Learning rate | 1e-5 |
| LR scheduler | Cosine with min LR (5e-7) |
| KL coefficient |
0.001 |
| Epochs | 2 |
Standard GRPO requires ground-truth labels to compute rewards, limiting reinforcement to annotated datasets. UPT (Majority-vote based Unlabeled Post-Training) allows the model to act as its own teacher, enabling continued training on any unlabeled complementary data — including distributions mimicking the test set.
In standard GRPO, the binary reward requires a known ground truth
Under UPT,
The updated reward function is:
By rewarding agreement with the group majority, the model reinforces its own internal consistency and reasoning paths, bootstrapping performance on entirely novel entity relations without human-annotated targets.
The Context-to-Memory (CTX→MEM) Handoff is the core operation under study: resolving a bridge entity from context and using it to retrieve a relational fact from parametric memory.
Example prompt:
"Tyler Carter works under Taylor Jackson. Who is the classmate of the boss of Tyler Carter?"
Answering requires:
- Reading from context: Taylor Jackson is Tyler Carter's boss (CTX hop)
- Retrieving from parametric memory: who Taylor Jackson's classmate is (MEM hop)
Hypothesis: SFT(MEM+CTX)→RL training produces a sparse, interpretable Synthesis Circuit implementing this handoff; SFT(COMP) produces diffuse, shortcut-based representations that fail on novel compositions.
The logit lens projects the residual stream at each transformer layer through the final LayerNorm and unembedding matrix, obtaining a layer-wise probability distribution over the vocabulary. The rank and probability of the correct answer token are tracked at the final sequence position across 500 discovery-set prompts.
| Model | Final Prob. | Peak Layer | Median Emergence | Never Top-1 |
|---|---|---|---|---|
SFT(MEM+CTX)→RL (Model A) |
0.0240 | 28 | 23.0 | 480/500 |
SFT(COMP) (Model B) |
0.0174 | 28 | 23.0 | 484/500 |
Model A achieves a 38% relative improvement in final-layer mean probability on the correct answer token.
- Layers 0–20: Both models follow nearly identical rank trajectories — domain-general linguistic processing, unaffected by training regime. Both reach a local minimum around rank 800 at layer 20.
- Layers 21–25: A transient rank spike (representational reorganization) occurs in both models.
- Layers 25–28: Model A's descent is sharper and more abrupt, consistent with a discrete late-layer circuit performing the CTX→MEM handoff.
-
Probability divergence
$\text{Prob}(A) - \text{Prob}(B)$ : Near-zero and alternating for layers 0–19; consistently positive and growing from layer 20 onward, peaking at layers 27 and 28. This establishes layers 20–28 as the primary region of mechanistic interest. - Bridge position analysis: Uniformly flat signal at the bridge entity's token position for both models across all layers. This rules out a two-stage pre-assembly hypothesis. The CTX→MEM synthesis occurs directly at the final token position, with the bridge entity serving as an attended-to key rather than an intermediate storage location.
Guided by the logit lens, activation patching is performed across layers 20–28. For each attention head
A component is classified as critical if its recovery fraction
Prompt partitioning (500 total prompts):
- Gap subset (309 prompts): Model B's rank is worse than Model A's
- Parity subset (191 prompts): Models perform similarly
| Layer | MLP Recovery Fraction |
|---|---|
| L20 | 0.218 |
| L21 | ~0 |
| L22 | 1.000 |
| L23 | 1.000 |
| L24 | 1.000 |
| L25 | ~0.05 |
| L26 | 0.868 |
| L27 | 0.744 |
| L28 | 0 (silent) |
The monotonically increasing pattern from L20 through L26 indicates a cascade of MLP transformations, each building on the previous. The final answer commitment is distributed across this range rather than concentrated in a single layer.
Critical heads are sparse and scattered; no single head dominates:
| Head | Recovery Fraction |
|---|---|
| L22 H10 | 0.34 |
| L25 H4 | 0.31 |
| L25 H3 | 0.26 |
| L21 H6 | 0.24 |
| L23 H0 | 0.22 |
Individual head recovery fractions are substantially lower than MLP recovery fractions at the same layers.
| Stage | Layer Range | Mechanism |
|---|---|---|
| Routing | L20–L22 | Critical attention heads identify the bridge entity in context and attend to relevant relational structure, preparing the residual stream for retrieval |
| Retrieval | L22–L27 | Cascade of MLP transformations progressively surfaces the correct parametric association — the entity related to the bridge entity by the target relation |
The crossover at layer 22 reveals a clean functional separation: attention-driven routing gives way to MLP-driven parametric memory retrieval.
| Component | Layer Range | Max Recovery |
|---|---|---|
| MLP (parametric retrieval) | 22–27 | 1.000 (L22, L23, L24) |
| MLP (early preparation) | 20 | 0.218 |
| Attention heads (routing) | 20–27 | 0.340 (L22 H10) |
Total: 25 critical components — 19 attention heads + 6 MLP layers, concentrated in layers 20–27 (29% of network depth). Layer 28 is entirely silent — the circuit completes before the final layer.
SFT(COMP) never learns to decompose the task atomically. Without independent exposure to contextual and parametric reasoning as separate skills:
- Attention heads in layers 20–22 fail to route to the genuine bridge entity, instead latching onto surface-level lexical patterns
- The downstream MLP cascade receives an imprecise residual stream and cannot reliably retrieve the correct association, even if individual MLP weights are otherwise capable
Evaluation on GSM8K mathematical word problems (without retraining) revealed complete catastrophic forgetting: Model A produced biographical relation-chain outputs in response to arithmetic problems. This confirms the Synthesis Circuit is a domain-specific mechanism shaped by the training distribution, not a general compositional reasoning module.
| Task Type | Model Size | Overall | I.I.D. | Composition | Zero-shot |
|---|---|---|---|---|---|
| Parametric (MEM) | 1.5B | 66.78% | 87.56% | 75.46% | 3.07% |
| Parametric (MEM) | 0.5B | 51.46% | 65.64% | 60.91% | 2.81% |
| Contextual (CTX) | 1.5B | 75.07% | 84.66% | 72.35% | 42.60% |
| Contextual (CTX) | 0.5B | 75.21% | 87.02% | 73.71% | 29.80% |
| Complementary (COMP) | 1.5B | 13.09% | 10.59% | 14.49% | 16.78% |
| Complementary (COMP) | 0.5B | 13.00% | 10.87% | 12.86% | 18.19% |
Exact Match accuracy. SFT(MEM+CTX) model evaluated on the COMP test set.
| Training Setup | Overall | I.I.D. | Composition | Zero-shot |
|---|---|---|---|---|
Base: SFT(MEM+CTX) at loss 0.05 |
13.09% | 10.59% | 14.49% | 16.78% |
| RL(COMP) CoT reward (12.8k samples) | 22.83% | 21.92% | 24.45% | 22.44% |
| RL(COMP) Binary reward (36k samples) | 38.70% | 42.90% | 38.52% | 29.19% |
| RL(COMP) CoT reward (36k samples) | 39.64% | 44.31% | 38.45% | 30.61% |
| RL(COMP) CoT (36k) + UPT (12.8k) | 40.04% | 45.34% | 38.30% | 30.39% |
SFT(COMP) (180k) — upper bound |
51.01% | 62.67% | 47.63% | 29.08% |
Key observations:
- The CoT entity reward consistently outperforms binary reward
- UPT provides a marginal additional gain, especially on I.I.D. accuracy
SFT(COMP)achieves higher I.I.D. accuracy but comparable zero-shot accuracy to the RL pipeline, consistent with shortcut memorization
This work provides the first causal, component-level explanation for why atomic skill decomposition is a necessary prerequisite for RL-driven generalization in complementary reasoning:
-
The first 20 layers are computationally identical between
SFT(MEM+CTX)→RLandSFT(COMP)models — behavioral divergence is entirely localized to layers 20–28. -
A sparse Synthesis Circuit of 19 attention heads and 6 MLP layers in layers 20–27 implements the CTX→MEM handoff in two stages:
- Routing (L20–L22): Attention heads identify the bridge entity from context
- Retrieval (L22–L27): MLP cascade performs parametric memory retrieval
-
MLP dominance (layers 22–24 each achieving full recovery independently) is consistent with MLPs as implicit key-value stores for relational knowledge.
-
Synthesis occurs entirely at the final output token position — the bridge entity serves as an attention key, not an intermediate storage location.
-
The circuit is domain-specific: the fine-tuned model undergoes catastrophic forgetting of pre-training capabilities (evidenced by GSM8K failure).
- S. Cheng et al., "From Atomic to Composite: Reinforcement Learning Enables Generalization in Complementary Reasoning," arXiv:2512.01970v2, 2025.
- Lai Wei et al., "First SFT, Second RL, Third UPT: Continual Improving Multi-Modal LLM Reasoning via Unsupervised Post-Training," NeurIPS 2025.
- Z. Shao et al., "DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models," arXiv:2402.03300, 2024.
- K. Meng et al., "Locating and Editing Factual Associations in GPT," NeurIPS, 2022.
- Z. Allen-Zhu and Y. Li, "Physics of Language Models: Part 3.1, Knowledge Storage and Extraction," arXiv:2309.14316, 2023.
- M. Geva et al., "Transformer Feed-Forward Layers Are Key-Value Memories," arXiv:2012.14913, 2021.
- K. Cobbe et al., "Training Verifiers to Solve Math Word Problems," arXiv:2110.14168, 2021.