|
3 | 3 | using ConcreteStructs, Lux, Random, Reactant |
4 | 4 | using HuggingFaceTokenizers, Scratch, PythonCall, JSON3 |
5 | 5 |
|
6 | | -## Load some python libraries for loading pretrained weights |
| 6 | +# Load some python libraries for loading pretrained weights |
7 | 7 |
|
8 | 8 | const huggingface_hub = pyimport("huggingface_hub") |
9 | 9 | const torch = pyimport("torch") |
@@ -286,9 +286,9 @@ function (ssm::SSM)(x::AbstractArray{T,3}, ps, st) where {T} |
286 | 286 |
|
287 | 287 | x_dbl, st_x_proj = ssm.x_proj(x, ps.x_proj, st.x_proj) |
288 | 288 |
|
289 | | - Δ = @view x_dbl[1:(ssm.dt_rank), :, :] |
290 | | - B = @view x_dbl[(ssm.dt_rank + 1):(ssm.dt_rank + n), :, :] |
291 | | - C = @view x_dbl[(ssm.dt_rank + n + 1):end, :, :] |
| 289 | + Δ = x_dbl[1:(ssm.dt_rank), :, :] |
| 290 | + B = x_dbl[(ssm.dt_rank + 1):(ssm.dt_rank + n), :, :] |
| 291 | + C = x_dbl[(ssm.dt_rank + n + 1):end, :, :] |
292 | 292 |
|
293 | 293 | Δ, st_dt_proj = ssm.dt_proj(Δ, ps.dt_proj, st.dt_proj) |
294 | 294 |
|
@@ -368,6 +368,9 @@ function download_mamba_weights_from_huggingface(pretrained_model_name::String) |
368 | 368 | d_model=config[:d_model], n_layer=config[:n_layer], vocab_size=config[:vocab_size] |
369 | 369 | ) |
370 | 370 |
|
| 371 | + weights_file = huggingface_hub.hf_hub_download(; |
| 372 | + repo_id=pretrained_model_name, filename="pytorch_model.bin", local_dir=local_dir |
| 373 | + ) |
371 | 374 | weights = torch.load(weights_file; weights_only=true, mmap=true, map_location="cpu") |
372 | 375 |
|
373 | 376 | return mamba_config, weights |
@@ -424,13 +427,244 @@ function load_weights_from_dict(weights_dict, config::MambaModelArgs, dev) |
424 | 427 | ) |
425 | 428 | end |
426 | 429 |
|
| 430 | +# ## Tokenizer |
| 431 | +
|
| 432 | +struct MambaTokenizer |
| 433 | + tokenizer::Tokenizer |
| 434 | + pad_token_id::Int32 |
| 435 | + eos_token_id::Int32 |
| 436 | +end |
| 437 | +
|
| 438 | +const SPLIT_RE = r"(<\|[^>]+?\|>)" |
| 439 | +
|
| 440 | +function MambaTokenizer() |
| 441 | + tok = HuggingFaceTokenizers.from_pretrained(Tokenizer, "EleutherAI/gpt-neox-20b") |
| 442 | + return MambaTokenizer( |
| 443 | + tok, token_to_id(tok, "<|endoftext|>"), token_to_id(tok, "<|endoftext|>") |
| 444 | + ) |
| 445 | +end |
| 446 | +
|
| 447 | +token_to_id(tokenizer::MambaTokenizer, s) = token_to_id(tokenizer.tokenizer, s) |
| 448 | +function token_to_id(tokenizer::Tokenizer, s) |
| 449 | + return pyconvert(Int32, tokenizer.py_tokenizer.token_to_id(s)) + Int32(1) |
| 450 | +end |
| 451 | +
|
| 452 | +function split_with_delims(text::String, re::Regex) |
| 453 | + parts = String[] |
| 454 | + last_end = 1 |
| 455 | + for m in eachmatch(re, text) |
| 456 | + if m.offset > last_end |
| 457 | + push!(parts, text[last_end:(m.offset - 1)]) |
| 458 | + elseif m.offset == 1 |
| 459 | + push!(parts, "") |
| 460 | + end |
| 461 | + push!(parts, m.match) |
| 462 | + last_end = m.offset + length(m.match) |
| 463 | + end |
| 464 | + if last_end ≤ lastindex(text) |
| 465 | + push!(parts, text[last_end:end]) |
| 466 | + end |
| 467 | + return parts |
| 468 | +end |
| 469 | +
|
| 470 | +function HuggingFaceTokenizers.encode(tok::MambaTokenizer, text) |
| 471 | + ids = Int32[] |
| 472 | + for part in filter(!isempty, split_with_delims(text, SPLIT_RE)) |
| 473 | + append!(ids, encode(tok.tokenizer, string(part)).ids .+ Int16(1)) |
| 474 | + end |
| 475 | + return ids |
| 476 | +end |
| 477 | +
|
| 478 | +function HuggingFaceTokenizers.decode(tok::MambaTokenizer, ids::Vector{<:Integer}) |
| 479 | + return decode(tok.tokenizer, ids .- Int16(1); skip_special_tokens=false) |
| 480 | +end |
| 481 | +
|
| 482 | +# ## Text Generation Utilities |
| 483 | +
|
| 484 | +function weighted_sample( |
| 485 | + rng, items::AbstractVector, weights::AbstractVector, n::Int; temperature::Number=1 |
| 486 | +) |
| 487 | + @assert length(items) == length(weights) |
| 488 | +
|
| 489 | + weights = weights .^ inv(eltype(weights)(temperature)) |
| 490 | + weights = weights ./ sum(weights) |
| 491 | + cumprobs = reshape(cumsum(weights), :, 1) |
| 492 | + random_vals = rand(rng, 1, n) |
| 493 | +
|
| 494 | + indices = dropdims(sum(cumprobs .< random_vals; dims=1); dims=1) .+ 1 |
| 495 | + return items[indices] |
| 496 | +end |
| 497 | +
|
| 498 | +# Setting top_k to 1 will disable sampling and instead return the argmax. For larger |
| 499 | +# values of top_k, we sample from the top_k most likely tokens. |
| 500 | +
|
| 501 | +function predict_next_token( |
| 502 | + rng, |
| 503 | + model, |
| 504 | + token_ids::AbstractVector{T}, |
| 505 | + input_mask_len, |
| 506 | + ps, |
| 507 | + st; |
| 508 | + top_k::Int=32, |
| 509 | + temperature::Number=1, |
| 510 | +) where {T} |
| 511 | + token_ids = Reactant.materialize_traced_array(reshape(token_ids, :, 1)) |
| 512 | +
|
| 513 | + logits, stₙ = model(token_ids, ps, st) |
| 514 | + next_token_logits = logits[:, end - input_mask_len, 1] |
| 515 | +
|
| 516 | + if top_k == 1 |
| 517 | + predictions = T.(argmax(next_token_logits)) |
| 518 | + else |
| 519 | + top_k_idxs = partialsortperm(next_token_logits, 1:top_k; rev=true) |
| 520 | + top_k_logits = next_token_logits[Reactant.materialize_traced_array(top_k_idxs)] |
| 521 | + predictions = weighted_sample(rng, T.(top_k_idxs), top_k_logits, 1; temperature) |
| 522 | + end |
| 523 | +
|
| 524 | + predictions = mod1.(predictions, T(size(logits, 1))) |
| 525 | + return predictions, stₙ |
| 526 | +end |
| 527 | +
|
| 528 | +function update_token_ids_and_mask!( |
| 529 | + padded_token_ids::AbstractVector, input_mask_len, cur_num_tokens, next_token::Number |
| 530 | +) |
| 531 | + @trace if input_mask_len == 0 |
| 532 | + cur_num_tokens += eltype(cur_num_tokens)(1) |
| 533 | + @allowscalar padded_token_ids[cur_num_tokens] = next_token |
| 534 | + else |
| 535 | + L = length(padded_token_ids) |
| 536 | + padded_token_ids[1:(L - 1)] = padded_token_ids[2:L] |
| 537 | + @allowscalar padded_token_ids[L] = next_token |
| 538 | + end |
| 539 | + return input_mask_len - eltype(input_mask_len)(1), cur_num_tokens |
| 540 | +end |
| 541 | +
|
| 542 | +function generate_chunk_of_text( |
| 543 | + rng, |
| 544 | + model, |
| 545 | + padded_token_ids, |
| 546 | + input_mask_len, |
| 547 | + cur_num_tokens, |
| 548 | + ps, |
| 549 | + st, |
| 550 | + n_tokens, |
| 551 | + top_k, |
| 552 | + temperature, |
| 553 | +) |
| 554 | + next_n_tokens = similar(padded_token_ids, n_tokens) |
| 555 | + @trace track_numbers = false for i in 1:n_tokens |
| 556 | + next_token, st = predict_next_token( |
| 557 | + rng, model, padded_token_ids, input_mask_len, ps, st; top_k, temperature |
| 558 | + ) |
| 559 | + next_token_scalar = @allowscalar next_token[1] |
| 560 | + input_mask_len, cur_num_tokens = update_token_ids_and_mask!( |
| 561 | + padded_token_ids, input_mask_len, cur_num_tokens, next_token_scalar |
| 562 | + ) |
| 563 | + @allowscalar next_n_tokens[i] = next_token_scalar |
| 564 | + end |
| 565 | + return next_n_tokens, input_mask_len, cur_num_tokens, st |
| 566 | +end |
| 567 | +
|
| 568 | +function generate_text( |
| 569 | + model::Mamba, |
| 570 | + prompt::String, |
| 571 | + ps, |
| 572 | + st, |
| 573 | + max_new_tokens::Int, |
| 574 | + tokenizer::MambaTokenizer; |
| 575 | + chunk_size::Int=128, |
| 576 | + top_k::Int=32, |
| 577 | + temperature::Number=1, |
| 578 | +) |
| 579 | + rdev = reactant_device() |
| 580 | +
|
| 581 | + token_ids = encode(tokenizer, prompt) |
| 582 | + print(decode(tokenizer, token_ids)) |
| 583 | + padding_size_to_compile = min(2048, max(length(token_ids) + max_new_tokens, 512)) |
| 584 | + if length(token_ids) > padding_size_to_compile |
| 585 | + @warn "Prompt is longer than $(padding_size_to_compile) tokens; truncating to \ |
| 586 | + last $(padding_size_to_compile) tokens." |
| 587 | + padded_token_ids = token_ids[(end - padding_size_to_compile + 1):end] |
| 588 | + else |
| 589 | + padded_token_ids = pad_constant( |
| 590 | + token_ids, |
| 591 | + (0, padding_size_to_compile - length(token_ids)), |
| 592 | + eltype(token_ids)(tokenizer.pad_token_id), |
| 593 | + ) |
| 594 | + @assert length(padded_token_ids) == padding_size_to_compile |
| 595 | + end |
| 596 | + padded_token_ids = rdev(padded_token_ids) |
| 597 | +
|
| 598 | + rng = Random.default_rng() |> rdev |
| 599 | + cur_num_tokens = ConcreteRNumber(Int32(length(padded_token_ids))) |
| 600 | + input_mask_len = ConcreteRNumber(Int32(padding_size_to_compile - length(token_ids))) |
| 601 | +
|
| 602 | + chunked_text_genfn = @compile generate_chunk_of_text( |
| 603 | + rng, |
| 604 | + model, |
| 605 | + padded_token_ids, |
| 606 | + input_mask_len, |
| 607 | + cur_num_tokens, |
| 608 | + ps, |
| 609 | + st, |
| 610 | + chunk_size, |
| 611 | + top_k, |
| 612 | + temperature, |
| 613 | + ) |
| 614 | +
|
| 615 | + n_tokens_generated = 0 |
| 616 | + total_time = 0.0 |
| 617 | + while n_tokens_generated < max_new_tokens |
| 618 | + start_time = time() |
| 619 | + next_n_tokens, input_mask_len, cur_num_tokens, st = chunked_text_genfn( |
| 620 | + rng, |
| 621 | + model, |
| 622 | + padded_token_ids, |
| 623 | + input_mask_len, |
| 624 | + cur_num_tokens, |
| 625 | + ps, |
| 626 | + st, |
| 627 | + chunk_size, |
| 628 | + top_k, |
| 629 | + temperature, |
| 630 | + ) |
| 631 | + total_time += time() - start_time |
| 632 | +
|
| 633 | + n_tokens_generated += length(next_n_tokens) |
| 634 | + next_n_tokens_jl = vec(Array(next_n_tokens)) |
| 635 | + for token in next_n_tokens_jl |
| 636 | + token == tokenizer.eos_token_id && return nothing |
| 637 | + print(decode(tokenizer, [token])) |
| 638 | + end |
| 639 | + end |
| 640 | + tokens_per_second = n_tokens_generated / total_time |
| 641 | + println() |
| 642 | + @info "Tokens per second: $(tokens_per_second)" |
| 643 | +
|
| 644 | + return nothing |
| 645 | +end |
| 646 | +
|
| 647 | +tokenizer = MambaTokenizer() |
| 648 | +
|
427 | 649 | config, weights_dict = download_mamba_weights_from_huggingface("state-spaces/mamba-130m"); |
428 | 650 | model = Mamba(config) |
429 | 651 |
|
430 | | -ps_from_hgf = load_weights_from_dict(weights_dict, config, cpu_device()); |
431 | | -ps = Lux.initialparameters(Random.default_rng(), model); |
432 | | -st = Lux.initialstates(Random.default_rng(), model); |
| 652 | +rdev = reactant_device() |
| 653 | +
|
| 654 | +ps_from_hgf = load_weights_from_dict(weights_dict, config, rdev); |
| 655 | +st = Lux.initialstates(Random.default_rng(), model) |> rdev; |
| 656 | +
|
| 657 | +generate_text(model, "Mamba is the ", ps_from_hgf, st, 128, tokenizer; chunk_size=32) |
| 658 | +
|
| 659 | +# x = Int32.(rand(1:(config.vocab_size), 4, 32)) |
| 660 | +# x = reshape(encode(tok, "Mamba is the "), :, 1) |
433 | 661 |
|
434 | | -x = Int32.(rand(1:(config.vocab_size), 4, 32)) |
| 662 | +# res = model(x, ps_from_hgf, st) |
435 | 663 |
|
436 | | -model(x, ps_from_hgf, st) |
| 664 | +# decode(tok, [argmax(res[1][:, end, 1])]) |
| 665 | +
|
| 666 | +# ## Entry Point |
| 667 | +
|
| 668 | +if abspath(PROGRAM_FILE) == @__FILE__ |
| 669 | + main() |
| 670 | +end |
0 commit comments