Skip to content

Commit bf3d95f

Browse files
move CUDA support to a package extension
1 parent c850df5 commit bf3d95f

File tree

13 files changed

+106
-67
lines changed

13 files changed

+106
-67
lines changed

Project.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2323
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2424
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2525

26+
[weakdeps]
27+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
28+
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
29+
2630
[compat]
2731
Adapt = "3.0"
2832
CUDA = "3"
@@ -41,13 +45,18 @@ StatsBase = "0.33"
4145
Zygote = "0.6.49"
4246
julia = "1.6"
4347

48+
[extensions]
49+
CUDAExt = ["CUDA", "NNlibCUDA"]
50+
4451
[extras]
4552
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
53+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4654
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4755
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
4856
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
4957
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
58+
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
5059
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5160

5261
[targets]
53-
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
62+
test = ["Test", "CUDA", "NNlibCUDA", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]

ext/CUDAExt/CUDAExt.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module CUDAExt
2+
3+
using CUDA
4+
import NNlib, NNlibCUDA
5+
6+
using Flux
7+
import Flux: adapt_storage, _gpu, FluxCPUAdaptor
8+
9+
using Adapt
10+
using ChainRulesCore
11+
using Random
12+
using Zygote
13+
14+
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)
15+
16+
include("utils.jl")
17+
include("functor.jl")
18+
19+
include("layers/normalise.jl")
20+
21+
include("cudnn.jl")
22+
23+
end # module
File renamed without changes.

ext/CUDAExt/functor.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
struct FluxCUDAAdaptor end
2+
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
3+
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
4+
if VERSION >= v"1.7"
5+
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
6+
else
7+
adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng()
8+
end
9+
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
10+
adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) =
11+
error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().")
12+
13+
# TODO: figure out the correct design for OneElement
14+
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x))
15+
16+
17+
adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x)
18+
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
19+
20+
function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray)
21+
Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),)
22+
end
23+
24+
function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
25+
adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx)),)
26+
end
27+
28+
29+
function _gpu(x)
30+
check_use_cuda()
31+
use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x
32+
end
33+
34+
function check_use_cuda()
35+
if use_cuda[] === nothing
36+
use_cuda[] = CUDA.functional()
37+
if use_cuda[] && !CUDA.has_cudnn()
38+
@warn "CUDA.jl found cuda, but did not find libcudnn. Some functionality will not be available."
39+
end
40+
if !(use_cuda[])
41+
@info """The GPU function is being called but the GPU is not accessible.
42+
Defaulting back to the CPU. (No action is required if you want to run on the CPU).""" maxlog=1
43+
end
44+
end
45+
end
46+
ChainRulesCore.@non_differentiable check_use_cuda()

ext/CUDAExt/layers/normalise.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
2+
dropout_mask(rng, x::CuArray, p; kwargs...) =
3+
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))

ext/CUDAExt/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
rng_from_array(::CuArray) = CUDA.default_rng()

src/Flux.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ include("train.jl")
3939
using .Train
4040
# using .Train: setup, @train_autodiff
4141

42-
using CUDA
43-
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)
44-
4542
using Adapt, Functors, OneHotArrays
4643
include("utils.jl")
4744
include("functor.jl")
@@ -67,6 +64,9 @@ using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12
6764

6865
include("deprecations.jl")
6966

70-
include("cuda/cuda.jl")
67+
# If package extensions are not supported in this Julia version
68+
if !isdefined(Base, :get_extension)
69+
include("../ext/CUDAExt/CUDAExt.jl")
70+
end
7171

7272
end # module

src/cuda/cuda.jl

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/functor.jl

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -90,45 +90,20 @@ end
9090

9191
# Allows caching of the parameters when params is called within gradient() to fix #2040.
9292
# @non_differentiable params(m...) # https://github.com/FluxML/Flux.jl/pull/2054
93-
# That speeds up implicit use, and silently breaks explicit use.
93+
# That speeds up implicit use, and silently breaks explicit use.
9494
# From @macroexpand Zygote.@nograd params(m...) and https://github.com/FluxML/Zygote.jl/pull/1248
9595
Zygote._pullback(::Zygote.Context{true}, ::typeof(params), m...) = params(m), _ -> nothing
9696

