Skip to content

feat: add Metal support#126

Open
haakon-e wants to merge 1 commit intomainfrom
he/feat-add-metal-support
Open

feat: add Metal support#126
haakon-e wants to merge 1 commit intomainfrom
he/feat-add-metal-support

Conversation

@haakon-e
Copy link
Copy Markdown
Member

@haakon-e haakon-e commented Feb 3, 2026

Purpose

This pull request adds support for Apple's Metal GPU backend to ClimaComms

Content

The main changes include introducing the MetalDevice type, implementing Metal-specific device and kernel launch logic, updating device selection and backend loading, and expanding tests to cover Metal. Additionally, the kernel launch parameter computation logic is refactored and centralized for use by both CUDA and Metal backends.

Metal backend support:

  • Added MetalDevice type and documentation to src/devices.jl to support Apple GPU acceleration.
  • Implemented Metal-specific device methods, kernel launching, and integration in new extension file ext/ClimaCommsMetalExt.jl.

Device selection and backend loading:

  • Updated device selection logic to recognize Metal as a valid value for CLIMACOMMS_DEVICE, with appropriate error handling and backend loading in src/devices.jl and src/loading.jl.

Kernel launch parameter computation:

  • Centralized and refactored kernel launch parameter calculation into helper functions (_compute_launch_params_simple and _compute_launch_params_coarsened) in src/devices.jl, now used by both CUDA and Metal backends.

CUDA backend improvements:

  • Refactored CUDA device and kernel launching code in ext/ClimaCommsCUDAExt.jl for consistency with Metal backend, using the new launch parameter helpers and simplifying function definitions.

Testing and compatibility:

  • Expanded test coverage to include Metal, with conditional handling for unsupported features (e.g., skipping Float64 tests for Metal) and updating array creation to use Float32 for compatibility.
  • Note: Since we do not have access to remote Metal-supported GPUs, support cannot be continuously verified in CI. As such, this PR is to be considered experimental.

  • I have read and checked the items on the review checklist.

@haakon-e haakon-e force-pushed the he/feat-add-metal-support branch 2 times, most recently from 4c59755 to e85e43a Compare February 3, 2026 19:52
@haakon-e haakon-e force-pushed the he/feat-add-metal-support branch from e85e43a to 6b22195 Compare February 19, 2026 22:11
@haakon-e haakon-e force-pushed the he/feat-add-metal-support branch from 6b22195 to e51b5b6 Compare March 24, 2026 00:50
@haakon-e
Copy link
Copy Markdown
Member Author

Run tests from your local machine with:

# from ClimaComms.jl root

julia --project=test -e 'using Pkg; Pkg.develop(;path="."); Pkg.add("Metal"); Pkg.instantiate(;verbose=true); Pkg.precompile(;strict=true)'

# run tests

CLIMACOMMS_DEVICE=Metal julia --proj=test test/runtests.jl

Output (all tests pass):

┌ Info: Running test
│   context = ClimaComms.SingletonCommsContext{ClimaComms.MetalDevice}(ClimaComms.MetalDevice())
│   device = ClimaComms.MetalDevice()
└   AT = Metal.MtlArray
  0.260102 seconds (1.08 M CPU allocations: 64.569 MiB) (18 GPU allocations: 34.333 MiB, 0.34% memmgmt time)
  0.004672 seconds (515 CPU allocations: 11.501 MiB) (18 GPU allocations: 34.333 MiB, 15.53% memmgmt time)
