Skip to content
This repository was archived by the owner on May 29, 2025. It is now read-only.

Commit 57031c2

Browse files
committed
Implement support for KQ norm for CPU inference
We currently assume the norm weights are shared between all heads for simplicity.
1 parent b9a9650 commit 57031c2

3 files changed

Lines changed: 39 additions & 18 deletions

File tree

src/infer.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,16 @@ float* forward(struct Transformer* transformer, int token, int pos, unsigned fla
361361
matmul(s->k, s->xb, w->wk[l], w->bqkv[l] ? w->bqkv[l] + q_dim : NULL, dim, kv_dim, dotprod);
362362
matmul(s->v, s->xb, w->wv[l], w->bqkv[l] ? w->bqkv[l] + q_dim + kv_dim : NULL, dim, kv_dim, dotprod);
363363

364+
// some models apply rmsnorm to qk values
365+
if (p->qk_norm) {
366+
for (int i = 0; i < p->n_heads; ++i) {
367+
rmsnorm(s->q + i * p->head_dim, s->q + i * p->head_dim, w->qnorm_weight[l], p->head_dim, p->norm_eps, false);
368+
}
369+
for (int i = 0; i < p->n_kv_heads; ++i) {
370+
rmsnorm(s->k + i * p->head_dim, s->k + i * p->head_dim, w->knorm_weight[l], p->head_dim, p->norm_eps, false);
371+
}
372+
}
373+
364374
// some models require clipping qkv values
365375
for (int i = 0; i < q_dim; i++) {
366376
s->q[i] = clip(s->q[i], p->qkv_clip);

src/model.h

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3-
#include <stddef.h>
43
#include <stdbool.h>
4+
#include <stddef.h>
55

66
#define MAX_LAYERS 128
77
#define MAX_EXPERTS 64
@@ -10,23 +10,24 @@
1010
#define KV_SINKS 2
1111

1212
struct Config {
13-
int dim; // transformer dimension
14-
int hidden_dim; // for ffn layers
15-
int head_dim; // for attention heads; usually dim / n_heads
16-
int n_layers; // number of layers
17-
int n_heads; // number of query heads
18-
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
19-
int vocab_size; // vocabulary size, usually 256 (byte-level)
20-
int seq_len; // max sequence length
21-
float rope_theta; // RoPE theta
22-
int rotary_dim; // RoPE rotary dimension (elements after that don't get rotated)
23-
int n_experts; // number of experts for MoE models
24-
int n_experts_ac; // number of active experts for MoE models
25-
float norm_eps; // epsilon for layer normalization
26-
bool act_gelu; // use GELU activation function
27-
bool norm_ln; // use full LN normalization
28-
bool norm_par; // use parallel MLP/attention by omitting intermediate normalization
29-
float qkv_clip; // clip qkv values to [-clip, clip]
13+
int dim; // transformer dimension
14+
int hidden_dim; // for ffn layers
15+
int head_dim; // for attention heads; usually dim / n_heads
16+
int n_layers; // number of layers
17+
int n_heads; // number of query heads
18+
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
19+
int vocab_size; // vocabulary size, usually 256 (byte-level)
20+
int seq_len; // max sequence length
21+
float rope_theta; // RoPE theta
22+
int rotary_dim; // RoPE rotary dimension (elements after that don't get rotated)
23+
int n_experts; // number of experts for MoE models
24+
int n_experts_ac; // number of active experts for MoE models
25+
float norm_eps; // epsilon for layer normalization
26+
bool act_gelu; // use GELU activation function
27+
bool norm_ln; // use full LN normalization
28+
bool norm_par; // use parallel MLP/attention by omitting intermediate normalization
29+
bool qk_norm; // use qk normalization
30+
float qkv_clip; // clip qkv values to [-clip, clip]
3031
};
3132

3233
struct Weights {
@@ -37,6 +38,8 @@ struct Weights {
3738
// weights for norms
3839
float* rms_att_weight[MAX_LAYERS]; // (dim) rmsnorm weights
3940
float* rms_ffn_weight[MAX_LAYERS]; // (dim)
41+
float* qnorm_weight[MAX_LAYERS]; // (head_dim)
42+
float* knorm_weight[MAX_LAYERS]; // (head_dim)
4043
// weights for matmuls
4144
void* wq[MAX_LAYERS]; // (n_heads * head_dim, dim)
4245
void* wk[MAX_LAYERS]; // (n_kv_heads * head_dim, dim)

src/run.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ void get_config(struct Config* config, struct Tensors* tensors, int context) {
6464
config->norm_ln = norm_type && strncmp(norm_type, "layernorm", 9) == 0; // note: we currently don't support layernorm bias
6565
config->norm_par = norm_type && strcmp(norm_type, "layernorm_par") == 0; // note: we currently don't support layernorm bias
6666

67+
const char* qk_norm = tensors_metadata_find(tensors, "qk_norm");
68+
config->qk_norm = qk_norm && atoi(qk_norm);
69+
6770
const char* qkv_clip = tensors_metadata_find(tensors, "qkv_clip");
6871
config->qkv_clip = qkv_clip ? atof(qkv_clip) : FLT_MAX;
6972
}
@@ -90,6 +93,11 @@ void get_weights(struct Config* config, struct Weights* weights, struct Tensors*
9093
weights->wv[l] = tensors_get(tensors, "model.layers.%d.attn.wv.weight", l, wtype, (int[]){config->n_kv_heads * config->head_dim, config->dim / gsize, 0, 0});
9194
weights->wo[l] = tensors_get(tensors, "model.layers.%d.attn.wo.weight", l, wtype, (int[]){config->dim, config->n_heads * config->head_dim / gsize, 0, 0});
9295

96+
if (config->qk_norm) {
97+
weights->qnorm_weight[l] = tensors_get(tensors, "model.layers.%d.attn.qnorm.weight", l, dt_f32, (int[]){config->head_dim, 0, 0, 0});
98+
weights->knorm_weight[l] = tensors_get(tensors, "model.layers.%d.attn.knorm.weight", l, dt_f32, (int[]){config->head_dim, 0, 0, 0});
99+
}
100+
93101
if (tensors_find(tensors, "model.layers.%d.attn.wqkv.bias", l)) {
94102
weights->bqkv[l] = (float*)tensors_get(tensors, "model.layers.%d.attn.wqkv.bias", l, dt_f32, (int[]){(config->n_heads + config->n_kv_heads * 2) * config->head_dim, 0, 0, 0});
95103
}

0 commit comments

Comments
 (0)