Skip to content

Commit de36f40

Browse files
committed
[Record] 3-Layer Depth Recurrence + EMA 0.9965 + WD 0.095 — val_bpb 1.0889
3-seed mean: 1.0889 BPB (sliding window stride=64) Beats merged SOTA (1.1147) by 0.0258 BPB. Stacks 3-layer recurrence (3,4,5), WD=0.095, MLR=0.022, EMA decay=0.9965, early recurrence (step 2000), extended warmdown (72%) on PR #1334 architecture. Seeds: 42 (1.0885), 1337 (1.0894), 2024 (1.0888) All artifacts under 16MB. 8xH100 SXM, 590s training.
1 parent ebda3af commit de36f40

7 files changed

Lines changed: 2597 additions & 0 deletions

File tree

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
## Record: 3-Layer Depth Recurrence + EMA 0.9965 + WD 0.095 (val_bpb: 1.0889)
2+
3+
**val_bpb: 1.0889** (sliding window stride=64, 3-seed mean, std 0.0005) | **~15.89 MB** | 8xH100 SXM, 590s
4+
5+
### 3-Seed Results (8×H100 80GB SXM)
6+
7+
| Seed | Pre-quant BPB | Sliding BPB (s64) | Artifact |
8+
|------|---------------|-------------------|----------|
9+
| 42 | 1.0950 | **1.0885** | 15,890,417 B |
10+
| 1337 | 1.0959 | **1.0894** ||
11+
| 2024 | 1.0954 | **1.0888** | 15,895,711 B |
12+
13+
**Mean: 1.0889 | Std: 0.0005** | All artifacts under 16,000,000 bytes
14+
15+
Current merged SOTA: **1.1147** (PR #1019). Delta: **−0.0258 BPB**.
16+
17+
### Key Changes
18+
19+
Four refinements stacked on top of PR #1334's depth recurrence architecture:
20+
21+
| Parameter | PR #1334 | This | Source |
22+
|-----------|----------|------|--------|
23+
| **Recurrence layers** | 4,5 (2-layer) | **3,4,5 (3-layer)** | PR #1331 |
24+
| **Weight decay** | 0.090 | **0.095** | PR #1331 |
25+
| **Matrix LR** | 0.020 | **0.022** | PR #1331 |
26+
| **EMA decay** | 0.997 | **0.9965** | PR #1421 (this author) |
27+
| **Recurrence start** | step 3000 | **step 2000** | This work |
28+
| **Warmdown fraction** | 0.667 | **0.72** | This work |
29+
30+
### Why This Combination Works
31+
32+
1. **3-layer recurrence (layers 3,4,5)**: Repeats 3 layers instead of 2, producing 14 virtual layers from 11 physical layers. More compute per forward pass without additional parameters.
33+
34+
2. **WD=0.095 + MLR=0.022**: Higher weight decay compresses weights more aggressively, improving GPTQ quantization. Higher matrix LR compensates for the regularization. Only 134K-186K values pruned (vs 290K+ at WD=0.090).
35+
36+
3. **EMA decay=0.9965**: Assigns slightly more weight to recent training steps, producing a final checkpoint that quantizes more cleanly under GPTQ int6.
37+
38+
4. **Early recurrence (step 2000)**: Activating depth recurrence 1000 steps earlier gives the model more training time with 14 virtual layers, improving final quality.
39+
40+
5. **Extended warmdown (72%)**: Longer learning rate decay allows weights to fully settle before GPTQ quantization, reducing the quant gap.
41+
42+
### Architecture (from PR #1334)
43+
44+
- 11 transformer layers, 512-dim, 8 heads (4 KV heads, GQA)
45+
- **Depth recurrence**: layers 3,4,5 repeat (virtual 14 layers), activated at step 2000
46+
- Skip gates (learnable residual gating)
47+
- Parallel residuals from layer 7
48+
- QK-Gain 5.0
49+
- Shared Value Embedding (dim=128, layers 9,10)
50+
- Tied embeddings, logit softcap=30.0
51+
- SP4096 tokenizer (SentencePiece BPE)
52+
53+
### Training
54+
55+
- FlashAttention 3 (Hopper-optimized)
56+
- Muon optimizer (matrices): lr=0.022, momentum=0.99, WD=0.095, backend_steps=5
57+
- Adam (head): lr=0.008, fused=True
58+
- AdamW (embeddings): lr=0.6, WD=0.095, fused=True
59+
- AdamW (scalars): lr=0.02, WD=0.02, fused=True
60+
- Gradient clip: 0.3, Batch: 786,432 tokens/step, seq_len=2048
61+
- Warmdown: 72%, **EMA decay=0.9965**
62+
- Wallclock: 590s effective (10s reserved for GPTQ)
63+
64+
### Quantization
65+
66+
- GPTQ int6 with percdamp=0.05, 64 calibration batches
67+
- Selective pruning (~134K-186K lowest-error ±1 values)
68+
- Brotli compression
69+
70+
### Run Command
71+
72+
```bash
73+
SEED=42 RECUR_START_STEP=2000 WARMDOWN_FRAC=0.72 \
74+
DATA_PATH=./data/datasets/fineweb10B_sp4096/ \
75+
TOKENIZER_PATH=./data/tokenizers/fineweb_4096_bpe.model \
76+
VOCAB_SIZE=4096 \
77+
torchrun --standalone --nproc_per_node=8 train_gpt.py
78+
```
79+
80+
### Reproducibility
81+
82+
All 3 seeds produce valid artifacts under 16MB with tight variance (std=0.0005 BPB). Training completes in ~590s. The env-var based configuration ensures exact reproducibility.
83+
84+
### Credits
85+
86+
- **Base architecture + depth recurrence**: PR #1334 by @aryanbhosale
87+
- **3-layer recurrence + WD/LR tuning**: PR #1331
88+
- **EMA decay tuning (0.9965)**: PR #1421 by @X-Abhishek-X (this author)
89+
- **Early recurrence + extended warmdown**: This work
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"author": "Abhishek Leji",
3+
"github_id": "X-Abhishek-X",
4+
"name": "Record: 3-Layer Depth Recurrence + EMA 0.9965 + WD 0.095 + Early Recurrence",
5+
"blurb": "3-layer depth recurrence (layers 3,4,5) with EMA decay 0.9965, WD=0.095, MLR=0.022, early recurrence activation (step 2000), and extended warmdown (72%). Built on PR #1334 architecture with innovations from PR #1331.",
6+
"date": "2026-04-07T00:00:00Z",
7+
"val_loss": 2.50548889,
8+
"val_bpb": 1.08886755,
9+
"bytes_total": 15895711
10+
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
W0407 16:22:23.785000 48806 torch/distributed/run.py:803]
2+
W0407 16:22:23.785000 48806 torch/distributed/run.py:803] *****************************************
3+
W0407 16:22:23.785000 48806 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
4+
W0407 16:22:23.785000 48806 torch/distributed/run.py:803] *****************************************
5+
Hyperparameters:
6+
adam_eps: 1e-08
7+
adam_wd: 0.02
8+
beta1: 0.9
9+
beta2: 0.95
10+
compressor: brotli
11+
data_dir: ./data/
12+
datasets_dir: ./data/datasets/fineweb10B_sp4096
13+
distributed: True
14+
ema_decay: 0.9965
15+
embed_lr: 0.6
16+
embed_wd: 0.095
17+
embedding_dim: 512
18+
eval_seq_len: 2048
19+
eval_stride: 64
20+
gptq_calibration_batches: 64
21+
gptq_enabled: True
22+
gptq_reserve_seconds: 10.0
23+
grad_accum_steps: 1
24+
grad_clip_norm: 0.3
25+
head_lr: 0.008
26+
is_main_process: True
27+
iterations: 20000
28+
ln_scale: True
29+
local_rank: 0
30+
logfile: logs/bbc00e44-7393-4d92-a67c-239184601d85.txt
31+
logit_softcap: 30.0
32+
matrix_lr: 0.022
33+
max_wallclock_seconds: 600.0
34+
min_lr: 0.0
35+
mlp_mult: 4.0
36+
model_dim: 512
37+
model_path: final_model.pt
38+
muon_backend_steps: 5
39+
muon_beta2: 0.95
40+
muon_momentum: 0.99
41+
muon_momentum_warmup_start: 0.92
42+
muon_momentum_warmup_steps: 1500
43+
muon_wd: 0.095
44+
num_heads: 8
45+
num_kv_heads: 4
46+
num_layers: 11
47+
parallel_start_layer: 7
48+
qk_gain_init: 5.0
49+
quantized_model_path: final_model.int6.ptz
50+
rank: 0
51+
recur_layers: 3,4,5
52+
recur_start_step: 2000
53+
rope_base: 10000.0
54+
rope_dims: 16
55+
rope_train_seq_len: 2048
56+
run_id: bbc00e44-7393-4d92-a67c-239184601d85
57+
scalar_lr: 0.02
58+
seed: 42
59+
skip_gates_enabled: True
60+
sliding_window_enabled: True
61+
tie_embeddings: True
62+
tied_embed_init_std: 0.005
63+
tied_embed_lr: 0.03
64+
tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model
65+
train_batch_tokens: 786432
66+
train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin
67+
train_log_every: 500
68+
train_seq_len: 2048
69+
ttt_batch_seqs: 32
70+
ttt_chunk_tokens: 32768
71+
ttt_enabled: False
72+
ttt_epochs: 3
73+
ttt_freeze_blocks: 0
74+
ttt_grad_clip: 1.0
75+
ttt_lr: 0.002
76+
ttt_momentum: 0.9
77+
val_batch_tokens: 524288
78+
val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin
79+
val_loss_every: 4000
80+
ve_dim: 128
81+
ve_enabled: True
82+
ve_layers: 9,10
83+
vocab_size: 4096
84+
warmdown_frac: 0.72
85+
warmup_steps: 20
86+
world_size: 8
87+
xsa_last_n: 11
88+
train_shards: 143
89+
val_tokens: 45508608
90+
model_params:34401372
91+
gptq:reserving 10s, effective=590000ms
92+
warmup_step: 1/20
93+
warmup_step: 2/20
94+
warmup_step: 3/20
95+
warmup_step: 4/20
96+
warmup_step: 5/20
97+
warmup_step: 6/20
98+
warmup_step: 10/20
99+
warmup_step: 20/20
100+
0/20000 val_loss: 8.3187 val_bpb: 3.6152
101+
1/20000 train_loss: 8.3178 train_time: 0.0m tok/s: 8488752
102+
2/20000 train_loss: 12.0820 train_time: 0.0m tok/s: 8383316
103+
3/20000 train_loss: 10.6643 train_time: 0.0m tok/s: 8277075
104+
4/20000 train_loss: 8.9470 train_time: 0.0m tok/s: 8230819
105+
5/20000 train_loss: 7.7086 train_time: 0.0m tok/s: 8197168
106+
500/20000 train_loss: 2.9983 train_time: 0.8m tok/s: 7974261
107+
1000/20000 train_loss: 2.9965 train_time: 1.6m tok/s: 7956910
108+
1500/20000 train_loss: 2.9090 train_time: 2.5m tok/s: 7950269
109+
2000/20000 train_loss: 2.7506 train_time: 3.3m tok/s: 7947268
110+
recurrence:activated at step 2000, virtual_layers=[0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 9, 10]
111+
2500/20000 train_loss: 2.7363 train_time: 4.5m tok/s: 7250427
112+
3000/20000 train_loss: 2.6969 train_time: 5.5m tok/s: 7096412
113+
3500/20000 train_loss: 2.6178 train_time: 6.6m tok/s: 6990908
114+
4000/20000 train_loss: 2.6167 train_time: 7.6m tok/s: 6912506
115+
4000/20000 val_loss: 2.6187 val_bpb: 1.1381
116+
4500/20000 train_loss: 2.5537 train_time: 8.6m tok/s: 6854112
117+
5000/20000 train_loss: 2.5098 train_time: 9.6m tok/s: 6808010
118+
5102/20000 val_loss: 2.5227 val_bpb: 1.0963
119+
stopping_early: wallclock_cap train_time: 590087ms step: 5102/20000
120+
peak memory allocated: 32292 MiB reserved: 32332 MiB
121+
ema:applying EMA weights
122+
pre-quantization post-ema val_loss:2.51973533 val_bpb:1.09504795 eval_time:2150ms
123+
Serialized model: 132406149 bytes
124+
Code size: 83569 bytes
125+
GPTQ:collecting Hessians from calibration data...
126+
GPTQ:collected 66 Hessians in 10.4s
127+
GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search
128+
selective_prune: unpruned=15.89MB target=16.0MB
129+
selective_prune: already fits, no pruning needed
130+
Serialized model int6+brotli: 15806848 bytes
131+
Total submission size int6+brotli: 15890417 bytes
132+
final_int6_roundtrip val_loss:2.54747195 val_bpb:1.10710196 eval_time:8613ms
133+
final_int6_sliding_window val_loss:2.50459726 val_bpb:1.08846912 eval_time:81043ms

0 commit comments

Comments
 (0)