1818#include < utility>
1919#include < vector>
2020
21+ /*
22+ The most atomic way to train and inference a GPT on CUDA.
23+ This file is the complete fused GPU algorithm.
24+ Everything else is just efficiency.
25+ */
26+
2127#define CUDA_CHECK (call ) \
2228 do { \
2329 cudaError_t err__ = (call); \
2834 } \
2935 } while (0 )
3036
31- constexpr int kNEmbd = 16 ;
32- constexpr int kNHead = 4 ;
33- constexpr int kNLayer = 1 ;
34- constexpr int kBlockSize = 8 ;
37+ // Let there be fixed model geometry, matching microgpt.py exactly.
38+ constexpr int kNEmbd = 16 ; // embedding dimension
39+ constexpr int kNHead = 4 ; // number of attention heads
40+ constexpr int kNLayer = 1 ; // this fused kernel path currently supports one layer
41+ constexpr int kBlockSize = 8 ; // maximum sequence length
3542constexpr int kHeadDim = kNEmbd / kNHead ;
36- constexpr int kFcDim = 4 * kNEmbd ;
37- constexpr int kMaxVocab = 256 ;
38- constexpr int kMaxTokens = kBlockSize + 1 ;
43+ constexpr int kFcDim = 4 * kNEmbd ; // hidden size of the MLP expansion
44+ constexpr int kMaxVocab = 256 ; // hard safety cap for static kernel buffers
45+ constexpr int kMaxTokens = kBlockSize + 1 ; // [token_t, ..., token_{t+n}] for next-token targets
3946
47+ // Let there be training knobs, kept numerically aligned with the Python reference.
4048struct TrainConfig {
4149 int num_steps = 500 ;
4250 int val_every = 100 ;
@@ -54,6 +62,7 @@ struct TrainConfig {
5462 float max_grad_norm = 1 .0f ;
5563};
5664
65+ // Let there be dataset and tokenizer state.
5766struct DataBundle {
5867 std::vector<std::string> docs;
5968 std::vector<std::string> train_docs;
@@ -63,6 +72,7 @@ struct DataBundle {
6372 int bos = 0 ;
6473};
6574
75+ // Let there be a tiny RAII wrapper over device memory.
6676template <typename T>
6777class DeviceBuffer {
6878public:
@@ -159,6 +169,7 @@ private:
159169 size_t size_ = 0 ;
160170};
161171
172+ // Let there be a plain pointer view for passing model buffers into kernels.
162173struct ModelPtrs {
163174 int vocab_size = 0 ;
164175
@@ -199,6 +210,7 @@ struct ModelPtrs {
199210 float * v_mlp_fc2 = nullptr ;
200211};
201212
213+ // Let there be parameters and optimizer state resident on the GPU.
202214struct DeviceModel {
203215 int vocab_size = 0 ;
204216
@@ -238,6 +250,7 @@ struct DeviceModel {
238250 DeviceBuffer<float > v_mlp_fc1;
239251 DeviceBuffer<float > v_mlp_fc2;
240252
253+ // Initialize weights with the same random scheme as microgpt.py.
241254 DeviceModel (int vocab, std::mt19937& rng)
242255 : vocab_size(vocab),
243256 wte (static_cast <size_t >(vocab) * kNEmbd),
@@ -272,6 +285,7 @@ struct DeviceModel {
272285 v_attn_wo(static_cast <size_t >(kNEmbd ) * kNEmbd),
273286 v_mlp_fc1(static_cast <size_t >(kFcDim ) * kNEmbd),
274287 v_mlp_fc2(static_cast <size_t >(kNEmbd ) * kFcDim) {
288+ // Parameter init: Gaussian(0, 0.02), with selected projections zero-initialized.
275289 init_matrix (wte, 0 .02f , rng);
276290 init_matrix (wpe, 0 .02f , rng);
277291 init_matrix (attn_wq, 0 .02f , rng);
@@ -309,6 +323,7 @@ struct DeviceModel {
309323 v_mlp_fc2.zero ();
310324 }
311325
326+ // Count all trainable scalars, for observability parity with microgpt.py.
312327 size_t num_params () const {
313328 return wte.size () + wpe.size () + attn_wq.size () + attn_wk.size () + attn_wv.size () +
314329 attn_wo.size () + mlp_fc1.size () + mlp_fc2.size ();
@@ -357,6 +372,7 @@ struct DeviceModel {
357372 }
358373
359374private:
375+ // Host-side random init, uploaded once; parameters stay device-resident afterward.
360376 static void init_matrix (DeviceBuffer<float >& dst, float stddev, std::mt19937& rng) {
361377 std::vector<float > host (dst.size (), 0 .0f );
362378 if (stddev != 0 .0f ) {
@@ -369,6 +385,7 @@ private:
369385 }
370386};
371387
388+ // Trim a line from the input corpus.
372389static std::string trim_copy (const std::string& s) {
373390 size_t start = 0 ;
374391 while (start < s.size () && std::isspace (static_cast <unsigned char >(s[start]))) {
@@ -381,6 +398,7 @@ static std::string trim_copy(const std::string& s) {
381398 return s.substr (start, end - start);
382399}
383400
401+ // Let there be an input dataset file, matching Python behavior on first run.
384402static void ensure_input_file () {
385403 namespace fs = std::filesystem;
386404 if (fs::exists (" input.txt" )) {
@@ -400,6 +418,7 @@ static void ensure_input_file() {
400418#endif
401419}
402420
421+ // Load docs, split train/val, and build a character tokenizer with <BOS>.
403422static DataBundle load_data (std::mt19937& rng) {
404423 ensure_input_file ();
405424
@@ -456,6 +475,7 @@ static DataBundle load_data(std::mt19937& rng) {
456475 return data;
457476}
458477
478+ // Encode one document as [BOS] + chars + [BOS], exactly like microgpt.py.
459479static std::vector<int > encode_doc (
460480 const std::string& doc,
461481 const std::unordered_map<char , int >& stoi,
@@ -474,6 +494,7 @@ static std::vector<int> encode_doc(
474494 return tokens;
475495}
476496
497+ // Top-k sampling with temperature, used at inference time.
477498static int sample_top_k (
478499 const std::vector<float >& logits,
479500 int top_k,
@@ -502,6 +523,7 @@ static int sample_top_k(
502523 return ids[static_cast <size_t >(dist (rng))];
503524}
504525
526+ // Parse optional CLI overrides for fast experiments.
505527static TrainConfig parse_args (int argc, char ** argv) {
506528 TrainConfig cfg;
507529 for (int i = 1 ; i < argc; ++i) {
@@ -550,6 +572,8 @@ static TrainConfig parse_args(int argc, char** argv) {
550572 }
551573 return cfg;
552574}
575+
576+ // Below live the scalar CUDA building blocks used by fused kernels.
553577__device__ inline void d_vec_copy (const float * src, float * dst, int n) {
554578 for (int i = 0 ; i < n; ++i) {
555579 dst[i] = src[i];
@@ -562,6 +586,7 @@ __device__ inline void d_vec_add_inplace(float* dst, const float* src, int n) {
562586 }
563587}
564588
589+ // y = W * x
565590__device__ inline void d_matvec (const float * w, const float * x, float * y, int out, int in) {
566591 for (int row = 0 ; row < out; ++row) {
567592 float acc = 0 .0f ;
@@ -573,6 +598,7 @@ __device__ inline void d_matvec(const float* w, const float* x, float* y, int ou
573598 }
574599}
575600
601+ // dx = W^T * dy
576602__device__ inline void d_matvec_t (const float * w, const float * dy, float * dx, int out, int in) {
577603 for (int col = 0 ; col < in; ++col) {
578604 float acc = 0 .0f ;
@@ -583,6 +609,7 @@ __device__ inline void d_matvec_t(const float* w, const float* dy, float* dx, in
583609 }
584610}
585611
612+ // dW += dy outer x
586613__device__ inline void d_outer_add (float * dw, const float * dy, const float * x, int out, int in) {
587614 for (int row = 0 ; row < out; ++row) {
588615 int base = row * in;
@@ -592,6 +619,7 @@ __device__ inline void d_outer_add(float* dw, const float* dy, const float* x, i
592619 }
593620}
594621
622+ // Fused linear backward: accumulate dW and produce dx.
595623__device__ inline void d_linear_backward (
596624 const float * w,
597625 float * dw,
@@ -604,6 +632,7 @@ __device__ inline void d_linear_backward(
604632 d_matvec_t (w, dy, dx, out, in);
605633}
606634
635+ // RMSNorm forward and backward, matching Python math.
607636__device__ inline void d_rmsnorm_forward (const float * x, float * y, float * inv_rms, int n) {
608637 float ms = 0 .0f ;
609638 for (int i = 0 ; i < n; ++i) {
@@ -632,6 +661,7 @@ __device__ inline void d_rmsnorm_backward(
632661 }
633662}
634663
664+ // Stable softmax and fused CE used throughout train/eval.
635665__device__ inline void d_softmax (const float * logits, int n, float * probs) {
636666 float mx = logits[0 ];
637667 for (int i = 1 ; i < n; ++i) {
@@ -677,6 +707,7 @@ __device__ inline void d_zero(float* x, int n) {
677707 }
678708}
679709
710+ // One-array AdamW update with bias correction and optional grad scaling.
680711__device__ inline void d_adamw_array (
681712 float * w,
682713 float * g,
@@ -718,6 +749,7 @@ __global__ void train_step_kernel(
718749 float max_grad_norm,
719750 float * out_loss,
720751 float * out_grad_norm) {
752+ // Single-thread fused reference kernel: one launch = one full optimizer step.
721753 if (blockIdx .x != 0 || threadIdx .x != 0 ) {
722754 return ;
723755 }
@@ -774,6 +806,7 @@ __global__ void train_step_kernel(
774806
775807 float total_loss = 0 .0f ;
776808
809+ // Forward pass over sequence positions: build activations and per-token CE.
777810 for (int pos = 0 ; pos < n; ++pos) {
778811 int token_id = tokens[pos];
779812 int target_id = tokens[pos + 1 ];
@@ -784,6 +817,7 @@ __global__ void train_step_kernel(
784817 }
785818 d_rmsnorm_forward (x_tokpos[pos], x0[pos], &inv_rms0[pos], kNEmbd );
786819
820+ // 1) Attention block.
787821 d_vec_copy (x0[pos], x_resid_attn[pos], kNEmbd );
788822 d_rmsnorm_forward (x_resid_attn[pos], x_norm1[pos], &inv_rms1[pos], kNEmbd );
789823
@@ -819,6 +853,7 @@ __global__ void train_step_kernel(
819853 x_after_attn[pos][i] = wo_out[i] + x_resid_attn[pos][i];
820854 }
821855
856+ // 2) MLP block.
822857 d_rmsnorm_forward (x_after_attn[pos], x_norm2[pos], &inv_rms2[pos], kNEmbd );
823858 d_matvec (model.mlp_fc1 , x_norm2[pos], fc1_pre[pos], kFcDim , kNEmbd );
824859 for (int i = 0 ; i < kFcDim ; ++i) {
@@ -847,6 +882,7 @@ __global__ void train_step_kernel(
847882
848883 const float inv_n = 1 .0f / static_cast <float >(n);
849884
885+ // Backward pass through time: reverse sequence order.
850886 for (int pos = n - 1 ; pos >= 0 ; --pos) {
851887 int token_id = tokens[pos];
852888 int target_id = tokens[pos + 1 ];
@@ -860,6 +896,7 @@ __global__ void train_step_kernel(
860896 float d_x[kNEmbd ];
861897 d_linear_backward (model.wte , model.g_wte , vocab, kNEmbd , x_final[pos], dlogits, d_x);
862898
899+ // 2) MLP backward.
863900 float d_x_after_attn[kNEmbd ];
864901 d_vec_copy (d_x, d_x_after_attn, kNEmbd );
865902
@@ -879,6 +916,7 @@ __global__ void train_step_kernel(
879916 d_rmsnorm_backward (x_after_attn[pos], inv_rms2[pos], d_x_norm2, d_norm2_in, kNEmbd );
880917 d_vec_add_inplace (d_x_after_attn, d_norm2_in, kNEmbd );
881918
919+ // 1) Attention backward.
882920 float d_x_resid_attn[kNEmbd ];
883921 d_vec_copy (d_x_after_attn, d_x_resid_attn, kNEmbd );
884922
@@ -947,6 +985,7 @@ __global__ void train_step_kernel(
947985 }
948986 }
949987
988+ // Gradient clipping by global norm.
950989 double sum_sq = 0.0 ;
951990 for (int i = 0 ; i < wte_size; ++i) {
952991 double g = model.g_wte [i];
@@ -981,6 +1020,7 @@ __global__ void train_step_kernel(
9811020 float one_minus_b1_prod = 1 .0f - b1_prod;
9821021 float one_minus_b2_prod = 1 .0f - b2_prod;
9831022
1023+ // AdamW update for every parameter tensor.
9841024 d_adamw_array (
9851025 model.wte ,
9861026 model.g_wte ,
@@ -1105,6 +1145,7 @@ __device__ inline void d_forward_token_logits(
11051145 float k_cache[kBlockSize ][kNEmbd ],
11061146 float v_cache[kBlockSize ][kNEmbd ],
11071147 float * logits_out) {
1148+ // Stateless per-token forward, reusing externally provided KV caches.
11081149 const float inv_sqrt_head = 1 .0f / sqrtf (static_cast <float >(kHeadDim ));
11091150
11101151 int token_id = tokens[pos];
@@ -1117,6 +1158,7 @@ __device__ inline void d_forward_token_logits(
11171158 }
11181159 d_rmsnorm_forward (x, x_norm, &inv_rms, kNEmbd );
11191160
1161+ // 1) Attention block.
11201162 float x_resid_attn[kNEmbd ];
11211163 d_vec_copy (x_norm, x_resid_attn, kNEmbd );
11221164
@@ -1159,6 +1201,7 @@ __device__ inline void d_forward_token_logits(
11591201 x_after_attn[i] = wo_out[i] + x_resid_attn[i];
11601202 }
11611203
1204+ // 2) MLP block.
11621205 float x_norm2[kNEmbd ];
11631206 d_rmsnorm_forward (x_after_attn, x_norm2, &inv_rms, kNEmbd );
11641207 float fc1_pre[kFcDim ];
@@ -1175,10 +1218,12 @@ __device__ inline void d_forward_token_logits(
11751218 x_out[i] = fc2_out[i] + x_after_attn[i];
11761219 }
11771220
1221+ // Weight tying: output projection reuses token embedding matrix.
11781222 d_matvec (model.wte , x_out, logits_out, model.vocab_size , kNEmbd );
11791223}
11801224
11811225__global__ void eval_sequence_nll_kernel (ModelPtrs model, const int * tokens, int n, float * out_nll) {
1226+ // Validation kernel: forward-only NLL accumulation, no gradient work.
11821227 if (blockIdx .x != 0 || threadIdx .x != 0 ) {
11831228 return ;
11841229 }
@@ -1203,6 +1248,7 @@ __global__ void eval_sequence_nll_kernel(ModelPtrs model, const int* tokens, int
12031248}
12041249
12051250__global__ void forward_last_logits_kernel (ModelPtrs model, const int * tokens, int seq_len, float * out_logits) {
1251+ // Inference kernel: run causal forward, return logits at final position.
12061252 if (blockIdx .x != 0 || threadIdx .x != 0 ) {
12071253 return ;
12081254 }
@@ -1228,8 +1274,10 @@ __global__ void forward_last_logits_kernel(ModelPtrs model, const int* tokens, i
12281274}
12291275int main (int argc, char ** argv) {
12301276 try {
1277+ // This fused path is intentionally minimal and currently specialized to one layer.
12311278 static_assert (kNLayer == 1 , " current fused kernels are implemented for n_layer=1" );
12321279
1280+ // Let there be deterministic setup and data loading.
12331281 TrainConfig cfg = parse_args (argc, argv);
12341282 std::mt19937 rng (cfg.seed );
12351283 DataBundle data = load_data (rng);
@@ -1257,10 +1305,12 @@ int main(int argc, char** argv) {
12571305 std::cout << " vocab size: " << data.itos .size () << " \n " ;
12581306 std::cout << " num params: " << model.num_params () << " \n " ;
12591307
1308+ // Adam running products for bias correction.
12601309 float b1_prod = 1 .0f ;
12611310 float b2_prod = 1 .0f ;
12621311 constexpr float kPi = 3 .14159265358979323846f ;
12631312
1313+ // Repeat in sequence: one document per step.
12641314 for (int step = 0 ; step < cfg.num_steps ; ++step) {
12651315 auto t0 = std::chrono::high_resolution_clock::now ();
12661316
@@ -1274,10 +1324,12 @@ int main(int argc, char** argv) {
12741324
12751325 d_tokens.upload_raw (tokens.data (), static_cast <size_t >(n + 1 ));
12761326
1327+ // Cosine LR schedule, same functional form as microgpt.py.
12771328 float lr_t = cfg.learning_rate * 0 .5f * (1 .0f + std::cos (kPi * step / cfg.num_steps ));
12781329 b1_prod *= cfg.beta1 ;
12791330 b2_prod *= cfg.beta2 ;
12801331
1332+ // One fused launch performs forward, backward, clipping, and AdamW.
12811333 train_step_kernel<<<1 , 1 >>> (
12821334 model_ptrs,
12831335 d_tokens.data (),
@@ -1305,6 +1357,7 @@ int main(int argc, char** argv) {
13051357 << " | loss " << std::fixed << std::setprecision (4 ) << loss
13061358 << " | " << ms << " ms\n " ;
13071359
1360+ // Periodic validation on held-out docs.
13081361 if ((step + 1 ) % cfg.val_every == 0 ) {
13091362 float val_loss = 0 .0f ;
13101363 int val_n = 0 ;
@@ -1334,6 +1387,7 @@ int main(int argc, char** argv) {
13341387 }
13351388 }
13361389
1390+ // Inference: top-k sampled autoregressive decoding.
13371391 std::cout << " \n --- inference ---\n " ;
13381392 std::vector<float > host_logits (static_cast <size_t >(model.vocab_size ), 0 .0f );
13391393 for (int sample_idx = 0 ; sample_idx < cfg.num_samples ; ++sample_idx) {
0 commit comments