Skip to content

KarthikSundar2002/GPT2-from-scratch

Repository files navigation

Analysis of Optimizations

Started with Batch Size = 2, Seq Length = 1024.

Running on a A4000 (I know it is pretty old, but that's all I have access to :( )

Training for 1 epoch

  1. Without any Optimizations - Time to train: 48.72043180465698 seconds
  2. float32_matmul_precision - high - Allowing for TF32 - Time to train: 29.83514142036438 seconds
  3. torch.autocast to bfloat16 + (2) - Time to train: 24.927905797958374 seconds
  4. With torch.compile + (3) - Time to train: 32.85438013076782 seconds

For 5 epochs,

  1. With torch.compile - Time to train: 81.19576597213745 seconds
  2. with torch.autocast to bfloat16 and float32 matmul precision to high - Time to train: 125.25839257240295 seconds
  3. With torch.compile and Flash Attention - Time to train: 69.03489518165588 seconds
  4. Change Vocab Size to a nice number + (3) - Time to train: 65.71632099151611 seconds

After Adding LR Schedulers, Weight Decay and Gradient Norm Clipping, Training for 5 epochs

  1. Without Fused AdamW - Time to train: 62.56659150123596 seconds
  2. With Fused AdamW - Time to train: 57.591166973114014 seconds, Avg Loss - 5.64244685606523

Added QK-norm, Training for 5 Epochs, Time - 69.68068528175354 seconds. Avg loss - 5.588010857321999

Then Added RoPE Embeddings (YaRN is available, but as of now the seq length is fixed), Now, Training for 5 Epochs, Time - 70.93886852264404 seconds. Avg Loss - 4.063881944887566

Added ReLU^2 Activation Instead of GeLU Activation Training for 5 Epochs, Time - 68.68175554275513 seconds. Avg Loss - 3.8532666105212585

Switched from LayernNorm to RMSNorm, 5 Epochs Train Time - 67.5941743850708 seconds. Avg Loss - 3.8758670272249165

Zero Initialization of the projection layers of the attention module and zero initialization of the second layer of MLP, 5 Epochs Train Time - 59.57936501502991 seconds.. Avg Loss - 3.6762325648105505

Untied The Embedding Weight and the LM Head: 5 Epochs Train Time - 72.79475784301758 seconds.. Avg Loss - 3.337182854161118

Adds Skip Connection from every block to the first block and averages their output. Time - 65.08958983421326 seconds. Avg Loss - 3.5915792017272024 Not much difference, even when I make the weighting between them adaptive.

Changing to a new Dataset - OpenWebText-10K Everything till now except the skip connections from the first block - 1 Epochs Time - 285.06610012054443 seconds.. Avg Loss - 5.971239207874645

With Muon Optimizer, Linear Weights Casted to Bfloat16, Epoch 0, Step 900: loss = 5.505822658538818, norm = 0.7672624588012695, time = 246.73382449150085 seconds Epoch 0, Step 1200: loss = 5.30943489074707, norm = 0.6942930221557617, time = 326.76921463012695 seconds

Added UNet Style Skip Connections, Epoch 0, Step 900: loss = 5.506043910980225, norm = 0.7674642205238342, time = 242.98071599006653 seconds Epoch 0, Step 1200: loss = 5.309330940246582, norm = 0.6946200728416443, time = 323.3906297683716 seconds

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages