Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.6.9"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -17,9 +18,10 @@ ClimaCommsCUDAExt = "CUDA"
ClimaCommsMPIExt = "MPI"

[compat]
CUDA = "3, 4, 5"
Adapt = "3, 4"
CUDA = "3, 4, 5"
Logging = "1.9.4"
LoggingExtras = "1.1.0"
StaticArrays = "1.9"
MPI = "0.20.18"
julia = "1.9"
29 changes: 23 additions & 6 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.11.0"
julia_version = "1.11.5"
manifest_format = "2.0"
project_hash = "d60839f726bd9115791d1a0807a21b61938765a9"

Expand All @@ -19,13 +19,11 @@ deps = ["LinearAlgebra", "Requires"]
git-tree-sha1 = "50c3c56a52972d78e8be9fd135bfb91c9574c140"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "4.1.1"
weakdeps = ["StaticArrays"]

[deps.Adapt.extensions]
AdaptStaticArraysExt = "StaticArrays"

[deps.Adapt.weakdeps]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.2"
Expand All @@ -39,10 +37,10 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
version = "1.11.0"

[[deps.ClimaComms]]
deps = ["Adapt", "Logging", "LoggingExtras"]
deps = ["Adapt", "Logging", "LoggingExtras", "StaticArrays"]
path = ".."
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.6.5"
version = "0.6.9"

[deps.ClimaComms.extensions]
ClimaCommsCUDAExt = "CUDA"
Expand Down Expand Up @@ -361,6 +359,25 @@ version = "1.11.0"
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
version = "1.11.0"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
git-tree-sha1 = "cbea8a6bd7bed51b1619658dec70035e07b8502f"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.9.14"

[deps.StaticArrays.extensions]
StaticArraysChainRulesCoreExt = "ChainRulesCore"
StaticArraysStatisticsExt = "Statistics"

[deps.StaticArrays.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[deps.StaticArraysCore]]
git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682"
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
version = "1.4.3"

[[deps.StyledStrings]]
uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
version = "1.11.0"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ makedocs(
pages = Any[
"Home" => "index.md",
"Developing with `ClimaComms`" => "internals.md",
"Threaded" => "threaded.md",
"Logging" => "logging.md",
"Frequently Asked Questions" => "faqs.md",
"APIs" => "apis.md",
Expand Down
25 changes: 24 additions & 1 deletion docs/src/apis.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,40 @@ ClimaComms.CUDADevice
ClimaComms.device
ClimaComms.device_functional
ClimaComms.array_type
ClimaComms.free_memory
ClimaComms.total_memory
ClimaComms.allowscalar
ClimaComms.@time
ClimaComms.@elapsed
ClimaComms.@assert
ClimaComms.@sync
ClimaComms.@cuda_sync
Adapt.adapt_structure(::Type{<:AbstractArray}, ::ClimaComms.AbstractDevice)
```

## Threading

```@docs
ClimaComms.@threaded
ClimaComms.@shmem_threaded
ClimaComms.threaded
ClimaComms.threadable
ClimaComms.ThreadableWrapper
ClimaComms.shareable
ClimaComms.set_metadata
ClimaComms.disable_auto_sync
ClimaComms.auto_sync!
ClimaComms.sync_shmem_threads!
ClimaComms.shmem_array
ClimaComms.@unique_shmem_thread
ClimaComms.unique_shmem_thread
ClimaComms.shmem_reduce!
ClimaComms.shmem_mapreduce!
ClimaComms.shmem_any!
ClimaComms.shmem_all!
ClimaComms.shmem_sum!
ClimaComms.shmem_prod!
ClimaComms.shmem_maximum!
ClimaComms.shmem_minimum!
```

## Contexts
Expand Down
123 changes: 123 additions & 0 deletions docs/src/threaded.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Writing `@threaded` Kernels

**This documentation needs to be updated for the new API.**

In this section, we will give a tutorial for writing device-agnostic kernels
using `ClimaComms.@threaded`. As an illustrative example, we will show how to
approximate the `n`-th derivative of a matrix in an efficient GPU kernel.

We can compute the first derivative of an `array` along each column using
second-order finite differences with periodic boundary conditions as

```julia
Base.@propagate_inbounds function axis1_deriv(array, index1, indices...)
prev_index1 = index1 == 1 ? size(array, 1) : index1 - 1
next_index1 = index1 == size(array, 1) ? 1 : index1 + 1
return (
array[next_index1, indices...] -
2 * array[index1, indices...] +
array[prev_index1, indices...]
) / 2
end
```

This can be recursively extended to the `n`-th derivative as
```julia
Base.@propagate_inbounds function axis1_nth_deriv(n, array, index1, indices...)
prev_index1 = index1 == 1 ? size(array, 1) : index1 - 1
next_index1 = index1 == size(array, 1) ? 1 : index1 + 1
if n == 1
return (
array[next_index1, indices...] -
2 * array[index1, indices...] +
array[prev_index1, indices...]
) / 2
else
return (
axis1_nth_deriv(n - 1, array, next_index1, indices...) -
2 * axis1_nth_deriv(n - 1, array, index1, indices...) +
axis1_nth_deriv(n - 1, array, prev_index1, indices...)
) / 2
end
end
```

The simplest way to parallelize this function over the indices of a `matrix` is
to directly evaluate the function at every point:
```julia
columnwise_nth_deriv!(result, matrix, n, device) =
ClimaComms.@threaded device for i in axes(matrix, 1), j in axes(matrix, 2)
@inbounds result[i, j] = axis1_nth_deriv(n, matrix, i, j)
end
```

However, this implementation results in a very large number of global memory
reads: `matrix` is accessed `O(3^n)` times, which becomes prohibitively large as
`n` grows. For example, here is the performance for a `100 × 1000` matrix with
`n = 1` and `n = 10`:
```julia
# TODO: Add performance benchmarks.
```

To reduce the number of global memory reads, we can move each column of the
`matrix` into shared memory before using it to evaluate the `n`-th derivative.
This makes each row index `i in axes(matrix, 1)` interdependent with every other
row index, which we need to explicitly declare:
```julia
columnwise_nth_deriv!(result, matrix, n, device, ::Val{N_rows}) where {N_rows} =
ClimaComms.@threaded device begin
for i in @interdependent(axes(matrix, 1)), j in axes(matrix, 2)
T = eltype(matrix)
matrix_col =
ClimaComms.static_shared_memory_array(device, T, N_rows)
@inbounds begin
ClimaComms.@sync_interdependent matrix_col[i] = matrix[i, j]
ClimaComms.@sync_interdependent result[i, j] =
axis1_nth_deriv(n, matrix_col, i)
end
end
end
```
The shared memory array used to store each column is statically-sized, so the
number of rows in each column must be passed as a static parameter. Also, every
use of the interdependent variable `i` must occur inside a
`@sync_interdependent` expression.

In this implementation, there are `O(N_rows)` global memory reads and
`O(N_rows * 3^n)` shared memory reads, which improves the performance:
```julia
# TODO: Add performance benchmarks.
```

We can further improve runtime by splitting the `O(N_rows * 3^n)` shared memory
reads into two sets of `O(N_rows * 3^(n ÷ 2))` shared memory reads:
```julia
columnwise_nth_deriv!(result, matrix, n, device, ::Val{N_rows}) where {N_rows} =
ClimaComms.@threaded device begin
for i in @interdependent(axes(matrix, 1)), j in axes(matrix, 2)
T = eltype(matrix)
matrix_col =
ClimaComms.static_shared_memory_array(device, T, N_rows)
intermediate_result_col =
ClimaComms.static_shared_memory_array(device, T, N_rows)
@inbounds begin
ClimaComms.@sync_interdependent matrix_col[i] = matrix[i, j]
ClimaComms.@sync_interdependent intermediate_result_col[i] =
axis1_nth_deriv(n ÷ 2, matrix_col, i)
ClimaComms.@sync_interdependent result[i, j] =
axis1_nth_deriv(n - n ÷ 2, intermediate_result_col, i)
end
end
end
```

This implementaion has the following performance:
```julia
# TODO: Add performance benchmarks.
```

We can continue to reduce the number of shared memory reads by adding more
intermediate results, though for large values of `n` we will hit another
performance barrier in thread synchronization time. The number of intermediate
results that gives the best performance will depend on the value of `n` and the
characteristics of the GPU.
84 changes: 3 additions & 81 deletions ext/ClimaCommsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module ClimaCommsCUDAExt
import CUDA

import Adapt
import StaticArrays
import ClimaComms
import ClimaComms: CUDADevice

Expand Down Expand Up @@ -52,86 +53,7 @@ ClimaComms.elapsed(f::F, ::CUDADevice, args...; kwargs...) where {F} =
ClimaComms.assert(::CUDADevice, cond::C, text::T) where {C, T} =
isnothing(text) ? (CUDA.@cuassert cond()) : (CUDA.@cuassert cond() text())

# The number of threads in the kernel being executed by the calling thread.
threads_in_kernel() = CUDA.blockDim().x * CUDA.gridDim().x

# The index of the calling thread, which is between 1 and threads_in_kernel().
thread_index() =
(CUDA.blockIdx().x - 1) * CUDA.blockDim().x + CUDA.threadIdx().x

# The maximum number of blocks that can fit on the GPU used for this kernel.
grid_size_limit(kernel) = CUDA.attribute(
CUDA.device(kernel.fun.mod.ctx),
CUDA.DEVICE_ATTRIBUTE_MAX_GRID_DIM_X,
)

# Either the first value if it is available, or the maximum number of threads
# that can fit in one block of this kernel (cuOccupancyMaxPotentialBlockSize).
# With enough blocks, the latter value will maximize the occupancy of the GPU.
block_size_limit(max_threads_in_block::Int, _) = max_threads_in_block
block_size_limit(::Val{:auto}, kernel) =
CUDA.launch_configuration(kernel.fun).threads

function ClimaComms.run_threaded(
f::F,
::CUDADevice,
::Val,
itr;
block_size,
) where {F}
n_items = length(itr)
n_items > 0 || return nothing

function call_f_from_thread()
item_index = thread_index()
item_index <= n_items &&
@inbounds f(itr[firstindex(itr) + item_index - 1])
return nothing
end
kernel = CUDA.@cuda always_inline=true launch=false call_f_from_thread()
max_blocks = grid_size_limit(kernel)
max_threads_in_block = block_size_limit(block_size, kernel)

# If there are too many items, coarsen by the smallest possible amount.
n_items <= max_blocks * max_threads_in_block ||
return ClimaComms.run_threaded(f, CUDADevice(), 1, itr; block_size)

threads_in_block = min(max_threads_in_block, n_items)
blocks = cld(n_items, threads_in_block)
kernel(; blocks, threads = threads_in_block)
end

function ClimaComms.run_threaded(
f::F,
::CUDADevice,
min_items_in_thread::Int,
itr;
block_size,
) where {F}
min_items_in_thread > 0 || throw(ArgumentError("`coarsen` is not positive"))
n_items = length(itr)
n_items > 0 || return nothing

# Maximize memory coalescing with a "grid-stride loop"; for reference, see
# https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops
call_f_from_thread() =
for item_index in thread_index():threads_in_kernel():n_items
@inbounds f(itr[firstindex(itr) + item_index - 1])
end
kernel = CUDA.@cuda always_inline=true launch=false call_f_from_thread()
max_blocks = grid_size_limit(kernel)
max_threads_in_block = block_size_limit(block_size, kernel)

# If there are too many items to use the specified coarsening, increase it
# by the smallest possible amount.
max_required_threads = cld(n_items, min_items_in_thread)
items_in_thread =
max_required_threads <= max_blocks * max_threads_in_block ?
min_items_in_thread : cld(n_items, max_blocks * max_threads_in_block)

threads_in_block = min(max_threads_in_block, max_required_threads)
blocks = cld(n_items, items_in_thread * threads_in_block)
kernel(; blocks, threads = threads_in_block)
end
include("cuda_threaded.jl")
include("cuda_shmem.jl")

end
Loading
Loading