1- # Taken from https://github.com/FluxML/model-zoo/pull/410
2- using ConcreteStructs, MLUtils, Lux, Random, Optimisers, Printf, Statistics, NNlib,
3- DataDeps, StatsBase, OneHotArrays, JLD2, Reactant, Enzyme, BytePairEncoding
1+ using ConcreteStructs,
2+ MLUtils,
3+ Lux,
4+ Random,
5+ Optimisers,
6+ Printf,
7+ Statistics,
8+ DataDeps,
9+ OneHotArrays,
10+ Reactant,
11+ Enzyme,
12+ BytePairEncoding
413using Comonicon: @main
514
6- if ! haskey(DataDeps. registry, " nanogpt" )
7- register(DataDep(
8- " nanogpt" ,
9- " Shakespeare Input Text for training NanoGPT" ,
10- " https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt" ,
11- " 59a0ad62833b2e15ec811c548618876359e902717431236e52699a0e2bc253ca"
12- ))
15+ if ! haskey(DataDeps. registry, " nanogpt_shakespeare_input" )
16+ register(
17+ DataDep(
18+ " nanogpt_shakespeare_input" ,
19+ " Shakespeare Input Text for training NanoGPT" ,
20+ " https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt" ,
21+ " 59a0ad62833b2e15ec811c548618876359e902717431236e52699a0e2bc253ca" ,
22+ ),
23+ )
1324end
1425
1526# Setup the model definition
16- @concrete struct CausalSelfAttention < :
17- AbstractLuxContainerLayer{(:causal_attn, :proj, :attn_drop)}
18- causal_attn
19- proj
20- attn_drop
21- n_embed:: Int
22- n_heads:: Int
23- end
2427
25- function CausalSelfAttention(; n_embed, n_heads, dropout_rate, use_bias)
26- causal_attn = Dense(n_embed, 3 * n_embed; use_bias)
27- proj = Chain(
28- Dense(n_embed, n_embed; use_bias),
29- Dropout(dropout_rate)
30- )
31- attn_drop = Dropout(dropout_rate)
32- return CausalSelfAttention(causal_attn, proj, attn_drop, n_embed, n_heads)
28+ @concrete struct CausalSelfAttention <: AbstractLuxWrapperLayer{:mha}
29+ mha
3330end
3431
35- function (attn:: CausalSelfAttention )(x:: AbstractArray{T, 3} , ps, st) where {T}
36- qkv, qkv_st = attn. causal_attn(x, ps. causal_attn, st. causal_attn)
37- q, k, v = (
38- selectdim(qkv, 1 , 1 : (attn. n_heads)),
39- selectdim(qkv, 1 , (attn. n_heads + 1 ): (2 * attn. n_heads)),
40- selectdim(qkv, 1 , (2 * attn. n_heads + 1 ): (3 * attn. n_heads))
41- )
42- dp = StatefulLuxLayer{true }(attn. attn_drop, ps. attn_drop, st. attn_drop)
43- mha, _ = NNlib. dot_product_attention(
44- q, k, v, nothing ; mask= NNlib. make_causal_mask(x), fdrop= dp, nheads= attn. n_heads
45- )
46- proj, proj_st = attn. proj(mha, ps. proj, st. proj)
47- return proj, (; causal_attn= qkv_st, proj= proj_st, attn_drop= dp. attn_drop)
32+ function (attn:: CausalSelfAttention )(x:: AbstractArray{T,3} , ps, st) where {T}
33+ return attn. mha((x, x, x, NNlib. make_causal_mask(x)), ps, st)
4834end
4935
5036@concrete struct GPTBlock <: AbstractLuxWrapperLayer{:block}
5137 block
5238end
5339
5440function GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)
55- return GPTBlock(Chain(
56- SkipConnection(
57- Chain(
58- LayerNorm((n_embed, 1 )),
59- CausalSelfAttention(; n_embed, n_heads, dropout_rate, use_bias)
41+ return GPTBlock(
42+ Chain(
43+ SkipConnection(
44+ Chain(
45+ LayerNorm((n_embed, 1 )),
46+ CausalSelfAttention(; n_embed, n_heads, dropout_rate, use_bias),
47+ ),
48+ + ,
6049 ),
61- +
62- ),
63- SkipConnection(
64- Chain(
65- LayerNorm(( n_embed, 1 ) ),
66- Dense(n_embed => 4 * n_embed, gelu; use_bias ),
67- Dense( 4 * n_embed => n_embed; use_bias ),
68- Dropout(dropout_rate)
50+ SkipConnection(
51+ Chain(
52+ LayerNorm((n_embed, 1 )),
53+ Dense(n_embed => 4 * n_embed, gelu; use_bias),
54+ Dense( 4 * n_embed => n_embed; use_bias ),
55+ Dropout(dropout_rate ),
56+ ),
57+ + ,
6958 ),
70- +
71- )
72- ))
59+ ),
60+ )
7361end
7462
7563struct PositionalEmbedding{E} <: AbstractLuxWrapperLayer{:embedding}
8876end
8977
9078function GPT(; n_vocab, n_embed, block_size, n_layers, dropout_rate, n_heads, use_bias)
91- return GPT(Chain(
92- Parallel(
93- + ,
94- Embedding(n_vocab => n_embed),
95- PositionalEmbedding(block_size => n_embed)
79+ return GPT(
80+ Chain(
81+ Parallel(
82+ + , Embedding(n_vocab => n_embed), PositionalEmbedding(block_size => n_embed)
83+ ),
84+ Dropout(dropout_rate),
85+ Chain(
86+ ntuple(
87+ Returns(GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)), n_layers
88+ ). .. ,
89+ ),
90+ LayerNorm((n_embed, 1 )),
91+ Dense(n_embed => n_vocab; use_bias),
9692 ),
97- Dropout(dropout_rate),
98- Chain(ntuple(
99- Returns(GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)), n_layers
100- ). .. ),
101- LayerNorm((n_embed, 1 )),
102- Dense(n_embed => n_vocab; use_bias)
103- ))
93+ )
10494end
10595
10696#=
@@ -117,9 +107,7 @@ ps, st = Lux.setup(rng, model) |> dev
117107=#
118108
119109# Use the model to generate some text.
120- function generate_text(
121- model, ps, st, seed; alphabet, output_length, sequence_length
122- )
110+ function generate_text(model, ps, st, seed; alphabet, output_length, sequence_length)
123111 dev = get_device((ps, st))
124112 @assert ! (dev isa ReactantDevice) " Currently we don't support running inference of \
125113 dynamically sized tensors."
@@ -192,14 +180,23 @@ function get_nanogpt_data(; sequence_length, test_split)
192180end
193181
194182@main function main(;
195- n_embed:: Int = 64 , n_hidden:: Int = 256 , n_heads:: Int = 4 , qk_dim:: Int = 16 ,
196- v_dim:: Int = 16 , n_layers:: Int = 6 , sequence_length:: Int = 64 , batchsize:: Int = 128 ,
197- dropout_rate:: Float32 = 0.0f0 , test_split:: Float64 = 0.1 , lr:: Float64 = 1e-2 ,
198- epochs:: Int = 100 ,
199- # Only inference options
200- inference:: Bool = false , model_path:: String = " " ,
201- seed:: Union{String, Vector{String}} = [" _" , " The" , " Julia" , " Lux.jl" ],
202- output_length:: Int = 1024
183+ n_embed:: Int = 64 ,
184+ n_hidden:: Int = 256 ,
185+ n_heads:: Int = 4 ,
186+ qk_dim:: Int = 16 ,
187+ v_dim:: Int = 16 ,
188+ n_layers:: Int = 6 ,
189+ sequence_length:: Int = 64 ,
190+ batchsize:: Int = 128 ,
191+ dropout_rate:: Float32 = 0.0f0 ,
192+ test_split:: Float64 = 0.1 ,
193+ lr:: Float64 = 1e-2 ,
194+ epochs:: Int = 100 ,
195+ # Only inference options
196+ inference:: Bool = false ,
197+ model_path:: String = " " ,
198+ seed:: Union{String,Vector{String}} = [" _" , " The" , " Julia" , " Lux.jl" ],
199+ output_length:: Int = 1024 ,
203200)
204201 rng = Random. default_rng()
205202 Random. seed!(rng, 1234 )
@@ -220,16 +217,14 @@ end
220217 alphabet = JLD2. load(model_path, " alphabet" )
221218 sequence_length = model_config. sequence_length
222219
223- texts = generate_text(
224- model, ps, st, seed; alphabet, output_length, sequence_length
225- )
220+ texts = generate_text(model, ps, st, seed; alphabet, output_length, sequence_length)
226221
227222 for (i, (text, s)) in enumerate(zip(texts, seed))
228223 @printf " [Info] Seed [%d]: %s\n " i s
229224 @printf " [Generated Text] %s\n\n " text
230225 end
231226
232- return
227+ return nothing
233228 end
234229
235230 alphabet, trainX, trainY, testX, testY = get_nanogpt_data(; sequence_length, test_split)
@@ -238,13 +233,19 @@ end
238233 @printf " [Info] Training size: %d sequences.\n " size(trainX, 2 )
239234 @printf " [Info] Testing size: %d sequences.\n\n " size(testX, 2 )
240235
241- train_loader = DataLoader(
242- (trainX, trainY); batchsize, shuffle= true , parallel= true
243- ) |> dev
236+ train_loader =
237+ DataLoader((trainX, trainY); batchsize, shuffle= true , parallel= true ) |> dev
244238
245239 model_config = (;
246- n_vocab= length(alphabet), n_embed, sequence_length, n_hidden,
247- n_layers, dropout_rate, n_heads, qk_dim, v_dim
240+ n_vocab= length(alphabet),
241+ n_embed,
242+ sequence_length,
243+ n_hidden,
244+ n_layers,
245+ dropout_rate,
246+ n_heads,
247+ qk_dim,
248+ v_dim,
248249 )
249250 model = GPT(; model_config... )
250251 ps, st = Lux. setup(rng, model) |> dev
290291
291292 # Generate some text here...
292293 texts = generate_text(
293- model, ps |> cdev, st |> cdev, seed;
294- alphabet, output_length, sequence_length
294+ model, ps |> cdev, st |> cdev, seed; alphabet, output_length, sequence_length
295295 )
296296 for (i, (text, s)) in enumerate(zip(texts, seed))
297297 @printf " [Info] Seed [%d]: %s\n " i s
307307 parameters= train_state. parameters |> cdev,
308308 states= train_state. states |> cdev,
309309 alphabet= alphabet,
310- model_config= model_config
310+ model_config= model_config,
311311 )
312312 end
313313 end
0 commit comments