Skip to content

Commit 27e1cef

Browse files
committed
Refactor code structure for improved readability and maintainability
1 parent 2fc0a24 commit 27e1cef

2 files changed

Lines changed: 1987 additions & 424 deletions

File tree

microgpt_cuda.cu

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
#include <utility>
1919
#include <vector>
2020

21+
/*
22+
The most atomic way to train and inference a GPT on CUDA.
23+
This file is the complete fused GPU algorithm.
24+
Everything else is just efficiency.
25+
*/
26+
2127
#define CUDA_CHECK(call) \
2228
do { \
2329
cudaError_t err__ = (call); \
@@ -28,15 +34,17 @@
2834
} \
2935
} while (0)
3036

31-
constexpr int kNEmbd = 16;
32-
constexpr int kNHead = 4;
33-
constexpr int kNLayer = 1;
34-
constexpr int kBlockSize = 8;
37+
// Let there be fixed model geometry, matching microgpt.py exactly.
38+
constexpr int kNEmbd = 16; // embedding dimension
39+
constexpr int kNHead = 4; // number of attention heads
40+
constexpr int kNLayer = 1; // this fused kernel path currently supports one layer
41+
constexpr int kBlockSize = 8; // maximum sequence length
3542
constexpr int kHeadDim = kNEmbd / kNHead;
36-
constexpr int kFcDim = 4 * kNEmbd;
37-
constexpr int kMaxVocab = 256;
38-
constexpr int kMaxTokens = kBlockSize + 1;
43+
constexpr int kFcDim = 4 * kNEmbd; // hidden size of the MLP expansion
44+
constexpr int kMaxVocab = 256; // hard safety cap for static kernel buffers
45+
constexpr int kMaxTokens = kBlockSize + 1; // [token_t, ..., token_{t+n}] for next-token targets
3946