Test Summary: | Total  Time
macro hygiene |     0  8.9s
Test Summary: | Pass  Total  Time
tree test ()  |    1      1  1.5s
Test Summary:  | Total  Time
linear test () |     0  0.0s
Test Summary: | Pass  Total  Time
gather        |    2      2  1.0s
Test Summary:            | Pass  Total  Time
reduce/reduce!/allreduce |    6      6  0.5s
Test Summary: | Pass  Total  Time
bcast         |    4      4  0.0s
Test Summary: | Pass  Total  Time
allowscalar   |    1      1  0.0s
Test Summary: | Pass  Total  Time
threaded      |    5      5  1.4s
Test Summary:                | Pass  Total  Time
threaded with lazy iterators |    7      7  2.1s
Test Summary:                    | Pass  Total  Time
threaded with multiple iterators |    3      3  2.1s
Test Summary: | Pass  Total  Time
Adapt         |    2      2  0.0s
Context: SingletonCommsContext
Device: ClimaComms.MetalDevice
Test Summary: | Pass  Total  Time
Logging       |   12     12  0.3s
[ Info: Benchmarking n-th derivative along first axis of a 100×10000 matrix
[ Info: reference identity copy, n = 0:
[ Info:     Latency = 0.558 s
[ Info:     Time = 0.000623 s
[ Info: @threaded identity copy, n = 0:
[ Info:     Latency = 0.000554 s
[ Info:     Time = 0.000226 s
[ Info: reference derivative (broadcast over matrix of indices), n = 0:
[ Info:     Latency = 0.236 s
[ Info:     Time = 0.000499 s
[ Info: reference derivative (broadcast over matrix of indices), n = 2:
[ Info:     Time = 0.0005 s
[ Info: reference derivative (broadcast over matrix of indices), n = 6:
[ Info:     Time = 0.00106 s
[ Info: @threaded derivative (no shared memory), n = 0:
[ Info:     Latency = 0.283 s
[ Info:     Time = 0.000855 s
[ Info: @threaded derivative (no shared memory), n = 2:
[ Info:     Time = 0.000444 s
[ Info: @threaded derivative (no shared memory), n = 6:
[ Info:     Time = 0.00107 s

@haakon-e haakon-e force-pushed the he/feat-add-metal-support branch from e51b5b6 to 17fe38d Compare March 24, 2026 19:00
@haakon-e
Copy link
Copy Markdown
Member Author

haakon-e commented Mar 24, 2026

TODO: Add method for

shared_memory(::Device, ::Type{T}, dims...) where {T} = Metal.Shared[Static/Dynamic]Array

@haakon-e haakon-e force-pushed the he/feat-add-metal-support branch from 17fe38d to 7806711 Compare March 24, 2026 21:21
@haakon-e
Copy link
Copy Markdown
Member Author

TODO: Add method for

shared_memory(::Device, ::Type{T}, dims...) where {T} = Metal.Shared[Static/Dynamic]Array

If it's ok, I'll add this in a next PR, once I understand what's going on in ClimaCore a little better.

Copy link
Copy Markdown
Member Author

@haakon-e haakon-e left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments to assist review

Comment thread ext/ClimaCommsCUDAExt.jl
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updates in this function is mostly formatting changes, with a few exceptions I'll outline below

Comment thread ext/ClimaCommsCUDAExt.jl
max_blocks = grid_size_limit(kernel)
max_threads_in_block = block_size_limit(block_size, kernel)

params = ClimaComms._compute_launch_params_simple(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defined _compute_launch_params_simple in src/devices.jl for re-use in the metal extension. Should be equivalent to the existing code

Comment thread ext/ClimaCommsCUDAExt.jl
threads_in_block = min(max_threads_in_block, max_required_threads)
blocks = cld(n_items, items_in_thread * threads_in_block)
kernel(; blocks, threads = threads_in_block)
params = ClimaComms._compute_launch_params_coarsened(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defined _compute_launch_params_coarsened in src/devices.jl for re-use in the metal extension.

Comment thread ext/ClimaCommsMetalExt.jl
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you open this file next to the ext/ClimaCommsCUDAExt.jl file, they should look pretty identical apart from renaming CUDA to Metal.

Comment thread test/hygiene.jl
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed all float types to Float32 so this file is compatible with Metal backend, for testing purposes.

Comment thread test/runtests.jl
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed all float types to Float32 so this file is compatible with Metal backend, for testing purposes, and/or skip Float64 tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant