A structured, test-driven transformer learning course for software engineers.
- Code-first, math-for-intuition: Lead with implementations, explain math when it builds understanding
- Test-driven & verifiable: Each concept has runnable tests, not notebooks
- Chapter + Labs structure: Chapters cover concepts, labs are hands-on exercises
- Progressive complexity: Laptop-friendly basics → GPU/TPU advanced topics
- Real ecosystem: Use PyTorch, HuggingFace, vLLM etc. - learn the tools professionals use
- First-principles when valuable: Build from scratch only when it deepens understanding
| Phase | Chapter | Title | Labs | Hardware |
|---|---|---|---|---|
| Foundation | 1 | The Attention Mechanism | 4 | Laptop |
| 2 | Building a Transformer Block | 4 | Laptop | |
| 3 | The Complete Transformer | 4 | Laptop | |
| 4 | Training Fundamentals | 4 | Laptop | |
| Attention Variants | 5 | Linear Attention | 4 | Laptop |
| 6 | Flash Linear Attention & State-Space | 5 | Laptop | |
| 7 | Sparse Attention (DeepSeek MLA, MoE) | 5 | Laptop | |
| Production | 8 | Memory & Inference Optimization | 4 | Laptop |
| 9 | Production Inference Frameworks | 5 | Laptop/GPU | |
| Hardware | 10 | Flash Attention Deep Dive | 5 | GPU |
| 11 | Distributed Training | 5 | Multi-GPU | |
| 12 | Custom Kernels & Hardware | 5 | GPU/TPU |
Total: 12 chapters, 54 labs
Learning Objectives:
- Understand attention as a soft lookup / weighted retrieval mechanism
- Derive and implement scaled dot-product attention
- Understand why we use multiple heads and how they specialize
Key Concepts:
- Query, Key, Value abstraction (database analogy)
- Dot product as similarity measure
- Softmax for normalization → attention weights
- Scaling factor √d_k and why it matters (variance control)
- Multi-head: parallel attention in different subspaces
Docs:
attention_intuition.md- What problem does attention solve? (seq2seq bottleneck, alignment)scaled_dot_product.md- The math: Q, K, V, softmax, scalingmultihead_attention.md- Why multiple heads? How do they combine?
Labs:
| Lab | Title | What you build | Key learning |
|---|---|---|---|
| lab01 | Dot-Product Attention | attention(Q, K, V) from scratch |
Core formula, softmax |
| lab02 | Attention Visualization | Heatmaps of attention weights | See what attention "looks at" |
| lab03 | Multi-Head Attention | MultiHeadAttention class |
Heads, projections, concatenation |
| lab04 | PyTorch Comparison | Match nn.MultiheadAttention output |
Validate correctness, learn API |
Milestone: Your multi-head attention matches PyTorch's output within 1e-5 tolerance.
Learning Objectives:
- Understand the components that surround attention in a transformer
- Implement layer normalization and understand pre-norm vs post-norm
- Understand positional encodings and why transformers need them
- Build a complete transformer block
Key Concepts:
- Residual connections (why they help training)
- Layer normalization (vs batch norm, why per-token)
- Pre-norm vs post-norm architecture
- Positional encodings: sinusoidal, learned, RoPE
- Feed-forward network: expand → activate → contract
- Activation functions: ReLU → GELU → SwiGLU evolution
Docs:
residuals_and_normalization.md- Why residuals? Pre-norm vs post-normpositional_encoding.md- Why position matters, different approachesfeed_forward_network.md- The "MLP" part, activation functionstransformer_block.md- Putting it all together
Labs:
| Lab | Title | What you build | Key learning |
|---|---|---|---|
| lab01 | Layer Normalization | LayerNorm from scratch |
Mean/var, learnable params |
| lab02 | Positional Encodings | Sinusoidal + RoPE | Frequency intuition, rotation |
| lab03 | Feed-Forward Network | FFN with GELU/SwiGLU |
Gating, activation patterns |
| lab04 | Transformer Block | Complete block assembly | Residuals, ordering |
Milestone: Your transformer block forward pass matches HuggingFace GPT-2 block output.
Learning Objectives:
- Understand how blocks stack to form complete models
- Learn the differences between encoder, decoder, encoder-decoder
- Understand tokenization and embeddings
- Load and run pretrained models
Key Concepts:
- Encoder vs decoder vs encoder-decoder architectures
- Causal masking for autoregressive generation
- Token embeddings and vocabulary
- Tied embeddings (input/output sharing)
- Output heads: LM head, classification head
Docs:
encoder_decoder_architectures.md- When to use whichcausal_masking.md- Preventing future peekingembeddings_and_vocabulary.md- Tokens, subwords, embedding tablespretrained_models.md- Loading weights, HuggingFace ecosystem
Labs:
| Lab | Title | What you build | Key learning |
|---|---|---|---|
| lab01 | Causal Masking | Implement attention mask | Triangle mask, -inf trick |
| lab02 | Token Embeddings | Embedding layer + position | Vocabulary, lookup tables |
| lab03 | Decoder-Only Transformer | Stack blocks into GPT-style model | Full architecture |
| lab04 | Load Pretrained Weights | Load GPT-2 weights into your code | Weight mapping, shapes |
Milestone: Your implementation generates same logits as HuggingFace GPT-2 for same input.
Learning Objectives:
- Understand the loss functions used to train language models
- Visualize and understand gradient flow through attention
- Learn modern optimizer and LR scheduling techniques
- Train a small model from scratch
Key Concepts:
- Cross-entropy loss for next-token prediction
- Perplexity as evaluation metric
- Gradient flow and vanishing/exploding gradients
- AdamW optimizer (weight decay done right)
- Learning rate schedules: warmup, cosine decay
- Gradient clipping
Docs:
loss_and_perplexity.md- What are we optimizing?gradient_flow.md- How gradients move through attentionoptimizers.md- Adam, AdamW, why weight decay matterslr_schedules.md- Warmup, decay strategies
Labs:
| Lab | Title | What you build | Key learning |
|---|---|---|---|
| lab01 | Loss Functions | Cross-entropy, perplexity | Probability interpretation |
| lab02 | Gradient Visualization | Plot gradients through layers | See vanishing/exploding |
| lab03 | Training Loop | Complete training loop | Data loading, optimization |
| lab04 | Train Tiny Model | Train on tiny_shakespeare | End-to-end training |
Milestone: Train a ~1M param model that generates coherent Shakespeare-like text.
Learning Objectives:
- Understand why O(n²) attention is problematic for long sequences
- Learn the kernel trick that enables O(n) attention
- Implement linear attention and understand its trade-offs
Key Concepts:
- Attention complexity: O(n²) memory and compute
- The associativity trick: (QK^T)V → Q(K^T V)
- Feature maps / kernel functions: φ(x)
- Causal linear attention: cumulative sum formulation
- Trade-off: efficiency vs expressiveness
Docs:
quadratic_bottleneck.md- Why O(n²) hurts, real-world exampleskernel_trick.md- Math behind linearization, associativityfeature_maps.md- Different φ functions and their propertiescausal_linear.md- Making it work for autoregressive models
Labs:
| Lab | Title | What you build | Key learning |
|---|---|---|---|
| lab01 | Complexity Analysis | Benchmark standard attention | See the O(n²) wall |
| lab02 | Kernel Trick | Implement Q(K^T V) formulation | Associativity insight |
| lab03 | Feature Maps | Try different φ: ELU+1, ReLU, exp | Impact on quality |
| lab04 | Causal Linear Attention | Cumsum-based implementation | Recurrent view |
Milestone: Linear attention that's 10x faster than standard for seq_len=4096.
Learning Objectives:
- Understand linear attention's connection to RNNs and state-space models
- Learn chunkwise parallel algorithms for efficient training
- Implement Gated Linear Attention (GLA) and understand the Kimi/Moonshot line of work
Key Concepts:
- Linear attention as RNN: hidden state = K^T V
- Parallel vs recurrent: training vs inference tradeoff
- Chunkwise computation: best of both worlds
- Flash Linear Attention: memory-efficient training
- Gating mechanisms: data-dependent forgetting
- DeltaNet, GLA, Mamba connections
Docs:
linear_attention_as_rnn.md- The recurrent interpretationchunkwise_parallel.md- Chunking for efficient trainingflash_linear_attention.md- The algorithm explainedgated_linear_attention.md- GLA, DeltaNet, Kimi variantsstate_space_connection.md- How this relates to Mamba/S4
Labs:
| Lab | Title | What you build | Key learning |
|---|---|---|---|
| lab01 | RNN View | Linear attention in recurrent form | State accumulation |
| lab02 | Chunkwise Parallel | Hybrid parallel/recurrent | Efficiency trick |
| lab03 | Flash Linear Attention | Memory-efficient version | Tiling for linear attn |
| lab04 | Gated Linear Attention | Implement GLA | Data-dependent decay |
| lab05 | DeltaNet | Implement DeltaNet variant | Delta rule, Kimi approach |
Milestone: Working GLA that matches reference implementation from fla library.
Learning Objectives:
- Understand sparse attention patterns and when to use them
- Implement sliding window and global token patterns
- Deep dive into DeepSeek's MLA (Multi-head Latent Attention)
- Understand how MoE integrates with attention
Key Concepts:
- Sparse patterns: local, strided, dilated, fixed
- Sliding window attention (Longformer, Mistral)
- Global tokens for long-range dependencies
- DeepSeek MLA: latent compression of KV
- Low-rank KV projection
- Mixture-of-Experts (MoE) basics
- Router design and load balancing
Docs:
sparse_patterns.md- Taxonomy of sparse attentionsliding_window.md- Local attention + global tokensdeepseek_mla.md- Multi-head Latent Attention explainedkv_compression.md- Why compress KV, how it worksmixture_of_experts.md- MoE basics, routing, load balancing
Labs:
| Lab | Title | What you build | Key learning |
|---|---|---|---|
| lab01 | Sparse Patterns | Implement local/strided masks | Mask construction |
| lab02 | Sliding Window | Longformer-style attention | Local + global |
| lab03 | KV Compression | Low-rank KV projection | Latent space idea |
| lab04 | DeepSeek MLA | Full MLA implementation | Latent attention |
| lab05 | Basic MoE | Simple MoE layer | Routing, top-k |
Milestone: MLA implementation that reduces KV cache by 4x while maintaining quality.
Learning Objectives:
- Understand why inference is memory-bound, not compute-bound
- Implement KV-cache and understand its memory implications
- Learn batching strategies for throughput optimization
- Understand quantization basics
Key Concepts:
- Memory bandwidth vs compute (roofline model)
- KV-cache: caching key/value for autoregressive generation
- Memory growth: O(batch × seq × layers × heads × dim)
- Continuous batching (iteration-level scheduling)
- Quantization: int8, int4, fp8
- Quantization-aware training vs post-training quantization
Docs:
memory_bound_inference.md- Why inference is memory-limitedkv_cache.md- What it is, how it grows, memory analysisbatching_strategies.md- Static vs dynamic vs continuousquantization_basics.md- Types, trade-offs, when to use
Labs:
| Lab | Title | What you build | Key learning |
|---|---|---|---|
| lab01 | KV-Cache | Add KV-cache to your transformer | Incremental decoding |
| lab02 | Generation Loop | Complete text generation | Sampling, temperature |
| lab03 | Batched Generation | Handle multiple sequences | Padding, attention masks |
| lab04 | Basic Quantization | int8 linear layers | Quantize/dequantize |
Milestone: Generation that's 10x faster with KV-cache vs recomputing.
Learning Objectives:
- Learn the production inference ecosystem
- Understand PagedAttention and its memory benefits
- Use vLLM, llama.cpp, and SGLang for real workloads
- Compare frameworks for different use cases
Key Concepts:
- PagedAttention: virtual memory for KV-cache
- vLLM architecture and optimizations
- llama.cpp: GGUF format, CPU inference, quantization
- SGLang: structured generation, constraint decoding
- Speculative decoding
- Choosing the right framework
Docs:
paged_attention.md- Virtual memory for KV-cachevllm_architecture.md- How vLLM worksllama_cpp.md- CPU inference, GGUF, quantization schemessglang.md- Structured generation, RadixAttentionframework_comparison.md- When to use what
Labs:
| Lab | Title | What you build | Key learning |
|---|---|---|---|
| lab01 | HuggingFace Basics | Load, generate, fine-tune | Ecosystem entry point |
| lab02 | vLLM Serving | Deploy model with vLLM | High-throughput serving |
| lab03 | llama.cpp | Quantize and run on CPU | GGUF, efficient CPU |
| lab04 | SGLang | Structured JSON generation | Constrained decoding |
| lab05 | Benchmark | Compare all frameworks | Throughput, latency, memory |
Milestone: Serve a 7B model with vLLM, achieve >100 tokens/sec throughput.
Learning Objectives:
- Understand GPU memory hierarchy in depth
- Learn how Flash Attention achieves memory efficiency
- Implement the core ideas of tiling and recomputation
- Use Flash Attention in practice
Key Concepts:
- GPU memory hierarchy: registers → shared memory (SRAM) → HBM
- Memory bandwidth bottleneck
- Tiling: compute attention in blocks
- Online softmax (numerically stable incremental)
- Recomputation in backward pass
- Flash Attention 2 and 3 improvements
- Gradient checkpointing
Docs:
gpu_memory_hierarchy.md- SRAM vs HBM, bandwidth limitstiling_and_blocking.md- Why and how to tile attentiononline_softmax.md- Incremental, numerically stable softmaxflash_attention_algorithm.md- The full algorithm explainedflash_attention_v2_v3.md- Improvements and optimizationsgradient_checkpointing.md- Trade compute for memory
Labs:
| Lab | Title | What you build | Key learning |
|---|---|---|---|
| lab01 | Memory Profiling | Profile attention memory usage | See the problem |
| lab02 | Online Softmax | Incremental softmax | Numerical stability |
| lab03 | Tiled Attention | Block-by-block attention | Core Flash idea |
| lab04 | Use Flash Attention | Integrate flash-attn library | Practical usage |
| lab05 | Gradient Checkpointing | Implement checkpointing | Memory/compute trade |
Milestone: Train with 4x longer sequences using Flash Attention + checkpointing.
Learning Objectives:
- Understand parallelism strategies for large model training
- Implement data parallelism with DDP
- Understand model parallelism concepts
- Use FSDP and DeepSpeed for real training
Key Concepts:
- Data parallelism: replicate model, partition data
- Distributed Data Parallel (DDP): gradient all-reduce
- Model parallelism: tensor vs pipeline
- ZeRO stages: optimizer, gradient, parameter sharding
- FSDP (Fully Sharded Data Parallel)
- Communication primitives: all-reduce, all-gather, reduce-scatter
- Mixed precision training (fp16, bf16)
Docs:
parallelism_strategies.md- Data, tensor, pipeline, expertddp.md- How DDP works, gradient synchronizationmodel_parallelism.md- Tensor and pipeline parallelismzero_and_fsdp.md- Memory-efficient data parallelismmixed_precision.md- fp16, bf16, loss scaling
Labs:
| Lab | Title | What you build | Key learning |
|---|---|---|---|
| lab01 | Multi-GPU Setup | Configure multi-GPU environment | Environment basics |
| lab02 | DDP Training | Train with DistributedDataParallel | Gradient sync |
| lab03 | FSDP | Train with Fully Sharded DP | Memory efficiency |
| lab04 | Mixed Precision | Add AMP to training | fp16/bf16 training |
| lab05 | DeepSpeed | Use DeepSpeed ZeRO | Production setup |
Milestone: Train a model too large for single GPU using FSDP.
Learning Objectives:
- Write custom GPU kernels in Triton
- Understand XLA and JAX compilation model
- Learn TPU programming basics
- Master profiling and optimization workflow
Key Concepts:
- Triton: Python-like GPU kernel programming
- Kernel fusion: combine operations to reduce memory traffic
- XLA: graph compilation and optimization
- JAX: functional transformations (jit, vmap, pmap)
- TPU architecture: systolic arrays, HBM
- Profiling tools: nsight, torch profiler, JAX profiler
Docs:
triton_basics.md- Writing kernels in Tritonkernel_fusion.md- Why fuse, what to fusexla_compilation.md- How XLA optimizesjax_transformations.md- jit, vmap, pmap explainedtpu_architecture.md- How TPUs workprofiling.md- Finding and fixing bottlenecks
Labs:
| Lab | Title | What you build | Key learning |
|---|---|---|---|
| lab01 | Triton Basics | Simple Triton kernels | Block-level programming |
| lab02 | Fused Attention | Attention kernel in Triton | Kernel fusion |
| lab03 | JAX Intro | Attention in JAX | Functional ML |
| lab04 | JAX JIT & vmap | Optimize with transformations | Compilation, batching |
| lab05 | Profiling | Profile and optimize a model | Find bottlenecks |
Milestone: Custom Triton attention kernel within 80% of Flash Attention performance.