@@ -9,13 +9,14 @@ using ConcreteStructs,
99 OneHotArrays,
1010 Reactant,
1111 Enzyme,
12- BytePairEncoding
12+ BytePairEncoding,
13+ NNlib
1314using Comonicon: @main
1415
15- if ! haskey(DataDeps. registry, " nanogpt_shakespeare_input " )
16+ if ! haskey(DataDeps. registry, " shakespeare_char " )
1617 register(
1718 DataDep(
18- " nanogpt_shakespeare_input " ,
19+ " shakespeare_char " ,
1920 " Shakespeare Input Text for training NanoGPT" ,
2021 " https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt" ,
2122 " 59a0ad62833b2e15ec811c548618876359e902717431236e52699a0e2bc253ca" ,
2728
2829@concrete struct CausalSelfAttention <: AbstractLuxWrapperLayer{:mha}
2930 mha
31+
32+ function CausalSelfAttention(args... ; kwargs... )
33+ mha = MultiHeadAttention(args... ; kwargs... )
34+ return new{typeof(mha)}(mha)
35+ end
3036end
3137
3238function (attn:: CausalSelfAttention )(x:: AbstractArray{T,3} , ps, st) where {T}
33- return attn. mha((x, x, x, NNlib. make_causal_mask(x)), ps, st)
39+ (y, α), stₙ = attn. mha((x, x, x, NNlib. make_causal_mask(x)), ps, st)
40+ return y, stₙ
3441end
3542
36- @concrete struct GPTBlock <: AbstractLuxWrapperLayer{:block}
43+ @concrete struct GPT2Block <: AbstractLuxWrapperLayer{:block}
3744 block
3845end
3946
40- function GPTBlock (; n_embed, n_heads, dropout_rate, use_bias )
41- return GPTBlock (
47+ function GPT2Block (; embed_dim, num_heads, hidden_dim, dropout_rate )
48+ return GPT2Block (
4249 Chain(
4350 SkipConnection(
4451 Chain(
45- LayerNorm((n_embed, 1 )),
46- CausalSelfAttention(; n_embed, n_heads, dropout_rate, use_bias),
52+ LayerNorm(embed_dim; dims= nothing ),
53+ CausalSelfAttention(
54+ embed_dim;
55+ nheads= num_heads,
56+ attention_dropout_probability= dropout_rate,
57+ dense_kwargs= (; init_weight= glorot_uniform, init_bias= zeros32),
58+ ),
4759 ),
4860 + ,
4961 ),
5062 SkipConnection(
5163 Chain(
52- LayerNorm((n_embed, 1 )),
53- Dense(n_embed => 4 * n_embed, gelu; use_bias),
54- Dense(4 * n_embed => n_embed; use_bias),
64+ LayerNorm(embed_dim; dims= nothing ),
65+ Dense(
66+ embed_dim => hidden_dim,
67+ gelu;
68+ init_weight= glorot_uniform,
69+ init_bias= zeros32,
70+ ),
71+ Dense(
72+ hidden_dim => embed_dim;
73+ init_weight= glorot_uniform,
74+ init_bias= zeros32,
75+ ),
5576 Dropout(dropout_rate),
5677 ),
5778 + ,
@@ -60,51 +81,62 @@ function GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)
6081 )
6182end
6283
63- struct PositionalEmbedding{E} <: AbstractLuxWrapperLayer{:embedding}
64- embedding:: E
65-
66- function PositionalEmbedding(args... ; kwargs... )
67- embed = Embedding(args... ; kwargs... )
68- return new{typeof(embed)}(embed)
69- end
70- end
71-
72- (pe:: PositionalEmbedding )(x, ps, st) = pe. embedding(1 : size(x, 1 ), ps, st)
73-
74- @concrete struct GPT <: AbstractLuxWrapperLayer{:layer}
75- layer
84+ @concrete struct GPT2 <: AbstractLuxContainerLayer{(:tok_emb, :pos_emb, :gpt_blocks)}
85+ tok_emb
86+ pos_emb
87+ gpt_blocks
7688end
7789
78- function GPT(; n_vocab, n_embed, block_size, n_layers, dropout_rate, n_heads, use_bias)
79- return GPT(
90+ function GPT2(;
91+ n_vocab, embed_dim, num_heads, hidden_dim, dropout_rate, block_size, n_layers
92+ )
93+ return GPT2(
94+ Embedding(n_vocab => embed_dim),
95+ Embedding(block_size => embed_dim),
8096 Chain(
81- Parallel(
82- + , Embedding(n_vocab => n_embed), PositionalEmbedding(block_size => n_embed)
83- ),
8497 Dropout(dropout_rate),
8598 Chain(
8699 ntuple(
87- Returns(GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)), n_layers
100+ Returns(GPT2Block(; embed_dim, num_heads, dropout_rate, hidden_dim)),
101+ n_layers,
88102 ). .. ,
89103 ),
90- LayerNorm((n_embed, 1 )),
91- Dense(n_embed => n_vocab; use_bias),
104+ LayerNorm(embed_dim; dims= nothing ),
92105 ),
93106 )
94107end
95108
96- #=
109+ function (model:: GPT2 )(x, ps, st)
110+ token_embeddings, st_tok_emb = model. tok_emb(x, ps. tok_emb, st. tok_emb)
111+ pos_embeddings, st_pos_emb = model. pos_emb(1 : size(x, 1 ), ps. pos_emb, st. pos_emb)
112+ embedding_output = token_embeddings .+ pos_embeddings
113+
114+ query, st_gpt_blocks = model. gpt_blocks(embedding_output, ps. gpt_blocks, st. gpt_blocks)
115+ _, seq_len, batch_size = size(query)
116+ outputs = reshape(
117+ ps. tok_emb. weight' * reshape(query, :, seq_len * batch_size), :, seq_len, batch_size
118+ )
119+
120+ return outputs, (; tok_emb=st_tok_emb, pos_emb=st_pos_emb, gpt_blocks=st_gpt_blocks)
121+ end
97122
98123dev = reactant_device(; force=true)
99124rng = Random.default_rng()
100125
101- model = GPT(;
102- n_vocab=50304, n_embed=768, block_size=1024, n_layers=12, dropout_rate=0.0, n_heads=12,
103- use_bias=true
126+ model = GPT2(;
127+ n_vocab=50304,
128+ embed_dim=768,
129+ hidden_dim=3072,
130+ block_size=1024,
131+ n_layers=3,
132+ dropout_rate=0.0,
133+ num_heads=12,
104134)
105- ps, st = Lux.setup(rng, model) |> dev
135+ ps, st = Lux.setup(rng, model) |> dev;
136+
137+ x = rand(1:50304, 1024, 32) |> dev;
106138
107- =#
139+ @code_hlo model(x, ps, st)
108140
109141# Use the model to generate some text.
110142function generate_text(model, ps, st, seed; alphabet, output_length, sequence_length)
@@ -180,7 +212,7 @@ function get_nanogpt_data(; sequence_length, test_split)
180212end
181213
182214@main function main(;
183- n_embed :: Int = 64 ,
215+ embed_dim ::Int=64,
184216 n_hidden::Int=256,
185217 n_heads::Int=4,
186218 qk_dim::Int=16,
238270
239271 model_config = (;
240272 n_vocab=length(alphabet),
241- n_embed ,
273+ embed_dim ,
242274 sequence_length,
243275 n_hidden,
244276 n_layers,
0 commit comments