Skip to content

Cache MPSGraph instances for matmul to reduce overhead#722

Merged
christiangnrd merged 10 commits intoJuliaGPU:mainfrom
KaanKesginLW:perf/mpsgraph-caching
Jan 1, 2026
Merged

Cache MPSGraph instances for matmul to reduce overhead#722
christiangnrd merged 10 commits intoJuliaGPU:mainfrom
KaanKesginLW:perf/mpsgraph-caching

Conversation

@KaanKesginLW
Copy link
Contributor

@KaanKesginLW KaanKesginLW commented Dec 5, 2025

Summary

This PR adds a thread-safe cache for MPSGraph instances in the matmul path, eliminating the graph construction overhead on repeated operations with the same configuration.

Motivation

The Metal.jl 1.6 release notes acknowledged that "for simple operations MPSGraph requires a lot of extra boilerplate without much benefit." This PR addresses that overhead by caching the compiled graphs, making the MPSGraph path viable for repeated operations.

Benchmarks

MPSGraph matmul performance is now essentially the same or better than with the MPS implementation, even at small matrix sizes. See #566 for a graph of previous performance.
bench_all_1

@github-actions

This comment was marked as off-topic.

@codecov
Copy link

codecov bot commented Dec 5, 2025

Codecov Report

❌ Patch coverage is 96.77419% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 81.28%. Comparing base (67d668c) to head (6baaff0).

Files with missing lines Patch % Lines
lib/mpsgraphs/matmul.jl 96.77% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #722      +/-   ##
==========================================
+ Coverage   80.90%   81.28%   +0.38%     
==========================================
  Files          59       62       +3     
  Lines        2896     2918      +22     
==========================================
+ Hits         2343     2372      +29     
+ Misses        553      546       -7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Details
Benchmark suite Current: 6baaff0 Previous: 67d668c Ratio
latency/precompile 24946982834 ns 24820843000 ns 1.01
latency/ttfp 2278260208 ns 2257593833 ns 1.01
latency/import 1445704333 ns 1431203750 ns 1.01
integration/metaldevrt 828875 ns 834875 ns 0.99
integration/byval/slices=1 1538875 ns 1525666.5 ns 1.01
integration/byval/slices=3 8816916.5 ns 8498958 ns 1.04
integration/byval/reference 1527833 ns 1538166 ns 0.99
integration/byval/slices=2 2577416 ns 2552562 ns 1.01
kernel/indexing 591208 ns 593833 ns 1.00
kernel/indexing_checked 617417 ns 575750 ns 1.07
kernel/launch 11375 ns 11250 ns 1.01
kernel/rand 555375 ns 557187.5 ns 1.00
array/construct 6000 ns 6000 ns 1
array/broadcast 588708 ns 591209 ns 1.00
array/random/randn/Float32 782958 ns 836917 ns 0.94
array/random/randn!/Float32 610375 ns 619542 ns 0.99
array/random/rand!/Int64 543812.5 ns 548834 ns 0.99
array/random/rand!/Float32 567459 ns 593333 ns 0.96
array/random/rand/Int64 763708 ns 735667 ns 1.04
array/random/rand/Float32 585312.5 ns 631792 ns 0.93
array/accumulate/Int64/1d 1256896 ns 1237125 ns 1.02
array/accumulate/Int64/dims=1 1808541 ns 1795625 ns 1.01
array/accumulate/Int64/dims=2 2113812 ns 2130458 ns 0.99
array/accumulate/Int64/dims=1L 11640833 ns 11609562.5 ns 1.00
array/accumulate/Int64/dims=2L 9605167 ns 9610834 ns 1.00
array/accumulate/Float32/1d 1123500 ns 1111187.5 ns 1.01
array/accumulate/Float32/dims=1 1527959 ns 1518146 ns 1.01
array/accumulate/Float32/dims=2 1825479 ns 1836167 ns 0.99
array/accumulate/Float32/dims=1L 9787667 ns 9757375 ns 1.00
array/accumulate/Float32/dims=2L 7206166 ns 7203562.5 ns 1.00
array/reductions/reduce/Int64/1d 1506312.5 ns 1498333 ns 1.01
array/reductions/reduce/Int64/dims=1 1067334 ns 1076542 ns 0.99
array/reductions/reduce/Int64/dims=2 1133791 ns 1129417 ns 1.00
array/reductions/reduce/Int64/dims=1L 2001604 ns 2002083.5 ns 1.00
array/reductions/reduce/Int64/dims=2L 4194708 ns 4214895.5 ns 1.00
array/reductions/reduce/Float32/1d 1000584 ns 991375 ns 1.01
array/reductions/reduce/Float32/dims=1 799125 ns 827000 ns 0.97
array/reductions/reduce/Float32/dims=2 832209 ns 833917 ns 1.00
array/reductions/reduce/Float32/dims=1L 1316416 ns 1305125 ns 1.01
array/reductions/reduce/Float32/dims=2L 1773750 ns 1788375 ns 0.99
array/reductions/mapreduce/Int64/1d 1512375 ns 1549292 ns 0.98
array/reductions/mapreduce/Int64/dims=1 1079042 ns 1085333 ns 0.99
array/reductions/mapreduce/Int64/dims=2 1120375 ns 1201959 ns 0.93
array/reductions/mapreduce/Int64/dims=1L 1987333 ns 2019583 ns 0.98
array/reductions/mapreduce/Int64/dims=2L 3604917 ns 3628521 ns 0.99
array/reductions/mapreduce/Float32/1d 960000 ns 1036542 ns 0.93
array/reductions/mapreduce/Float32/dims=1 810375 ns 819667 ns 0.99
array/reductions/mapreduce/Float32/dims=2 833083 ns 843917 ns 0.99
array/reductions/mapreduce/Float32/dims=1L 1287000 ns 1280500 ns 1.01
array/reductions/mapreduce/Float32/dims=2L 1849625 ns 1784500 ns 1.04
array/private/copyto!/gpu_to_gpu 628584 ns 635375 ns 0.99
array/private/copyto!/cpu_to_gpu 776270.5 ns 786625 ns 0.99
array/private/copyto!/gpu_to_cpu 787334 ns 773833 ns 1.02
array/private/iteration/findall/int 1594541.5 ns 1620458 ns 0.98
array/private/iteration/findall/bool 1440625 ns 1430125 ns 1.01
array/private/iteration/findfirst/int 2034104 ns 2024937.5 ns 1.00
array/private/iteration/findfirst/bool 1992459 ns 2010916 ns 0.99
array/private/iteration/scalar 4540541 ns 5600375 ns 0.81
array/private/iteration/logical 2527250 ns 2504521 ns 1.01
array/private/iteration/findmin/1d 2193666 ns 2209917 ns 0.99
array/private/iteration/findmin/2d 1511791 ns 1498584 ns 1.01
array/private/copy 553084 ns 558312.5 ns 0.99
array/shared/copyto!/gpu_to_gpu 83125 ns 82042 ns 1.01
array/shared/copyto!/cpu_to_gpu 81750 ns 79750 ns 1.03
array/shared/copyto!/gpu_to_cpu 81500 ns 82125 ns 0.99
array/shared/iteration/findall/int 1589250 ns 1600354 ns 0.99
array/shared/iteration/findall/bool 1435000 ns 1452458 ns 0.99
array/shared/iteration/findfirst/int 1627875 ns 1621520.5 ns 1.00
array/shared/iteration/findfirst/bool 1621104.5 ns 1607916.5 ns 1.01
array/shared/iteration/scalar 199604.5 ns 202916 ns 0.98
array/shared/iteration/logical 2428666 ns 2386416.5 ns 1.02
array/shared/iteration/findmin/1d 1798041 ns 1799396 ns 1.00
array/shared/iteration/findmin/2d 1510125 ns 1500416.5 ns 1.01
array/shared/copy 243167 ns 230791 ns 1.05
array/permutedims/4d 2338854.5 ns 2358000 ns 0.99
array/permutedims/2d 1125917 ns 1133208 ns 0.99
array/permutedims/3d 1644250 ns 1645604 ns 1.00
metal/synchronization/stream 17500 ns 18500 ns 0.95
metal/synchronization/context 18458 ns 19625 ns 0.94

