11#pragma once
22
3- #include <stddef.h>
43#include <stdbool.h>
4+ #include <stddef.h>
55
66#define MAX_LAYERS 128
77#define MAX_EXPERTS 64
1010#define KV_SINKS 2
1111
1212struct Config {
13- int dim ; // transformer dimension
14- int hidden_dim ; // for ffn layers
15- int head_dim ; // for attention heads; usually dim / n_heads
16- int n_layers ; // number of layers
17- int n_heads ; // number of query heads
18- int n_kv_heads ; // number of key/value heads (can be < query heads because of multiquery)
19- int vocab_size ; // vocabulary size, usually 256 (byte-level)
20- int seq_len ; // max sequence length
21- float rope_theta ; // RoPE theta
22- int rotary_dim ; // RoPE rotary dimension (elements after that don't get rotated)
23- int n_experts ; // number of experts for MoE models
24- int n_experts_ac ; // number of active experts for MoE models
25- float norm_eps ; // epsilon for layer normalization
26- bool act_gelu ; // use GELU activation function
27- bool norm_ln ; // use full LN normalization
28- bool norm_par ; // use parallel MLP/attention by omitting intermediate normalization
29- float qkv_clip ; // clip qkv values to [-clip, clip]
13+ int dim ; // transformer dimension
14+ int hidden_dim ; // for ffn layers
15+ int head_dim ; // for attention heads; usually dim / n_heads
16+ int n_layers ; // number of layers
17+ int n_heads ; // number of query heads
18+ int n_kv_heads ; // number of key/value heads (can be < query heads because of multiquery)
19+ int vocab_size ; // vocabulary size, usually 256 (byte-level)
20+ int seq_len ; // max sequence length
21+ float rope_theta ; // RoPE theta
22+ int rotary_dim ; // RoPE rotary dimension (elements after that don't get rotated)
23+ int n_experts ; // number of experts for MoE models
24+ int n_experts_ac ; // number of active experts for MoE models
25+ float norm_eps ; // epsilon for layer normalization
26+ bool act_gelu ; // use GELU activation function
27+ bool norm_ln ; // use full LN normalization
28+ bool norm_par ; // use parallel MLP/attention by omitting intermediate normalization
29+ bool qk_norm ; // use qk normalization
30+ float qkv_clip ; // clip qkv values to [-clip, clip]
3031};
3132
3233struct Weights {
@@ -37,6 +38,8 @@ struct Weights {
3738 // weights for norms
3839 float * rms_att_weight [MAX_LAYERS ]; // (dim) rmsnorm weights
3940 float * rms_ffn_weight [MAX_LAYERS ]; // (dim)
41+ float * qnorm_weight [MAX_LAYERS ]; // (head_dim)
42+ float * knorm_weight [MAX_LAYERS ]; // (head_dim)
4043 // weights for matmuls
4144 void * wq [MAX_LAYERS ]; // (n_heads * head_dim, dim)
4245 void * wk [MAX_LAYERS ]; // (n_kv_heads * head_dim, dim)
0 commit comments