47+
// Let there be training knobs, kept numerically aligned with the Python reference.
4048
struct TrainConfig {
4149
int num_steps = 500;
4250
int val_every = 100;
@@ -54,6 +62,7 @@ struct TrainConfig {
5462
float max_grad_norm = 1.0f;
5563
};
5664

65+
// Let there be dataset and tokenizer state.
5766
struct DataBundle {
5867
std::vector<std::string> docs;
5968
std::vector<std::string> train_docs;
@@ -63,6 +72,7 @@ struct DataBundle {
6372
int bos = 0;
6473
};
6574

75+
// Let there be a tiny RAII wrapper over device memory.
6676
template <typename T>
6777
class DeviceBuffer {
6878
public:
@@ -159,6 +169,7 @@ private:
159169
size_t size_ = 0;
160170
};
161171

172+
// Let there be a plain pointer view for passing model buffers into kernels.
162173
struct ModelPtrs {
163174
int vocab_size = 0;
164175

@@ -199,6 +210,7 @@ struct ModelPtrs {
199210
float* v_mlp_fc2 = nullptr;
200211
};
201212

213+
// Let there be parameters and optimizer state resident on the GPU.
202214
struct DeviceModel {
203215
int vocab_size = 0;
204216

@@ -238,6 +250,7 @@ struct DeviceModel {
238250
DeviceBuffer<float> v_mlp_fc1;
239251
DeviceBuffer<float> v_mlp_fc2;
240252

253+
// Initialize weights with the same random scheme as microgpt.py.
241254
DeviceModel(int vocab, std::mt19937& rng)
242255
: vocab_size(vocab),
243256
wte(static_cast<size_t>(vocab) * kNEmbd),
@@ -272,6 +285,7 @@ struct DeviceModel {
272285
v_attn_wo(static_cast<size_t>(kNEmbd) * kNEmbd),
273286
v_mlp_fc1(static_cast<size_t>(kFcDim) * kNEmbd),
274287
v_mlp_fc2(static_cast<size_t>(kNEmbd) * kFcDim) {
288+
// Parameter init: Gaussian(0, 0.02), with selected projections zero-initialized.
275289
init_matrix(wte, 0.02f, rng);
276290
init_matrix(wpe, 0.02f, rng);
277291
init_matrix(attn_wq, 0.02f, rng);
@@ -309,6 +323,7 @@ struct DeviceModel {
309323
v_mlp_fc2.zero();
310324
}
311325

326+
// Count all trainable scalars, for observability parity with microgpt.py.
312327
size_t num_params() const {
313328
return wte.size() + wpe.size() + attn_wq.size() + attn_wk.size() + attn_wv.size() +
314329
attn_wo.size() + mlp_fc1.size() + mlp_fc2.size();
@@ -357,6 +372,7 @@ struct DeviceModel {
357372
}
358373

359374
private:
375+
// Host-side random init, uploaded once; parameters stay device-resident afterward.
360376
static void init_matrix(DeviceBuffer<float>& dst, float stddev, std::mt19937& rng) {
361377
std::vector<float> host(dst.size(), 0.0f);
362378
if (stddev != 0.0f) {
@@ -369,6 +385,7 @@ private:
369385
}
370386
};
371387

388+
// Trim a line from the input corpus.
372389
static std::string trim_copy(const std::string& s) {
373390
size_t start = 0;
374391
while (start < s.size() && std::isspace(static_cast<unsigned char>(s[start]))) {
@@ -381,6 +398,7 @@ static std::string trim_copy(const std::string& s) {
381398
return s.substr(start, end - start);
382399
}
383400

401+
// Let there be an input dataset file, matching Python behavior on first run.
384402
static void ensure_input_file() {
385403
namespace fs = std::filesystem;
386404
if (fs::exists("input.txt")) {
@@ -400,6 +418,7 @@ static void ensure_input_file() {
400418
#endif
401419
}
402420

421+
// Load docs, split train/val, and build a character tokenizer with <BOS>.
403422
static DataBundle load_data(std::mt19937& rng) {
404423
ensure_input_file();
405424

@@ -456,6 +475,7 @@ static DataBundle load_data(std::mt19937& rng) {
456475
return data;
457476
}
458477

478+
// Encode one document as [BOS] + chars + [BOS], exactly like microgpt.py.
459479
static std::vector<int> encode_doc(
460480
const std::string& doc,
461481
const std::unordered_map<char, int>& stoi,
@@ -474,6 +494,7 @@ static std::vector<int> encode_doc(
474494
return tokens;
475495
}
476496

497+
// Top-k sampling with temperature, used at inference time.
477498
static int sample_top_k(
478499
const std::vector<float>& logits,
479500
int top_k,
@@ -502,6 +523,7 @@ static int sample_top_k(
502523
return ids[static_cast<size_t>(dist(rng))];
503524
}
504525

526+
// Parse optional CLI overrides for fast experiments.
505527
static TrainConfig parse_args(int argc, char** argv) {
506528
TrainConfig cfg;
507529
for (int i = 1; i < argc; ++i) {
@@ -550,6 +572,8 @@ static TrainConfig parse_args(int argc, char** argv) {
550572
}
551573
return cfg;
552574
}
575+
576+
// Below live the scalar CUDA building blocks used by fused kernels.
553577
__device__ inline void d_vec_copy(const float* src, float* dst, int n) {
554578
for (int i = 0; i < n; ++i) {
555579
dst[i] = src[i];
@@ -562,6 +586,7 @@ __device__ inline void d_vec_add_inplace(float* dst, const float* src, int n) {
562586
}
563587
}
564588

589+
// y = W * x
565590
__device__ inline void d_matvec(const float* w, const float* x, float* y, int out, int in) {
566591
for (int row = 0; row < out; ++row) {
567592
float acc = 0.0f;
@@ -573,6 +598,7 @@ __device__ inline void d_matvec(const float* w, const float* x, float* y, int ou
573598
}
574599
}
575600

601+
// dx = W^T * dy
576602
__device__ inline void d_matvec_t(const float* w, const float* dy, float* dx, int out, int in) {
577603
for (int col = 0; col < in; ++col) {
578604
float acc = 0.0f;
@@ -583,6 +609,7 @@ __device__ inline void d_matvec_t(const float* w, const float* dy, float* dx, in
583609
}
584610
}
585611

612+
// dW += dy outer x
586613
__device__ inline void d_outer_add(float* dw, const float* dy, const float* x, int out, int in) {
587614
for (int row = 0; row < out; ++row) {
588615
int base = row * in;
@@ -592,6 +619,7 @@ __device__ inline void d_outer_add(float* dw, const float* dy, const float* x, i
592619
}
593620
}
594621

622+
// Fused linear backward: accumulate dW and produce dx.
595623
__device__ inline void d_linear_backward(
596624
const float* w,
597625
float* dw,
@@ -604,6 +632,7 @@ __device__ inline void d_linear_backward(
604632
d_matvec_t(w, dy, dx, out, in);
605633
}
606634

635+
// RMSNorm forward and backward, matching Python math.
607636
__device__ inline void d_rmsnorm_forward(const float* x, float* y, float* inv_rms, int n) {
608637
float ms = 0.0f;
609638
for (int i = 0; i < n; ++i) {
@@ -632,6 +661,7 @@ __device__ inline void d_rmsnorm_backward(
632661
}
633662
}
634663

664+
// Stable softmax and fused CE used throughout train/eval.
635665
__device__ inline void d_softmax(const float* logits, int n, float* probs) {
636666
float mx = logits[0];
637667
for (int i = 1; i < n; ++i) {
@@ -677,6 +707,7 @@ __device__ inline void d_zero(float* x, int n) {
677707
}
678708
}
679709

710+
// One-array AdamW update with bias correction and optional grad scaling.
680711
__device__ inline void d_adamw_array(
681712
float* w,
682713
float* g,
@@ -718,6 +749,7 @@ __global__ void train_step_kernel(
718749
float max_grad_norm,
719750
float* out_loss,
720751
float* out_grad_norm) {
752+
// Single-thread fused reference kernel: one launch = one full optimizer step.
721753
if (blockIdx.x != 0 || threadIdx.x != 0) {
722754
return;
723755
}
@@ -774,6 +806,7 @@ __global__ void train_step_kernel(
774806

775807
float total_loss = 0.0f;
776808

809+
// Forward pass over sequence positions: build activations and per-token CE.
777810
for (int pos = 0; pos < n; ++pos) {
778811
int token_id = tokens[pos];
779812
int target_id = tokens[pos + 1];
@@ -784,6 +817,7 @@ __global__ void train_step_kernel(
784817
}
785818
d_rmsnorm_forward(x_tokpos[pos], x0[pos], &inv_rms0[pos], kNEmbd);
786819

820+
// 1) Attention block.
787821
d_vec_copy(x0[pos], x_resid_attn[pos], kNEmbd);
788822
d_rmsnorm_forward(x_resid_attn[pos], x_norm1[pos], &inv_rms1[pos], kNEmbd);
789823

@@ -819,6 +853,7 @@ __global__ void train_step_kernel(
819853
x_after_attn[pos][i] = wo_out[i] + x_resid_attn[pos][i];
820854
}
821855

856+
// 2) MLP block.
822857
d_rmsnorm_forward(x_after_attn[pos], x_norm2[pos], &inv_rms2[pos], kNEmbd);
823858
d_matvec(model.mlp_fc1, x_norm2[pos], fc1_pre[pos], kFcDim, kNEmbd);
824859
for (int i = 0; i < kFcDim; ++i) {
@@ -847,6 +882,7 @@ __global__ void train_step_kernel(
847882

848883
const float inv_n = 1.0f / static_cast<float>(n);
849884

885+
// Backward pass through time: reverse sequence order.
850886
for (int pos = n - 1; pos >= 0; --pos) {
851887
int token_id = tokens[pos];
852888
int target_id = tokens[pos + 1];
@@ -860,6 +896,7 @@ __global__ void train_step_kernel(
860896
float d_x[kNEmbd];
861897
d_linear_backward(model.wte, model.g_wte, vocab, kNEmbd, x_final[pos], dlogits, d_x);
862898

899+
// 2) MLP backward.
863900
float d_x_after_attn[kNEmbd];
864901
d_vec_copy(d_x, d_x_after_attn, kNEmbd);
865902

@@ -879,6 +916,7 @@ __global__ void train_step_kernel(
879916
d_rmsnorm_backward(x_after_attn[pos], inv_rms2[pos], d_x_norm2, d_norm2_in, kNEmbd);
880917
d_vec_add_inplace(d_x_after_attn, d_norm2_in, kNEmbd);
881918

919+
// 1) Attention backward.
882920
float d_x_resid_attn[kNEmbd];
883921
d_vec_copy(d_x_after_attn, d_x_resid_attn, kNEmbd);
884922

@@ -947,6 +985,7 @@ __global__ void train_step_kernel(
947985
}
948986
}
949987

988+
// Gradient clipping by global norm.
950989
double sum_sq = 0.0;
951990
for (int i = 0; i < wte_size; ++i) {
952991
double g = model.g_wte[i];
@@ -981,6 +1020,7 @@ __global__ void train_step_kernel(
9811020
float one_minus_b1_prod = 1.0f - b1_prod;
9821021
float one_minus_b2_prod = 1.0f - b2_prod;
9831022

1023+
// AdamW update for every parameter tensor.
9841024
d_adamw_array(
9851025
model.wte,
9861026
model.g_wte,
@@ -1105,6 +1145,7 @@ __device__ inline void d_forward_token_logits(
11051145
float k_cache[kBlockSize][kNEmbd],
11061146
float v_cache[kBlockSize][kNEmbd],
11071147
float* logits_out) {
1148+
// Stateless per-token forward, reusing externally provided KV caches.
11081149
const float inv_sqrt_head = 1.0f / sqrtf(static_cast<float>(kHeadDim));
11091150

11101151
int token_id = tokens[pos];
@@ -1117,6 +1158,7 @@ __device__ inline void d_forward_token_logits(
11171158
}
11181159
d_rmsnorm_forward(x, x_norm, &inv_rms, kNEmbd);
11191160

1161+
// 1) Attention block.
11201162
float x_resid_attn[kNEmbd];
11211163
d_vec_copy(x_norm, x_resid_attn, kNEmbd);
11221164

@@ -1159,6 +1201,7 @@ __device__ inline void d_forward_token_logits(
11591201
x_after_attn[i] = wo_out[i] + x_resid_attn[i];
11601202
}
11611203

1204+
// 2) MLP block.
11621205
float x_norm2[kNEmbd];
11631206
d_rmsnorm_forward(x_after_attn, x_norm2, &inv_rms, kNEmbd);
11641207
float fc1_pre[kFcDim];
@@ -1175,10 +1218,12 @@ __device__ inline void d_forward_token_logits(
11751218
x_out[i] = fc2_out[i] + x_after_attn[i];
11761219
}
11771220

1221+
// Weight tying: output projection reuses token embedding matrix.
11781222
d_matvec(model.wte, x_out, logits_out, model.vocab_size, kNEmbd);
11791223
}
11801224

11811225
__global__ void eval_sequence_nll_kernel(ModelPtrs model, const int* tokens, int n, float* out_nll) {
1226+
// Validation kernel: forward-only NLL accumulation, no gradient work.
11821227
if (blockIdx.x != 0 || threadIdx.x != 0) {
11831228
return;
11841229
}
@@ -1203,6 +1248,7 @@ __global__ void eval_sequence_nll_kernel(ModelPtrs model, const int* tokens, int
12031248
}
12041249

12051250
__global__ void forward_last_logits_kernel(ModelPtrs model, const int* tokens, int seq_len, float* out_logits) {
1251+
// Inference kernel: run causal forward, return logits at final position.
12061252
if (blockIdx.x != 0 || threadIdx.x != 0) {
12071253
return;
12081254
}
@@ -1228,8 +1274,10 @@ __global__ void forward_last_logits_kernel(ModelPtrs model, const int* tokens, i
12281274
}
12291275
int main(int argc, char** argv) {
12301276
try {
1277+
// This fused path is intentionally minimal and currently specialized to one layer.
12311278
static_assert(kNLayer == 1, "current fused kernels are implemented for n_layer=1");
12321279

1280+
// Let there be deterministic setup and data loading.
12331281
TrainConfig cfg = parse_args(argc, argv);
12341282
std::mt19937 rng(cfg.seed);
12351283
DataBundle data = load_data(rng);
@@ -1257,10 +1305,12 @@ int main(int argc, char** argv) {
12571305
std::cout << "vocab size: " << data.itos.size() << "\n";
12581306
std::cout << "num params: " << model.num_params() << "\n";
12591307

1308+
// Adam running products for bias correction.
12601309
float b1_prod = 1.0f;
12611310
float b2_prod = 1.0f;
12621311
constexpr float kPi = 3.14159265358979323846f;
12631312

1313+
// Repeat in sequence: one document per step.
12641314
for (int step = 0; step < cfg.num_steps; ++step) {
12651315
auto t0 = std::chrono::high_resolution_clock::now();
12661316

@@ -1274,10 +1324,12 @@ int main(int argc, char** argv) {
12741324

12751325
d_tokens.upload_raw(tokens.data(), static_cast<size_t>(n + 1));
12761326

1327+
// Cosine LR schedule, same functional form as microgpt.py.
12771328
float lr_t = cfg.learning_rate * 0.5f * (1.0f + std::cos(kPi * step / cfg.num_steps));
12781329
b1_prod *= cfg.beta1;
12791330
b2_prod *= cfg.beta2;
12801331

1332+
// One fused launch performs forward, backward, clipping, and AdamW.
12811333
train_step_kernel<<<1, 1>>>(
12821334
model_ptrs,
12831335
d_tokens.data(),
@@ -1305,6 +1357,7 @@ int main(int argc, char** argv) {
13051357
<< " | loss " << std::fixed << std::setprecision(4) << loss
13061358
<< " | " << ms << "ms\n";
13071359

1360+
// Periodic validation on held-out docs.
13081361
if ((step + 1) % cfg.val_every == 0) {
13091362
float val_loss = 0.0f;
13101363
int val_n = 0;
@@ -1334,6 +1387,7 @@ int main(int argc, char** argv) {
13341387
}
13351388
}
13361389

1390+
// Inference: top-k sampled autoregressive decoding.
13371391
std::cout << "\n--- inference ---\n";
13381392
std::vector<float> host_logits(static_cast<size_t>(model.vocab_size), 0.0f);
13391393
for (int sample_idx = 0; sample_idx < cfg.num_samples; ++sample_idx) {

0 commit comments

Comments
 (0)