Skip to content

Commit e85e43a

Browse files
committed
feat: add Metal support
1 parent b4fc0bd commit e85e43a

File tree

10 files changed

+247
-72
lines changed

10 files changed

+247
-72
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ steps:
1111
key: "initialize"
1212
command:
1313
- echo "--- Instantiate project"
14-
- julia --project=test -e 'using Pkg; Pkg.develop(;path="."); Pkg.add("CUDA"); Pkg.add("MPI"); Pkg.instantiate(;verbose=true); Pkg.precompile(;strict=true)'
14+
- julia --project=test -e 'using Pkg; Pkg.develop(;path="."); Pkg.add([PackageSpec("CUDA"), PackageSpec("MPI"), PackageSpec("Metal")]); Pkg.instantiate(;verbose=true); Pkg.precompile(;strict=true)'
1515
# force the initialization of the CUDA runtime as it is lazily loaded by default
1616
- "julia --project=test -e 'using CUDA; CUDA.precompile_runtime()'"
1717
- "julia --project=test -e 'using Pkg; Pkg.status()'"

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ ClimaComms.jl Release Notes
44
main
55
-------
66

7+
v0.6.11
8+
- Added Metal support [PR 126](https://github.com/CliMA/ClimaComms.jl/pull/126)
9+
- NOTE: This is considered experimental as it is not continuously tested in CI.
10+
711
v0.6.10
812
-------
913
- fixed logging interoperability with `GPUCompiler.jl` [PR 119](https://github.com/CliMA/ClimaComms.jl/pull/119)

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "ClimaComms"
22
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
3-
authors = ["Kiran Pamnany <clima-software@caltech.edu>", "Simon Byrne <simonbyrne@caltech.edu>", "Charles Kawczynski <charliek@caltech.edu>", "Sriharsha Kandala <Sriharsha.kvs@gmail.com>", "Jake Bolewski <clima-software@caltech.edu>", "Gabriele Bozzola <gbozzola@caltech.edu>"]
4-
version = "0.6.10"
3+
authors = ["Kiran Pamnany <clima-software@caltech.edu>", "Simon Byrne <simonbyrne@caltech.edu>", "Charles Kawczynski <charliek@caltech.edu>", "Sriharsha Kandala <Sriharsha.kvs@gmail.com>", "Jake Bolewski <clima-software@caltech.edu>", "Gabriele Bozzola <gbozzola@caltech.edu>", "Haakon Ludvig Langeland Ervik <45243236+haakon-e@users.noreply.github.com>"]
4+
version = "0.6.11"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
99
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
10+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
1011

1112
[weakdeps]
1213
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -15,11 +16,13 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1516
[extensions]
1617
ClimaCommsCUDAExt = "CUDA"
1718
ClimaCommsMPIExt = "MPI"
19+
ClimaCommsMetalExt = "Metal"
1820

1921
[compat]
2022
CUDA = "3, 4, 5"
2123
Adapt = "3, 4"
2224
Logging = "1.9.4"
2325
LoggingExtras = "1.1.0"
2426
MPI = "0.20.18"
27+
Metal = "1"
2528
julia = "1.9"

ext/ClimaCommsCUDAExt.jl

Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,13 @@ function Base.summary(io::IO, ::CUDADevice)
1818
return "$name ($uuid)"
1919
end
2020

21-
function ClimaComms.device_functional(::CUDADevice)
22-
return CUDA.functional()
23-
end
21+
ClimaComms.device_functional(::CUDADevice) = CUDA.functional()
2422

25-
function Adapt.adapt_structure(
26-
to::Type{<:CUDA.CuArray},
27-
ctx::ClimaComms.AbstractCommsContext,
28-
)
29-
return ClimaComms.context(Adapt.adapt(to, ClimaComms.device(ctx)))
30-
end
23+
Adapt.adapt_structure(to::Type{<:CUDA.CuArray}, ctx::ClimaComms.AbstractCommsContext) =
24+
ClimaComms.context(Adapt.adapt(to, ClimaComms.device(ctx)))
3125

32-
Adapt.adapt_structure(
33-
::Type{<:CUDA.CuArray},
34-
device::ClimaComms.AbstractDevice,
35-
) = ClimaComms.CUDADevice()
26+
Adapt.adapt_structure(::Type{<:CUDA.CuArray}, device::ClimaComms.AbstractDevice) =
27+
ClimaComms.CUDADevice()
3628

3729
ClimaComms.array_type(::CUDADevice) = CUDA.CuArray
3830
ClimaComms.free_memory(::CUDADevice) = CUDA.free_memory()
@@ -56,57 +48,44 @@ ClimaComms.assert(::CUDADevice, cond::C, text::T) where {C, T} =
5648
threads_in_kernel() = CUDA.blockDim().x * CUDA.gridDim().x
5749

5850
# The index of the calling thread, which is between 1 and threads_in_kernel().
59-
thread_index() =
60-
(CUDA.blockIdx().x - 1) * CUDA.blockDim().x + CUDA.threadIdx().x
51+
thread_index() = (CUDA.blockIdx().x - 1) * CUDA.blockDim().x + CUDA.threadIdx().x
6152

6253
# The maximum number of blocks that can fit on the GPU used for this kernel.
6354
grid_size_limit(kernel) = CUDA.attribute(
64-
CUDA.device(kernel.fun.mod.ctx),
65-
CUDA.DEVICE_ATTRIBUTE_MAX_GRID_DIM_X,
55+
CUDA.device(kernel.fun.mod.ctx), CUDA.DEVICE_ATTRIBUTE_MAX_GRID_DIM_X,
6656
)
6757

6858
# Either the first value if it is available, or the maximum number of threads
6959
# that can fit in one block of this kernel (cuOccupancyMaxPotentialBlockSize).
7060
# With enough blocks, the latter value will maximize the occupancy of the GPU.
7161
block_size_limit(max_threads_in_block::Int, _) = max_threads_in_block
72-
block_size_limit(::Val{:auto}, kernel) =
73-
CUDA.launch_configuration(kernel.fun).threads
62+
block_size_limit(::Val{:auto}, kernel) = CUDA.launch_configuration(kernel.fun).threads
7463

75-
function ClimaComms.run_threaded(
76-
f::F,
77-
::CUDADevice,
78-
::Val,
79-
itr;
80-
block_size,
81-
) where {F}
64+
function ClimaComms.run_threaded(f::F, ::CUDADevice, ::Val, itr; block_size) where {F}
8265
n_items = length(itr)
8366
n_items > 0 || return nothing
8467

8568
function call_f_from_thread()
8669
item_index = thread_index()
87-
item_index <= n_items &&
88-
@inbounds f(itr[firstindex(itr) + item_index - 1])
70+
item_index <= n_items && @inbounds f(itr[firstindex(itr) + item_index - 1])
8971
return nothing
9072
end
9173
kernel = CUDA.@cuda always_inline=true launch=false call_f_from_thread()
9274
max_blocks = grid_size_limit(kernel)
9375
max_threads_in_block = block_size_limit(block_size, kernel)
9476

77+
params = ClimaComms._compute_launch_params_simple(
78+
n_items, max_blocks, max_threads_in_block,
79+
)
9580
# If there are too many items, coarsen by the smallest possible amount.
96-
n_items <= max_blocks * max_threads_in_block ||
81+
isnothing(params) &&
9782
return ClimaComms.run_threaded(f, CUDADevice(), 1, itr; block_size)
9883

99-
threads_in_block = min(max_threads_in_block, n_items)
100-
blocks = cld(n_items, threads_in_block)
101-
kernel(; blocks, threads = threads_in_block)
84+
kernel(; params.blocks, threads = params.threads_in_block)
10285
end
10386

10487
function ClimaComms.run_threaded(
105-
f::F,
106-
::CUDADevice,
107-
min_items_in_thread::Int,
108-
itr;
109-
block_size,
88+
f::F, ::CUDADevice, min_items_in_thread::Int, itr; block_size,
11089
) where {F}
11190
min_items_in_thread > 0 || throw(ArgumentError("`coarsen` is not positive"))
11291
n_items = length(itr)
@@ -122,16 +101,10 @@ function ClimaComms.run_threaded(
122101
max_blocks = grid_size_limit(kernel)
123102
max_threads_in_block = block_size_limit(block_size, kernel)
124103

125-
# If there are too many items to use the specified coarsening, increase it
126-
# by the smallest possible amount.
127-
max_required_threads = cld(n_items, min_items_in_thread)
128-
items_in_thread =
129-
max_required_threads <= max_blocks * max_threads_in_block ?
130-
min_items_in_thread : cld(n_items, max_blocks * max_threads_in_block)
131-
132-
threads_in_block = min(max_threads_in_block, max_required_threads)
133-
blocks = cld(n_items, items_in_thread * threads_in_block)
134-
kernel(; blocks, threads = threads_in_block)
104+
params = ClimaComms._compute_launch_params_coarsened(
105+
n_items, max_blocks, max_threads_in_block, min_items_in_thread,
106+
)
107+
kernel(; params.blocks, threads = params.threads_in_block)
135108
end
136109

137110
end

ext/ClimaCommsMetalExt.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
module ClimaCommsMetalExt
2+
3+
import Metal
4+
5+
import Adapt
6+
import ClimaComms
7+
import ClimaComms: MetalDevice
8+
9+
# Metal automatically manages device assignment, so this is a no-op
10+
ClimaComms._assign_device(::MetalDevice, rank_number) = nothing
11+
12+
function Base.summary(io::IO, ::MetalDevice)
13+
dev = Metal.device()
14+
name = dev.name
15+
return "$name (Metal)"
16+
end
17+
18+
ClimaComms.device_functional(::MetalDevice) = !isempty(Metal.devices())
19+
20+
Adapt.adapt_structure(to::Type{<:Metal.MtlArray}, ctx::ClimaComms.AbstractCommsContext) =
21+
ClimaComms.context(Adapt.adapt(to, ClimaComms.device(ctx)))
22+
23+
Adapt.adapt_structure(::Type{<:Metal.MtlArray}, device::ClimaComms.AbstractDevice) =
24+
ClimaComms.MetalDevice()
25+
26+
ClimaComms.array_type(::MetalDevice) = Metal.MtlArray
27+
ClimaComms.free_memory(::MetalDevice) = Metal.device().currentAllocatedSize
28+
ClimaComms.total_memory(::MetalDevice) = Metal.device().maxBufferLength
29+
ClimaComms.allowscalar(f, ::MetalDevice, args...; kwargs...) =
30+
Metal.@allowscalar f(args...; kwargs...)
31+
32+
# Extending ClimaComms methods that operate on expressions (cannot use dispatch here)
33+
ClimaComms.sync(f::F, ::MetalDevice, args...; kwargs...) where {F} =
34+
Metal.@sync f(args...; kwargs...)
35+
ClimaComms.cuda_sync(f::F, ::MetalDevice, args...; kwargs...) where {F} = # TODO: Rename to `device_sync` to unify `Metal` and `CUDA`
36+
Metal.@sync f(args...; kwargs...)
37+
ClimaComms.time(f::F, ::MetalDevice, args...; kwargs...) where {F} =
38+
Metal.@time f(args...; kwargs...)
39+
ClimaComms.elapsed(f::F, ::MetalDevice, args...; kwargs...) where {F} =
40+
Metal.@elapsed f(args...; kwargs...)
41+
ClimaComms.assert(::MetalDevice, cond::C, text::T) where {C,T} =
42+
isnothing(text) ? (Metal.@assert cond()) : (Metal.@assert cond() text())
43+
44+
# The number of threads in the kernel being executed by the calling thread.
45+
threads_in_kernel() = Metal.threads_per_grid_1d()
46+
47+
# The index of the calling thread, which is between 1 and threads_in_kernel().
48+
thread_index() = Metal.thread_position_in_grid_1d()
49+
50+
# The maximum number of blocks that can fit on the GPU used for this kernel.
51+
# Metal doesn't have a direct equivalent to CUDA's max grid dim, so we use a reasonable default
52+
grid_size_limit(kernel) = 65535
53+
54+
# Either the first value if it is available, or the maximum number of threads
55+
# that can fit in one block of this kernel.
56+
# With enough blocks, the latter value will maximize the occupancy of the GPU.
57+
block_size_limit(max_threads_in_block::Int, _) = max_threads_in_block
58+
block_size_limit(::Val{:auto}, kernel) = Int(kernel.pipeline.maxTotalThreadsPerThreadgroup)
59+
60+
function ClimaComms.run_threaded(f::F, ::MetalDevice, ::Val, itr; block_size) where {F}
61+
n_items = length(itr)
62+
n_items > 0 || return nothing
63+
64+
function call_f_from_thread()
65+
item_index = thread_index()
66+
item_index <= n_items && @inbounds f(itr[firstindex(itr)+item_index-1])
67+
return nothing
68+
end
69+
kernel = Metal.@metal launch = false call_f_from_thread()
70+
max_blocks = grid_size_limit(kernel)
71+
max_threads_in_block = block_size_limit(block_size, kernel)
72+
73+
params = ClimaComms._compute_launch_params_simple(
74+
n_items, max_blocks, max_threads_in_block,
75+
)
76+
# If there are too many items, coarsen by the smallest possible amount.
77+
isnothing(params) &&
78+
return ClimaComms.run_threaded(f, MetalDevice(), 1, itr; block_size)
79+
80+
Metal.@sync kernel(; threads = params.threads_in_block, groups = params.blocks)
81+
end
82+
83+
function ClimaComms.run_threaded(
84+
f::F, ::MetalDevice, min_items_in_thread::Int, itr; block_size
85+
) where {F}
86+
min_items_in_thread > 0 || throw(ArgumentError("`coarsen` is not positive"))
87+
n_items = length(itr)
88+
n_items > 0 || return nothing
89+
90+
# Maximize memory coalescing with a "grid-stride loop"; for reference, see
91+
# https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops
92+
call_f_from_thread() =
93+
for item_index in thread_index():threads_in_kernel():n_items
94+
@inbounds f(itr[firstindex(itr)+item_index-1])
95+
end
96+
kernel = Metal.@metal launch = false call_f_from_thread()
97+
max_blocks = grid_size_limit(kernel)
98+
max_threads_in_block = block_size_limit(block_size, kernel)
99+
100+
params = ClimaComms._compute_launch_params_coarsened(
101+
n_items, max_blocks, max_threads_in_block, min_items_in_thread,
102+
)
103+
Metal.@sync kernel(; threads = params.threads_in_block, groups = params.blocks)
104+
end
105+
106+
end

src/devices.jl

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ Use NVIDIA GPU accelarator
3737
"""
3838
struct CUDADevice <: AbstractDevice end
3939

40+
"""
41+
MetalDevice()
42+
43+
Use Apple GPU accelerator (Metal)
44+
"""
45+
struct MetalDevice <: AbstractDevice end
46+
4047
"""
4148
ClimaComms.device_functional(device)
4249
@@ -57,6 +64,8 @@ function device_type()
5764
return :CPUMultiThreaded
5865
elseif env_var == "CUDA"
5966
return :CUDADevice
67+
elseif env_var == "Metal"
68+
return :MetalDevice
6069
else
6170
error("Invalid CLIMACOMMS_DEVICE: $env_var")
6271
end
@@ -71,7 +80,8 @@ Allowed values:
7180
- `CPU`, single-threaded or multi-threaded depending on the number of threads;
7281
- `CPUSingleThreaded`,
7382
- `CPUMultiThreaded`,
74-
- `CUDA`.
83+
- `CUDA`,
84+
- `Metal`.
7585
7686
The default is `CPU`.
7787
"""
@@ -82,6 +92,11 @@ function device()
8292
"Loading CUDA.jl is required to use CUDADevice. You might want to call ClimaComms.@import_required_backends",
8393
)
8494
end
95+
if target_device == :MetalDevice && !metal_ext_is_loaded()
96+
error(
97+
"Loading Metal.jl is required to use MetalDevice. You might want to call ClimaComms.@import_required_backends",
98+
)
99+
end
85100
DeviceConstructor = getproperty(ClimaComms, target_device)
86101
return DeviceConstructor()
87102
end
@@ -742,3 +757,49 @@ Base.@propagate_inbounds function Base.getindex(
742757
end
743758

744759
# TODO: Check whether conversion of every Int to Int32 improves GPU performance.
760+
761+
# Internal helpers for GPU kernel launch parameters
762+
763+
"""
764+
_compute_launch_params_simple(n_items, max_blocks, max_threads_in_block)
765+
766+
Compute kernel launch parameters (`blocks`, `threads_in_block`) for a simple (1 item per thread)
767+
execution strategy. Returns `nothing` if the number of items exceeds the GPU's capacity for
768+
this strategy (requires coarsening).
769+
770+
Used by `ClimaCommsCUDAExt` and `ClimaCommsMetalExt` in `run_threaded`.
771+
"""
772+
function _compute_launch_params_simple(n_items, max_blocks, max_threads_in_block)
773+
if n_items <= max_blocks * max_threads_in_block
774+
threads_in_block = min(max_threads_in_block, n_items)
775+
blocks = cld(n_items, threads_in_block)
776+
return (; blocks, threads_in_block)
777+
else
778+
return nothing
779+
end
780+
end
781+
782+
"""
783+
_compute_launch_params_coarsened(n_items, max_blocks, max_threads_in_block, min_items_in_thread)
784+
785+
Compute kernel launch parameters (`blocks`, `threads_in_block`) for a coarsened execution strategy,
786+
where each thread processes at least `min_items_in_thread`. This strategy maximizes GPU occupancy
787+
when `n_items` is large.
788+
789+
Used by `ClimaCommsCUDAExt` and `ClimaCommsMetalExt` in `run_threaded`.
790+
"""
791+
function _compute_launch_params_coarsened(
792+
n_items, max_blocks, max_threads_in_block, min_items_in_thread,
793+
)
794+
# If there are too many items to use the specified coarsening, increase it
795+
# by the smallest possible amount.
796+
max_required_threads = cld(n_items, min_items_in_thread)
797+
items_in_thread =
798+
max_required_threads <= max_blocks * max_threads_in_block ?
799+
min_items_in_thread :
800+
cld(n_items, max_blocks * max_threads_in_block)
801+
802+
threads_in_block = min(max_threads_in_block, max_required_threads)
803+
blocks = cld(n_items, items_in_thread * threads_in_block)
804+
return (; blocks, threads_in_block)
805+
end

0 commit comments

Comments
 (0)