Cache MPSGraph instances for matmul to reduce overhead#722
Merged
christiangnrd merged 10 commits intoJuliaGPU:mainfrom Jan 1, 2026
Merged
Cache MPSGraph instances for matmul to reduce overhead#722christiangnrd merged 10 commits intoJuliaGPU:mainfrom
christiangnrd merged 10 commits intoJuliaGPU:mainfrom
Conversation
This comment was marked as off-topic.
This comment was marked as off-topic.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #722 +/- ##
==========================================
+ Coverage 80.90% 81.28% +0.38%
==========================================
Files 59 62 +3
Lines 2896 2918 +22
==========================================
+ Hits 2343 2372 +29
+ Misses 553 546 -7 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Contributor
There was a problem hiding this comment.
Metal Benchmarks
Details
| Benchmark suite | Current: 6baaff0 | Previous: 67d668c | Ratio |
|---|---|---|---|
latency/precompile |
24946982834 ns |
24820843000 ns |
1.01 |
latency/ttfp |
2278260208 ns |
2257593833 ns |
1.01 |
latency/import |
1445704333 ns |
1431203750 ns |
1.01 |
integration/metaldevrt |
828875 ns |
834875 ns |
0.99 |
integration/byval/slices=1 |
1538875 ns |
1525666.5 ns |
1.01 |
integration/byval/slices=3 |
8816916.5 ns |
8498958 ns |
1.04 |
integration/byval/reference |
1527833 ns |
1538166 ns |
0.99 |
integration/byval/slices=2 |
2577416 ns |
2552562 ns |
1.01 |
kernel/indexing |
591208 ns |
593833 ns |
1.00 |
kernel/indexing_checked |
617417 ns |
575750 ns |
1.07 |
kernel/launch |
11375 ns |
11250 ns |
1.01 |
kernel/rand |
555375 ns |
557187.5 ns |
1.00 |
array/construct |
6000 ns |
6000 ns |
1 |
array/broadcast |
588708 ns |
591209 ns |
1.00 |
array/random/randn/Float32 |
782958 ns |
836917 ns |
0.94 |
array/random/randn!/Float32 |
610375 ns |
619542 ns |
0.99 |
array/random/rand!/Int64 |
543812.5 ns |
548834 ns |
0.99 |
array/random/rand!/Float32 |
567459 ns |
593333 ns |
0.96 |
array/random/rand/Int64 |
763708 ns |
735667 ns |
1.04 |
array/random/rand/Float32 |
585312.5 ns |
631792 ns |
0.93 |
array/accumulate/Int64/1d |
1256896 ns |
1237125 ns |
1.02 |
array/accumulate/Int64/dims=1 |
1808541 ns |
1795625 ns |
1.01 |
array/accumulate/Int64/dims=2 |
2113812 ns |
2130458 ns |
0.99 |
array/accumulate/Int64/dims=1L |
11640833 ns |
11609562.5 ns |
1.00 |
array/accumulate/Int64/dims=2L |
9605167 ns |
9610834 ns |
1.00 |
array/accumulate/Float32/1d |
1123500 ns |
1111187.5 ns |
1.01 |
array/accumulate/Float32/dims=1 |
1527959 ns |
1518146 ns |
1.01 |
array/accumulate/Float32/dims=2 |
1825479 ns |
1836167 ns |
0.99 |
array/accumulate/Float32/dims=1L |
9787667 ns |
9757375 ns |
1.00 |
array/accumulate/Float32/dims=2L |
7206166 ns |
7203562.5 ns |
1.00 |
array/reductions/reduce/Int64/1d |
1506312.5 ns |
1498333 ns |
1.01 |
array/reductions/reduce/Int64/dims=1 |
1067334 ns |
1076542 ns |
0.99 |
array/reductions/reduce/Int64/dims=2 |
1133791 ns |
1129417 ns |
1.00 |
array/reductions/reduce/Int64/dims=1L |
2001604 ns |
2002083.5 ns |
1.00 |
array/reductions/reduce/Int64/dims=2L |
4194708 ns |
4214895.5 ns |
1.00 |
array/reductions/reduce/Float32/1d |
1000584 ns |
991375 ns |
1.01 |
array/reductions/reduce/Float32/dims=1 |
799125 ns |
827000 ns |
0.97 |
array/reductions/reduce/Float32/dims=2 |
832209 ns |
833917 ns |
1.00 |
array/reductions/reduce/Float32/dims=1L |
1316416 ns |
1305125 ns |
1.01 |
array/reductions/reduce/Float32/dims=2L |
1773750 ns |
1788375 ns |
0.99 |
array/reductions/mapreduce/Int64/1d |
1512375 ns |
1549292 ns |
0.98 |
array/reductions/mapreduce/Int64/dims=1 |
1079042 ns |
1085333 ns |
0.99 |
array/reductions/mapreduce/Int64/dims=2 |
1120375 ns |
1201959 ns |
0.93 |
array/reductions/mapreduce/Int64/dims=1L |
1987333 ns |
2019583 ns |
0.98 |
array/reductions/mapreduce/Int64/dims=2L |
3604917 ns |
3628521 ns |
0.99 |
array/reductions/mapreduce/Float32/1d |
960000 ns |
1036542 ns |
0.93 |
array/reductions/mapreduce/Float32/dims=1 |
810375 ns |
819667 ns |
0.99 |
array/reductions/mapreduce/Float32/dims=2 |
833083 ns |
843917 ns |
0.99 |
array/reductions/mapreduce/Float32/dims=1L |
1287000 ns |
1280500 ns |
1.01 |
array/reductions/mapreduce/Float32/dims=2L |
1849625 ns |
1784500 ns |
1.04 |
array/private/copyto!/gpu_to_gpu |
628584 ns |
635375 ns |
0.99 |
array/private/copyto!/cpu_to_gpu |
776270.5 ns |
786625 ns |
0.99 |
array/private/copyto!/gpu_to_cpu |
787334 ns |
773833 ns |
1.02 |
array/private/iteration/findall/int |
1594541.5 ns |
1620458 ns |
0.98 |
array/private/iteration/findall/bool |
1440625 ns |
1430125 ns |
1.01 |
array/private/iteration/findfirst/int |
2034104 ns |
2024937.5 ns |
1.00 |
array/private/iteration/findfirst/bool |
1992459 ns |
2010916 ns |
0.99 |
array/private/iteration/scalar |
4540541 ns |
5600375 ns |
0.81 |
array/private/iteration/logical |
2527250 ns |
2504521 ns |
1.01 |
array/private/iteration/findmin/1d |
2193666 ns |
2209917 ns |
0.99 |
array/private/iteration/findmin/2d |
1511791 ns |
1498584 ns |
1.01 |
array/private/copy |
553084 ns |
558312.5 ns |
0.99 |
array/shared/copyto!/gpu_to_gpu |
83125 ns |
82042 ns |
1.01 |
array/shared/copyto!/cpu_to_gpu |
81750 ns |
79750 ns |
1.03 |
array/shared/copyto!/gpu_to_cpu |
81500 ns |
82125 ns |
0.99 |
array/shared/iteration/findall/int |
1589250 ns |
1600354 ns |
0.99 |
array/shared/iteration/findall/bool |
1435000 ns |
1452458 ns |
0.99 |
array/shared/iteration/findfirst/int |
1627875 ns |
1621520.5 ns |
1.00 |
array/shared/iteration/findfirst/bool |
1621104.5 ns |
1607916.5 ns |
1.01 |
array/shared/iteration/scalar |
199604.5 ns |
202916 ns |
0.98 |
array/shared/iteration/logical |
2428666 ns |
2386416.5 ns |
1.02 |
array/shared/iteration/findmin/1d |
1798041 ns |
1799396 ns |
1.00 |
array/shared/iteration/findmin/2d |
1510125 ns |
1500416.5 ns |
1.01 |
array/shared/copy |
243167 ns |
230791 ns |
1.05 |
array/permutedims/4d |
2338854.5 ns |
2358000 ns |
0.99 |
array/permutedims/2d |
1125917 ns |
1133208 ns |
0.99 |
array/permutedims/3d |
1644250 ns |
1645604 ns |
1.00 |
metal/synchronization/stream |
17500 ns |
18500 ns |
0.95 |
metal/synchronization/context |
18458 ns |
19625 ns |
0.94 |
This comment was automatically generated by workflow using github-action-benchmark.
christiangnrd
requested changes
Dec 8, 2025
Member
christiangnrd
left a comment
There was a problem hiding this comment.
This is going to be great! Thanks for taking the time to write this up.
MPSGraph construction takes ~2ms per call, which dominated matmul latency for the MPSGraph path. This adds a thread-safe cache keyed by structural parameters (shapes, types, transpose flags, alpha/beta). Performance impact by use case: FASTER (3-7x improvement on subsequent calls): - Large matrices (>6000x6000 Float32, >2000x2000 Integer) - Mixed-precision matmul (Int8->Float32, Float16->Float32) - Matrix-vector multiplication with supported types - Explicit `Metal.@with Metal.matmul_alg => :MPSGraph` usage - Batched matrix multiplication (3D+ arrays) UNCHANGED (uses MPS path, not affected): - Small/medium Float32 matrices (<=6000x6000 on Apple9+ GPUs) - Small Integer matrices (<=2000x2000) - Most typical ML inference workloads SLIGHTLY SLOWER on first call only: - First matmul of each unique shape/type adds cache lookup overhead - Negligible compared to the ~2ms saved on all subsequent calls The cache is process-global and grows with unique configurations. Typical ML workloads use few distinct shapes, so memory overhead is minimal (each cached graph is ~1-2KB).
- Reorder struct fields for consistency (alpha/beta before transpose, place_c before place_a) - Remove dead code (unused get_batch_size helper function) - Revert inadvertent change to broadcast logic (Na==1 vs nBatchA==1) - Update speedup claim in comment to be less specific
194af11 to
6baaff0
Compare
christiangnrd
approved these changes
Jan 1, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds a thread-safe cache for MPSGraph instances in the matmul path, eliminating the graph construction overhead on repeated operations with the same configuration.
Motivation
The Metal.jl 1.6 release notes acknowledged that "for simple operations MPSGraph requires a lot of extra boilerplate without much benefit." This PR addresses that overhead by caching the compiled graphs, making the MPSGraph path viable for repeated operations.
Benchmarks
MPSGraph matmul performance is now essentially the same or better than with the MPS implementation, even at small matrix sizes. See #566 for a graph of previous performance.
