11# Taken from https://github.com/FluxML/model-zoo/pull/410
22using ConcreteStructs, MLUtils, Lux, Random, Optimisers, Printf, Statistics, NNlib,
3- DataDeps, StatsBase, OneHotArrays, JLD2, Reactant, Enzyme
3+ DataDeps, StatsBase, OneHotArrays, JLD2, Reactant, Enzyme, BytePairEncoding
44using Comonicon: @main
55
66if ! haskey(DataDeps. registry, " nanogpt" )
5151 block
5252end
5353
54- function GPTBlock(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate, use_bias)
54+ function GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)
5555 return GPTBlock(Chain(
5656 SkipConnection(
5757 Chain(
@@ -63,8 +63,8 @@ function GPTBlock(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate, use
6363 SkipConnection(
6464 Chain(
6565 LayerNorm((n_embed, 1 )),
66- Dense(n_embed => n_hidden , gelu),
67- Dense(n_hidden => n_embed),
66+ Dense(n_embed => 4 * n_embed , gelu; use_bias ),
67+ Dense(4 * n_embed => n_embed; use_bias ),
6868 Dropout(dropout_rate)
6969 ),
7070 +
8787 layer
8888end
8989
90- function GPT(;
91- n_vocab, n_embed, sequence_length, n_hidden, n_layers, dropout_rate,
92- n_heads, qk_dim, v_dim
93- )
90+ function GPT(; n_vocab, n_embed, block_size, n_layers, dropout_rate, n_heads, use_bias)
9491 return GPT(Chain(
9592 Parallel(
9693 + ,
9794 Embedding(n_vocab => n_embed),
98- PositionalEmbedding(sequence_length => n_embed)
95+ PositionalEmbedding(block_size => n_embed)
9996 ),
10097 Dropout(dropout_rate),
101- Chain(ntuple(n_layers) do i
102- return GPTBlock(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate)
103- end ... ),
98+ Chain(ntuple(
99+ Returns( GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)), n_layers
100+ ) . .. ),
104101 LayerNorm((n_embed, 1 )),
105- Dense(n_embed => n_vocab)
102+ Dense(n_embed => n_vocab; use_bias )
106103 ))
107104end
108105
106+ #=
107+
108+ dev = reactant_device(; force=true)
109+ rng = Random.default_rng()
110+
111+ model = GPT(;
112+ n_vocab=50304, n_embed=768, block_size=1024, n_layers=12, dropout_rate=0.0, n_heads=12,
113+ use_bias=true
114+ )
115+ ps, st = Lux.setup(rng, model) |> dev
116+
117+ =#
118+
109119# Use the model to generate some text.
110120function generate_text(
111121 model, ps, st, seed; alphabet, output_length, sequence_length
@@ -152,13 +162,14 @@ function get_nanogpt_data(; sequence_length, test_split)
152162 data_file = joinpath(datadep" nanogpt" , " shakespeare_input.txt" )
153163 text = String(read(data_file))
154164
155- # For aesthetic reasons, replace newlines with strings. This is not necessary, but makes
156- # strings print nicer.
157- text = replace(text, r" \r ?\n " => " " )
165+ idx = ceil(Int, length(text) * (1 - test_split))
166+ train_text = text[1 : idx]
167+ test_text = text[(idx + 1 ): end ]
168+
169+ tokenizer = BytePairEncoding. load_gpt2()
158170
159- # # an array of all unique characters
160- alphabet = [unique(text). .. , ' _' ]
161- stop = alphabet[end ]
171+ train_tokens = tokenizer(train_text)
172+ test_tokens = tokenizer(test_text)
162173
163174 B = (length(text) - 1 ) ÷ sequence_length
164175 # We must collect() before indexing, because String indexing does strange things with multi-byte
0 commit comments