This comment was automatically generated by workflow using github-action-benchmark.

Copy link
Member

@christiangnrd christiangnrd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be great! Thanks for taking the time to write this up.

@christiangnrd christiangnrd added the performance Gotta go fast. label Dec 8, 2025
KaanKesginLW and others added 7 commits January 1, 2026 15:14
MPSGraph construction takes ~2ms per call, which dominated matmul
latency for the MPSGraph path. This adds a thread-safe cache keyed
by structural parameters (shapes, types, transpose flags, alpha/beta).

Performance impact by use case:

FASTER (3-7x improvement on subsequent calls):
- Large matrices (>6000x6000 Float32, >2000x2000 Integer)
- Mixed-precision matmul (Int8->Float32, Float16->Float32)
- Matrix-vector multiplication with supported types
- Explicit `Metal.@with Metal.matmul_alg => :MPSGraph` usage
- Batched matrix multiplication (3D+ arrays)

UNCHANGED (uses MPS path, not affected):
- Small/medium Float32 matrices (<=6000x6000 on Apple9+ GPUs)
- Small Integer matrices (<=2000x2000)
- Most typical ML inference workloads

SLIGHTLY SLOWER on first call only:
- First matmul of each unique shape/type adds cache lookup overhead
- Negligible compared to the ~2ms saved on all subsequent calls

The cache is process-global and grows with unique configurations.
Typical ML workloads use few distinct shapes, so memory overhead
is minimal (each cached graph is ~1-2KB).
- Reorder struct fields for consistency (alpha/beta before transpose, place_c before place_a)
- Remove dead code (unused get_batch_size helper function)
- Revert inadvertent change to broadcast logic (Na==1 vs nBatchA==1)
- Update speedup claim in comment to be less specific
@christiangnrd christiangnrd force-pushed the perf/mpsgraph-caching branch from 194af11 to 6baaff0 Compare January 1, 2026 20:06
@christiangnrd christiangnrd enabled auto-merge (squash) January 1, 2026 22:50
@christiangnrd christiangnrd merged commit 36873f1 into JuliaGPU:main Jan 1, 2026
15 of 17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Gotta go fast.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants