-
Notifications
You must be signed in to change notification settings - Fork 3
feat: add Metal support #126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,21 +18,15 @@ function Base.summary(io::IO, ::CUDADevice) | |
| return "$name ($uuid)" | ||
| end | ||
|
|
||
| function ClimaComms.device_functional(::CUDADevice) | ||
| return CUDA.functional() | ||
| end | ||
|
|
||
| function Adapt.adapt_structure( | ||
| to::Type{<:CUDA.CuArray}, | ||
| ctx::ClimaComms.AbstractCommsContext, | ||
| ) | ||
| return ClimaComms.context(Adapt.adapt(to, ClimaComms.device(ctx))) | ||
| end | ||
| ClimaComms.device_functional(::CUDADevice) = CUDA.functional() | ||
|
|
||
| Adapt.adapt_structure( | ||
| ::Type{<:CUDA.CuArray}, | ||
| device::ClimaComms.AbstractDevice, | ||
| ) = ClimaComms.CUDADevice() | ||
| to::Type{<:CUDA.CuArray}, ctx::ClimaComms.AbstractCommsContext, | ||
| ) = | ||
| ClimaComms.context(Adapt.adapt(to, ClimaComms.device(ctx))) | ||
|
|
||
| Adapt.adapt_structure(::Type{<:CUDA.CuArray}, ::ClimaComms.AbstractDevice) = | ||
| ClimaComms.CUDADevice() | ||
|
|
||
| ClimaComms.array_type(::CUDADevice) = CUDA.CuArray | ||
| ClimaComms.free_memory(::CUDADevice) = CUDA.free_memory() | ||
|
|
@@ -61,8 +55,7 @@ thread_index() = | |
|
|
||
| # The maximum number of blocks that can fit on the GPU used for this kernel. | ||
| grid_size_limit(kernel) = CUDA.attribute( | ||
| CUDA.device(kernel.fun.mod.ctx), | ||
| CUDA.DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, | ||
| CUDA.device(kernel.fun.mod.ctx), CUDA.DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, | ||
| ) | ||
|
|
||
| # Either the first value if it is available, or the maximum number of threads | ||
|
|
@@ -73,11 +66,7 @@ block_size_limit(::Val{:auto}, kernel) = | |
| CUDA.launch_configuration(kernel.fun).threads | ||
|
|
||
| function ClimaComms.run_threaded( | ||
| f::F, | ||
| ::CUDADevice, | ||
| ::Val, | ||
| itr; | ||
| block_size, | ||
| f::F, ::CUDADevice, ::Val, itr; block_size, | ||
| ) where {F} | ||
| n_items = length(itr) | ||
| n_items > 0 || return nothing | ||
|
|
@@ -92,21 +81,18 @@ function ClimaComms.run_threaded( | |
| max_blocks = grid_size_limit(kernel) | ||
| max_threads_in_block = block_size_limit(block_size, kernel) | ||
|
|
||
| params = ClimaComms._compute_launch_params_simple( | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Defined |
||
| n_items, max_blocks, max_threads_in_block, | ||
| ) | ||
| # If there are too many items, coarsen by the smallest possible amount. | ||
| n_items <= max_blocks * max_threads_in_block || | ||
| isnothing(params) && | ||
| return ClimaComms.run_threaded(f, CUDADevice(), 1, itr; block_size) | ||
|
|
||
| threads_in_block = min(max_threads_in_block, n_items) | ||
| blocks = cld(n_items, threads_in_block) | ||
| kernel(; blocks, threads = threads_in_block) | ||
| kernel(; params.blocks, threads = params.threads_in_block) | ||
| end | ||
|
|
||
| function ClimaComms.run_threaded( | ||
| f::F, | ||
| ::CUDADevice, | ||
| min_items_in_thread::Int, | ||
| itr; | ||
| block_size, | ||
| f::F, ::CUDADevice, min_items_in_thread::Int, itr; block_size, | ||
| ) where {F} | ||
| min_items_in_thread > 0 || throw(ArgumentError("`coarsen` is not positive")) | ||
| n_items = length(itr) | ||
|
|
@@ -122,16 +108,10 @@ function ClimaComms.run_threaded( | |
| max_blocks = grid_size_limit(kernel) | ||
| max_threads_in_block = block_size_limit(block_size, kernel) | ||
|
|
||
| # If there are too many items to use the specified coarsening, increase it | ||
| # by the smallest possible amount. | ||
| max_required_threads = cld(n_items, min_items_in_thread) | ||
| items_in_thread = | ||
| max_required_threads <= max_blocks * max_threads_in_block ? | ||
| min_items_in_thread : cld(n_items, max_blocks * max_threads_in_block) | ||
|
|
||
| threads_in_block = min(max_threads_in_block, max_required_threads) | ||
| blocks = cld(n_items, items_in_thread * threads_in_block) | ||
| kernel(; blocks, threads = threads_in_block) | ||
| params = ClimaComms._compute_launch_params_coarsened( | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Defined |
||
| n_items, max_blocks, max_threads_in_block, min_items_in_thread, | ||
| ) | ||
| kernel(; params.blocks, threads = params.threads_in_block) | ||
| end | ||
|
|
||
| end | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you open this file next to the |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| module ClimaCommsMetalExt | ||
|
|
||
| import Metal | ||
|
|
||
| import Adapt | ||
| import ClimaComms | ||
| import ClimaComms: MetalDevice | ||
|
|
||
| # Metal automatically manages device assignment, so this is a no-op | ||
| ClimaComms._assign_device(::MetalDevice, rank_number) = nothing | ||
|
|
||
| function Base.summary(io::IO, ::MetalDevice) | ||
| dev = Metal.device() | ||
| name = dev.name | ||
| return "$name (Metal)" | ||
| end | ||
|
|
||
| ClimaComms.device_functional(::MetalDevice) = !isempty(Metal.devices()) | ||
|
|
||
| Adapt.adapt_structure( | ||
| to::Type{<:Metal.MtlArray}, ctx::ClimaComms.AbstractCommsContext, | ||
| ) = | ||
| ClimaComms.context(Adapt.adapt(to, ClimaComms.device(ctx))) | ||
|
|
||
| Adapt.adapt_structure(::Type{<:Metal.MtlArray}, ::ClimaComms.AbstractDevice) = | ||
| ClimaComms.MetalDevice() | ||
|
|
||
| ClimaComms.array_type(::MetalDevice) = Metal.MtlArray | ||
| ClimaComms.free_memory(::MetalDevice) = Metal.device().currentAllocatedSize | ||
| ClimaComms.total_memory(::MetalDevice) = Metal.device().maxBufferLength | ||
| ClimaComms.allowscalar(f, ::MetalDevice, args...; kwargs...) = | ||
| Metal.@allowscalar f(args...; kwargs...) | ||
|
|
||
| # Extending ClimaComms methods that operate on expressions (cannot use dispatch here) | ||
| ClimaComms.sync(f::F, ::MetalDevice, args...; kwargs...) where {F} = | ||
| Metal.@sync f(args...; kwargs...) | ||
| ClimaComms.cuda_sync(f::F, ::MetalDevice, args...; kwargs...) where {F} = # TODO: Rename to `device_sync` to unify `Metal` and `CUDA` | ||
| Metal.@sync f(args...; kwargs...) | ||
| ClimaComms.time(f::F, ::MetalDevice, args...; kwargs...) where {F} = | ||
| Metal.@time f(args...; kwargs...) | ||
| ClimaComms.elapsed(f::F, ::MetalDevice, args...; kwargs...) where {F} = | ||
| Metal.@elapsed f(args...; kwargs...) | ||
| ClimaComms.assert(::MetalDevice, cond::C, text::T) where {C, T} = | ||
| isnothing(text) ? (Metal.@assert cond()) : (Metal.@assert cond() text()) | ||
|
|
||
| # The number of threads in the kernel being executed by the calling thread. | ||
| threads_in_kernel() = Metal.threads_per_grid_1d() | ||
|
|
||
| # The index of the calling thread, which is between 1 and threads_in_kernel(). | ||
| thread_index() = Metal.thread_position_in_grid_1d() | ||
|
|
||
| # The maximum number of blocks that can fit on the GPU used for this kernel. | ||
| # Metal doesn't have a direct equivalent to CUDA's max grid dim, so we use a reasonable default | ||
| grid_size_limit(kernel) = 65535 | ||
|
|
||
| # Either the first value if it is available, or the maximum number of threads | ||
| # that can fit in one block of this kernel. | ||
| # With enough blocks, the latter value will maximize the occupancy of the GPU. | ||
| block_size_limit(max_threads_in_block::Int, _) = max_threads_in_block | ||
| block_size_limit(::Val{:auto}, kernel) = | ||
| Int(kernel.pipeline.maxTotalThreadsPerThreadgroup) | ||
|
|
||
| function ClimaComms.run_threaded( | ||
| f::F, ::MetalDevice, ::Val, itr; block_size, | ||
| ) where {F} | ||
| n_items = length(itr) | ||
| n_items > 0 || return nothing | ||
|
|
||
| function call_f_from_thread() | ||
| item_index = thread_index() | ||
| item_index <= n_items && | ||
| @inbounds f(itr[firstindex(itr) + item_index - 1]) | ||
| return nothing | ||
| end | ||
| kernel = Metal.@metal launch = false call_f_from_thread() | ||
| max_blocks = grid_size_limit(kernel) | ||
| max_threads_in_block = block_size_limit(block_size, kernel) | ||
|
|
||
| params = ClimaComms._compute_launch_params_simple( | ||
| n_items, max_blocks, max_threads_in_block, | ||
| ) | ||
| # If there are too many items, coarsen by the smallest possible amount. | ||
| isnothing(params) && | ||
| return ClimaComms.run_threaded(f, MetalDevice(), 1, itr; block_size) | ||
|
|
||
| Metal.@sync kernel(; | ||
| threads = params.threads_in_block, groups = params.blocks, | ||
| ) | ||
| end | ||
|
|
||
| function ClimaComms.run_threaded( | ||
| f::F, ::MetalDevice, min_items_in_thread::Int, itr; block_size, | ||
| ) where {F} | ||
| min_items_in_thread > 0 || throw(ArgumentError("`coarsen` is not positive")) | ||
| n_items = length(itr) | ||
| n_items > 0 || return nothing | ||
|
|
||
| # Maximize memory coalescing with a "grid-stride loop"; for reference, see | ||
| # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops | ||
| call_f_from_thread() = | ||
| for item_index in thread_index():threads_in_kernel():n_items | ||
| @inbounds f(itr[firstindex(itr) + item_index - 1]) | ||
| end | ||
| kernel = Metal.@metal launch = false call_f_from_thread() | ||
| max_blocks = grid_size_limit(kernel) | ||
| max_threads_in_block = block_size_limit(block_size, kernel) | ||
|
|
||
| params = ClimaComms._compute_launch_params_coarsened( | ||
| n_items, max_blocks, max_threads_in_block, min_items_in_thread, | ||
| ) | ||
| Metal.@sync kernel(; | ||
| threads = params.threads_in_block, groups = params.blocks, | ||
| ) | ||
| end | ||
|
|
||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The updates in this function is mostly formatting changes, with a few exceptions I'll outline below