97-
struct FluxCUDAAdaptor end
98-
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
99-
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
100-
if VERSION >= v"1.7"
101-
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
102-
else
103-
adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng()
104-
end
105-
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
106-
adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) =
107-
error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().")
108-
109-
# TODO: figure out the correct design for OneElement
110-
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x))
111-
11297
struct FluxCPUAdaptor end
11398

11499
# define rules for handling structured arrays
115100
adapt_storage(to::FluxCPUAdaptor, x::AbstractArray) = adapt(Array, x)
116101
adapt_storage(to::FluxCPUAdaptor, x::AbstractRange) = x
117102
adapt_storage(to::FluxCPUAdaptor, x::Zygote.FillArrays.AbstractFill) = x
118-
adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x)
119103
adapt_storage(to::FluxCPUAdaptor, x::Zygote.OneElement) = x
120104
adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x
121-
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
122105
adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x
123106

124-
function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray)
125-
Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),)
126-
end
127-
128-
function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
129-
adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx)),)
130-
end
131-
132107
# CPU/GPU movement conveniences
133108

134109
"""
@@ -166,8 +141,12 @@ _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)
166141
"""
167142
gpu(x)
168143
144+
Requires CUDA and NNlibCUDA to be loaded
145+
```julia-rept
146+
julia> using Flux, CUDA, NNlibCUDA
147+
```
169148
Moves `m` to the current GPU device, if available. It is a no-op otherwise.
170-
See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/)
149+
See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/)
171150
to help identify the current device.
172151
173152
This works for functions, and any struct marked with [`@functor`](@ref).
@@ -187,23 +166,14 @@ CuArray{Float32, 2}
187166
```
188167
"""
189168
function gpu(x)
190-
check_use_cuda()
191-
use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x
192-
end
193-
194-
function check_use_cuda()
195-
if use_cuda[] === nothing
196-
use_cuda[] = CUDA.functional()
197-
if use_cuda[] && !CUDA.has_cudnn()
198-
@warn "CUDA.jl found cuda, but did not find libcudnn. Some functionality will not be available."
199-
end
200-
if !(use_cuda[])
201-
@info """The GPU function is being called but the GPU is not accessible.
202-
Defaulting back to the CPU. (No action is required if you want to run on the CPU).""" maxlog=1
203-
end
169+
if hasmethod(_gpu, Tuple{Any})
170+
_gpu(x)
171+
else
172+
error("CUDA not loaded. Load `CUDA` and `NNlibCUDA` to access GPU functionality")
204173
end
205174
end
206-
ChainRulesCore.@non_differentiable check_use_cuda()
175+
176+
function _gpu end
207177

208178
# Precision
209179

src/layers/normalise.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ function dropout(rng, x, p; dims=:, active::Bool=true)
3636
end
3737
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)
3838

39-
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
40-
dropout_mask(rng, x::CuArray, p; kwargs...) =
41-
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
4239
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
4340
function _dropout_mask(rng, x, p; dims=:)
4441
realfptype = float(real(eltype(x)))
@@ -56,9 +53,9 @@ ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
5653
Dropout layer.
5754
5855
While training, for each input, this layer either sets that input to `0` (with probability
59-
`p`) or scales it by `1 / (1 - p)`. To apply dropout along certain dimension(s), specify the
56+
`p`) or scales it by `1 / (1 - p)`. To apply dropout along certain dimension(s), specify the
6057
`dims` keyword. e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input
61-
(also called 2D dropout). This is used as a regularisation, i.e. it reduces overfitting during
58+
(also called 2D dropout). This is used as a regularisation, i.e. it reduces overfitting during
6259
training.
6360
6461
In the forward pass, this layer applies the [`Flux.dropout`](@ref) function. See that for more

0 commit comments

Comments
 (0)