Skip to content

Commit 4aa7d94

Browse files
committed
feat: towards completion
1 parent 4927f92 commit 4aa7d94

File tree

1 file changed

+243
-9
lines changed

1 file changed

+243
-9
lines changed

examples/MinimalMamba/main.jl

Lines changed: 243 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using ConcreteStructs, Lux, Random, Reactant
44
using HuggingFaceTokenizers, Scratch, PythonCall, JSON3
55

6-
## Load some python libraries for loading pretrained weights
6+
# Load some python libraries for loading pretrained weights
77

88
const huggingface_hub = pyimport("huggingface_hub")
99
const torch = pyimport("torch")
@@ -286,9 +286,9 @@ function (ssm::SSM)(x::AbstractArray{T,3}, ps, st) where {T}
286286

287287
x_dbl, st_x_proj = ssm.x_proj(x, ps.x_proj, st.x_proj)
288288

289-
Δ = @view x_dbl[1:(ssm.dt_rank), :, :]
290-
B = @view x_dbl[(ssm.dt_rank + 1):(ssm.dt_rank + n), :, :]
291-
C = @view x_dbl[(ssm.dt_rank + n + 1):end, :, :]
289+
Δ = x_dbl[1:(ssm.dt_rank), :, :]
290+
B = x_dbl[(ssm.dt_rank + 1):(ssm.dt_rank + n), :, :]
291+
C = x_dbl[(ssm.dt_rank + n + 1):end, :, :]
292292

293293
Δ, st_dt_proj = ssm.dt_proj(Δ, ps.dt_proj, st.dt_proj)
294294

