Started with Batch Size = 2, Seq Length = 1024.
Running on a A4000 (I know it is pretty old, but that's all I have access to :( )
Training for 1 epoch
- Without any Optimizations - Time to train: 48.72043180465698 seconds
- float32_matmul_precision - high - Allowing for TF32 - Time to train: 29.83514142036438 seconds
- torch.autocast to bfloat16 + (2) - Time to train: 24.927905797958374 seconds
- With torch.compile + (3) - Time to train: 32.85438013076782 seconds
For 5 epochs,
- With torch.compile - Time to train: 81.19576597213745 seconds
- with torch.autocast to bfloat16 and float32 matmul precision to high - Time to train: 125.25839257240295 seconds
- With torch.compile and Flash Attention - Time to train: 69.03489518165588 seconds
- Change Vocab Size to a nice number + (3) - Time to train: 65.71632099151611 seconds
After Adding LR Schedulers, Weight Decay and Gradient Norm Clipping, Training for 5 epochs
- Without Fused AdamW - Time to train: 62.56659150123596 seconds
- With Fused AdamW - Time to train: 57.591166973114014 seconds, Avg Loss - 5.64244685606523
Added QK-norm, Training for 5 Epochs, Time - 69.68068528175354 seconds. Avg loss - 5.588010857321999
Then Added RoPE Embeddings (YaRN is available, but as of now the seq length is fixed), Now, Training for 5 Epochs, Time - 70.93886852264404 seconds. Avg Loss - 4.063881944887566
Added ReLU^2 Activation Instead of GeLU Activation Training for 5 Epochs, Time - 68.68175554275513 seconds. Avg Loss - 3.8532666105212585
Switched from LayernNorm to RMSNorm, 5 Epochs Train Time - 67.5941743850708 seconds. Avg Loss - 3.8758670272249165
Zero Initialization of the projection layers of the attention module and zero initialization of the second layer of MLP, 5 Epochs Train Time - 59.57936501502991 seconds.. Avg Loss - 3.6762325648105505
Untied The Embedding Weight and the LM Head: 5 Epochs Train Time - 72.79475784301758 seconds.. Avg Loss - 3.337182854161118
Adds Skip Connection from every block to the first block and averages their output. Time - 65.08958983421326 seconds. Avg Loss - 3.5915792017272024 Not much difference, even when I make the weighting between them adaptive.
Changing to a new Dataset - OpenWebText-10K Everything till now except the skip connections from the first block - 1 Epochs Time - 285.06610012054443 seconds.. Avg Loss - 5.971239207874645
With Muon Optimizer, Linear Weights Casted to Bfloat16, Epoch 0, Step 900: loss = 5.505822658538818, norm = 0.7672624588012695, time = 246.73382449150085 seconds Epoch 0, Step 1200: loss = 5.30943489074707, norm = 0.6942930221557617, time = 326.76921463012695 seconds
Added UNet Style Skip Connections, Epoch 0, Step 900: loss = 5.506043910980225, norm = 0.7674642205238342, time = 242.98071599006653 seconds Epoch 0, Step 1200: loss = 5.309330940246582, norm = 0.6946200728416443, time = 323.3906297683716 seconds