Skip to content

fix: missing onehotarrays dispatch for cpu matmul#1655

Merged
avik-pal merged 2 commits intomainfrom
ap/onehot
Feb 3, 2026
Merged

fix: missing onehotarrays dispatch for cpu matmul#1655
avik-pal merged 2 commits intomainfrom
ap/onehot

Conversation

@avik-pal
Copy link
Member

@avik-pal avik-pal commented Feb 2, 2026

fixes #1197

@gemini-code-assist
Copy link

Summary of Changes

Hello @avik-pal, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a performance issue in CPU matrix multiplication when dealing with OneHotArrays. It achieves this by implementing a flexible dispatch mechanism that allows LuxLib to leverage the highly optimized mul! function for OneHotLike arrays, rather than relying on its more generic internal routines. This change is crucial for improving the efficiency of operations involving one-hot encoded data.

Highlights

  • Specialized Dispatch Mechanism: Introduced a new Utils.force_3arg_mul!_dispatch function that allows LuxLib to conditionally bypass its internal matrix multiplication logic and instead dispatch to the optimized mul! function for specific array types.
  • OneHotArrays Integration: Added OneHotArrays as a dependency and created a new extension (OneHotArraysExt.jl) that sets Utils.force_3arg_mul!_dispatch to true when the third argument to matmul! is a OneHotLike array, enabling specialized CPU matrix multiplication.
  • Performance Improvement: Modified the matmul! function to utilize this new dispatch mechanism, which is expected to significantly improve performance for operations involving OneHotLike arrays on CPU.
  • New Test Coverage: Added a dedicated test file (onehotarray_tests.jl) to verify the correctness and benchmark the performance gains achieved by the specialized OneHotArrays dispatch.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a performance issue for CPU matrix multiplication when using OneHotArrays by introducing a specialized dispatch for fused_dense_bias_activation. The implementation is clean, leveraging Julia's extension system to handle the optional dependency on OneHotArrays. The addition of tests that verify both correctness and performance is excellent. I have one suggestion to improve the new test file for better structure and to adhere to benchmarking best practices.

@github-actions
Copy link
Contributor

github-actions bot commented Feb 2, 2026

Benchmark Results (Julia v1.12)

Time benchmarks
main dff3f8c... main / dff3f8c...
basics/MHA 3.94 ± 0.37 μs 3.83 ± 0.3 μs 1.03 ± 0.13
basics/MHA (first run) 4.29 ± 0.47 μs 4.03 ± 0.37 μs 1.06 ± 0.15
basics/MHA reactant 0.0742 ± 0.0083 ms 0.0631 ± 0.0059 ms 1.18 ± 0.17
basics/MHA reactant (comp + run) 0.176 ± 0.0043 s 0.168 ± 0.0066 s 1.05 ± 0.048
basics/conv 14.3 ± 27 μs 14.4 ± 23 μs 0.987 ± 2.4
basics/conv (first run) 19.2 ± 27 μs 15 ± 23 μs 1.28 ± 2.7
basics/conv reactant 0.0551 ± 0.0024 ms 0.0544 ± 0.0019 ms 1.01 ± 0.057
basics/conv reactant (comp + run) 0.126 ± 0.0031 s 0.128 ± 0.0047 s 0.982 ± 0.044
basics/dense 0.19 ± 0.001 μs 0.19 ± 0.001 μs 1 ± 0.0074
basics/dense (first run) 0.191 ± 0.01 μs 0.191 ± 0.01 μs 1 ± 0.074
basics/dense reactant 0.0528 ± 0.0029 ms 0.0517 ± 0.0017 ms 1.02 ± 0.065
basics/dense reactant (comp + run) 0.108 ± 0.0054 s 0.108 ± 0.003 s 0.998 ± 0.057
time_to_load 0.649 ± 0.0056 s 0.643 ± 0.016 s 1.01 ± 0.027
Memory benchmarks
main dff3f8c... main / dff3f8c...
basics/MHA 0.087 k allocs: 5.81 kB 0.087 k allocs: 5.81 kB 1
basics/MHA (first run) 0.087 k allocs: 5.81 kB 0.087 k allocs: 5.81 kB 1
basics/MHA reactant 19 allocs: 0.578 kB 19 allocs: 0.578 kB 1
basics/MHA reactant (comp + run) 18.5 k allocs: 1.83 MB 18.5 k allocs: 1.83 MB 1
basics/conv 0.039 k allocs: 4.55 kB 0.039 k allocs: 4.55 kB 1
basics/conv (first run) 0.039 k allocs: 4.55 kB 0.039 k allocs: 4.55 kB 1
basics/conv reactant 15 allocs: 0.438 kB 15 allocs: 0.438 kB 1
basics/conv reactant (comp + run) 6.76 k allocs: 1.29 MB 6.76 k allocs: 1.29 MB 1
basics/dense 2 allocs: 0.109 kB 2 allocs: 0.109 kB 1
basics/dense (first run) 2 allocs: 0.109 kB 2 allocs: 0.109 kB 1
basics/dense reactant 15 allocs: 0.422 kB 15 allocs: 0.422 kB 1
basics/dense reactant (comp + run) 6.18 k allocs: 1.25 MB 6.18 k allocs: 1.25 MB 1
time_to_load 0.145 k allocs: 11 kB 0.145 k allocs: 11 kB 1

@avik-pal avik-pal force-pushed the ap/onehot branch 2 times, most recently from b486970 to 59392ac Compare February 2, 2026 20:22
@codecov
Copy link

codecov bot commented Feb 2, 2026

Codecov Report

❌ Patch coverage is 80.00000% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 75.11%. Comparing base (6ab9a57) to head (dff3f8c).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
lib/LuxLib/ext/OneHotArraysExt.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1655      +/-   ##
==========================================
- Coverage   75.16%   75.11%   -0.05%     
==========================================
  Files         174      175       +1     
  Lines        7182     7182              
==========================================
- Hits         5398     5395       -3     
- Misses       1784     1787       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@avik-pal avik-pal merged commit fa3c732 into main Feb 3, 2026
56 of 57 checks passed
@avik-pal avik-pal deleted the ap/onehot branch February 3, 2026 01:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Warning from LuxLib when using OneHotArrays about Mixed Precision

1 participant