@@ -368,6 +368,9 @@ function download_mamba_weights_from_huggingface(pretrained_model_name::String)
368368
d_model=config[:d_model], n_layer=config[:n_layer], vocab_size=config[:vocab_size]
369369
)
370370
371+
weights_file = huggingface_hub.hf_hub_download(;
372+
repo_id=pretrained_model_name, filename="pytorch_model.bin", local_dir=local_dir
373+
)
371374
weights = torch.load(weights_file; weights_only=true, mmap=true, map_location="cpu")
372375
373376
return mamba_config, weights
@@ -424,13 +427,244 @@ function load_weights_from_dict(weights_dict, config::MambaModelArgs, dev)
424427
)
425428
end
426429
430+
# ## Tokenizer
431+
432+
struct MambaTokenizer
433+
tokenizer::Tokenizer
434+
pad_token_id::Int32
435+
eos_token_id::Int32
436+
end
437+
438+
const SPLIT_RE = r"(<\|[^>]+?\|>)"
439+
440+
function MambaTokenizer()
441+
tok = HuggingFaceTokenizers.from_pretrained(Tokenizer, "EleutherAI/gpt-neox-20b")
442+
return MambaTokenizer(
443+
tok, token_to_id(tok, "<|endoftext|>"), token_to_id(tok, "<|endoftext|>")
444+
)
445+
end
446+
447+
token_to_id(tokenizer::MambaTokenizer, s) = token_to_id(tokenizer.tokenizer, s)
448+
function token_to_id(tokenizer::Tokenizer, s)
449+
return pyconvert(Int32, tokenizer.py_tokenizer.token_to_id(s)) + Int32(1)
450+
end
451+
452+
function split_with_delims(text::String, re::Regex)
453+
parts = String[]
454+
last_end = 1
455+
for m in eachmatch(re, text)
456+
if m.offset > last_end
457+
push!(parts, text[last_end:(m.offset - 1)])
458+
elseif m.offset == 1
459+
push!(parts, "")
460+
end
461+
push!(parts, m.match)
462+
last_end = m.offset + length(m.match)
463+
end
464+
if last_end ≤ lastindex(text)
465+
push!(parts, text[last_end:end])
466+
end
467+
return parts
468+
end
469+
470+
function HuggingFaceTokenizers.encode(tok::MambaTokenizer, text)
471+
ids = Int32[]
472+
for part in filter(!isempty, split_with_delims(text, SPLIT_RE))
473+
append!(ids, encode(tok.tokenizer, string(part)).ids .+ Int16(1))
474+
end
475+
return ids
476+
end
477+
478+
function HuggingFaceTokenizers.decode(tok::MambaTokenizer, ids::Vector{<:Integer})
479+
return decode(tok.tokenizer, ids .- Int16(1); skip_special_tokens=false)
480+
end
481+
482+
# ## Text Generation Utilities
483+
484+
function weighted_sample(
485+
rng, items::AbstractVector, weights::AbstractVector, n::Int; temperature::Number=1
486+
)
487+
@assert length(items) == length(weights)
488+
489+
weights = weights .^ inv(eltype(weights)(temperature))
490+
weights = weights ./ sum(weights)
491+
cumprobs = reshape(cumsum(weights), :, 1)
492+
random_vals = rand(rng, 1, n)
493+
494+
indices = dropdims(sum(cumprobs .< random_vals; dims=1); dims=1) .+ 1
495+
return items[indices]
496+
end
497+
498+
# Setting top_k to 1 will disable sampling and instead return the argmax. For larger
499+
# values of top_k, we sample from the top_k most likely tokens.
500+
501+
function predict_next_token(
502+
rng,
503+
model,
504+
token_ids::AbstractVector{T},
505+
input_mask_len,
506+
ps,
507+
st;
508+
top_k::Int=32,
509+
temperature::Number=1,
510+
) where {T}
511+
token_ids = Reactant.materialize_traced_array(reshape(token_ids, :, 1))
512+
513+
logits, stₙ = model(token_ids, ps, st)
514+
next_token_logits = logits[:, end - input_mask_len, 1]
515+
516+
if top_k == 1
517+
predictions = T.(argmax(next_token_logits))
518+
else
519+
top_k_idxs = partialsortperm(next_token_logits, 1:top_k; rev=true)
520+
top_k_logits = next_token_logits[Reactant.materialize_traced_array(top_k_idxs)]
521+
predictions = weighted_sample(rng, T.(top_k_idxs), top_k_logits, 1; temperature)
522+
end
523+
524+
predictions = mod1.(predictions, T(size(logits, 1)))
525+
return predictions, stₙ
526+
end
527+
528+
function update_token_ids_and_mask!(
529+
padded_token_ids::AbstractVector, input_mask_len, cur_num_tokens, next_token::Number
530+
)
531+
@trace if input_mask_len == 0
532+
cur_num_tokens += eltype(cur_num_tokens)(1)
533+
@allowscalar padded_token_ids[cur_num_tokens] = next_token
534+
else
535+
L = length(padded_token_ids)
536+
padded_token_ids[1:(L - 1)] = padded_token_ids[2:L]
537+
@allowscalar padded_token_ids[L] = next_token
538+
end
539+
return input_mask_len - eltype(input_mask_len)(1), cur_num_tokens
540+
end
541+
542+
function generate_chunk_of_text(
543+
rng,
544+
model,
545+
padded_token_ids,
546+
input_mask_len,
547+
cur_num_tokens,
548+
ps,
549+
st,
550+
n_tokens,
551+
top_k,
552+
temperature,
553+
)
554+
next_n_tokens = similar(padded_token_ids, n_tokens)
555+
@trace track_numbers = false for i in 1:n_tokens
556+
next_token, st = predict_next_token(
557+
rng, model, padded_token_ids, input_mask_len, ps, st; top_k, temperature
558+
)
559+
next_token_scalar = @allowscalar next_token[1]
560+
input_mask_len, cur_num_tokens = update_token_ids_and_mask!(
561+
padded_token_ids, input_mask_len, cur_num_tokens, next_token_scalar
562+
)
563+
@allowscalar next_n_tokens[i] = next_token_scalar
564+
end
565+
return next_n_tokens, input_mask_len, cur_num_tokens, st
566+
end
567+
568+
function generate_text(
569+
model::Mamba,
570+
prompt::String,
571+
ps,
572+
st,
573+
max_new_tokens::Int,
574+
tokenizer::MambaTokenizer;
575+
chunk_size::Int=128,
576+
top_k::Int=32,
577+
temperature::Number=1,
578+
)
579+
rdev = reactant_device()
580+
581+
token_ids = encode(tokenizer, prompt)
582+
print(decode(tokenizer, token_ids))
583+
padding_size_to_compile = min(2048, max(length(token_ids) + max_new_tokens, 512))
584+
if length(token_ids) > padding_size_to_compile
585+
@warn "Prompt is longer than $(padding_size_to_compile) tokens; truncating to \
586+
last $(padding_size_to_compile) tokens."
587+
padded_token_ids = token_ids[(end - padding_size_to_compile + 1):end]
588+
else
589+
padded_token_ids = pad_constant(
590+
token_ids,
591+
(0, padding_size_to_compile - length(token_ids)),
592+
eltype(token_ids)(tokenizer.pad_token_id),
593+
)
594+
@assert length(padded_token_ids) == padding_size_to_compile
595+
end
596+
padded_token_ids = rdev(padded_token_ids)
597+
598+
rng = Random.default_rng() |> rdev
599+
cur_num_tokens = ConcreteRNumber(Int32(length(padded_token_ids)))
600+
input_mask_len = ConcreteRNumber(Int32(padding_size_to_compile - length(token_ids)))
601+
602+
chunked_text_genfn = @compile generate_chunk_of_text(
603+
rng,
604+
model,
605+
padded_token_ids,
606+
input_mask_len,
607+
cur_num_tokens,
608+
ps,
609+
st,
610+
chunk_size,
611+
top_k,
612+
temperature,
613+
)
614+
615+
n_tokens_generated = 0
616+
total_time = 0.0
617+
while n_tokens_generated < max_new_tokens
618+
start_time = time()
619+
next_n_tokens, input_mask_len, cur_num_tokens, st = chunked_text_genfn(
620+
rng,
621+
model,
622+
padded_token_ids,
623+
input_mask_len,
624+
cur_num_tokens,
625+
ps,
626+
st,
627+
chunk_size,
628+
top_k,
629+
temperature,
630+
)
631+
total_time += time() - start_time
632+
633+
n_tokens_generated += length(next_n_tokens)
634+
next_n_tokens_jl = vec(Array(next_n_tokens))
635+
for token in next_n_tokens_jl
636+
token == tokenizer.eos_token_id && return nothing
637+
print(decode(tokenizer, [token]))
638+
end
639+
end
640+
tokens_per_second = n_tokens_generated / total_time
641+
println()
642+
@info "Tokens per second: $(tokens_per_second)"
643+
644+
return nothing
645+
end
646+
647+
tokenizer = MambaTokenizer()
648+
427649
config, weights_dict = download_mamba_weights_from_huggingface("state-spaces/mamba-130m");
428650
model = Mamba(config)
429651
430-
ps_from_hgf = load_weights_from_dict(weights_dict, config, cpu_device());
431-
ps = Lux.initialparameters(Random.default_rng(), model);
432-
st = Lux.initialstates(Random.default_rng(), model);
652+
rdev = reactant_device()
653+
654+
ps_from_hgf = load_weights_from_dict(weights_dict, config, rdev);
655+
st = Lux.initialstates(Random.default_rng(), model) |> rdev;
656+
657+
generate_text(model, "Mamba is the ", ps_from_hgf, st, 128, tokenizer; chunk_size=32)
658+
659+
# x = Int32.(rand(1:(config.vocab_size), 4, 32))
660+
# x = reshape(encode(tok, "Mamba is the "), :, 1)
433661
434-
x = Int32.(rand(1:(config.vocab_size), 4, 32))
662+
# res = model(x, ps_from_hgf, st)
435663
436-
model(x, ps_from_hgf, st)
664+
# decode(tok, [argmax(res[1][:, end, 1])])
665+
666+
# ## Entry Point
667+
668+
if abspath(PROGRAM_FILE) == @__FILE__
669+
main()
670+
end

0 commit comments

Comments
 (0)