Skip to content

Commit

Permalink
move CUDA support to a package extension
Browse files Browse the repository at this point in the history
  • Loading branch information
IanButterworth committed Dec 9, 2022
1 parent c850df5 commit bf3d95f
Show file tree
Hide file tree
Showing 13 changed files with 106 additions and 67 deletions.
11 changes: 10 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"

[compat]
Adapt = "3.0"
CUDA = "3"
Expand All @@ -41,13 +45,18 @@ StatsBase = "0.33"
Zygote = "0.6.49"
julia = "1.6"

[extensions]
CUDAExt = ["CUDA", "NNlibCUDA"]

[extras]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
test = ["Test", "CUDA", "NNlibCUDA", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
23 changes: 23 additions & 0 deletions ext/CUDAExt/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module CUDAExt

using CUDA
import NNlib, NNlibCUDA

using Flux
import Flux: adapt_storage, _gpu, FluxCPUAdaptor

using Adapt
using ChainRulesCore
using Random
using Zygote

const use_cuda = Ref{Union{Nothing,Bool}}(nothing)

include("utils.jl")
include("functor.jl")

include("layers/normalise.jl")

include("cudnn.jl")

end # module
File renamed without changes.
46 changes: 46 additions & 0 deletions ext/CUDAExt/functor.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
struct FluxCUDAAdaptor end
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
if VERSION >= v"1.7"
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
else
adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng()
end
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) =
error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().")

# TODO: figure out the correct design for OneElement
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x))


adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x)
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()

function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray)
Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),)
end

function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx)),)
end


function _gpu(x)
check_use_cuda()
use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x
end

function check_use_cuda()
if use_cuda[] === nothing
use_cuda[] = CUDA.functional()
if use_cuda[] && !CUDA.has_cudnn()
@warn "CUDA.jl found cuda, but did not find libcudnn. Some functionality will not be available."
end
if !(use_cuda[])
@info """The GPU function is being called but the GPU is not accessible.
Defaulting back to the CPU. (No action is required if you want to run on the CPU).""" maxlog=1
end
end
end
ChainRulesCore.@non_differentiable check_use_cuda()
3 changes: 3 additions & 0 deletions ext/CUDAExt/layers/normalise.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
dropout_mask(rng, x::CuArray, p; kwargs...) =
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
1 change: 1 addition & 0 deletions ext/CUDAExt/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
rng_from_array(::CuArray) = CUDA.default_rng()
8 changes: 4 additions & 4 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ include("train.jl")
using .Train
# using .Train: setup, @train_autodiff

using CUDA
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)

using Adapt, Functors, OneHotArrays
include("utils.jl")
include("functor.jl")
Expand All @@ -67,6 +64,9 @@ using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12

include("deprecations.jl")

include("cuda/cuda.jl")
# If package extensions are not supported in this Julia version
if !isdefined(Base, :get_extension)
include("../ext/CUDAExt/CUDAExt.jl")
end

end # module
11 changes: 0 additions & 11 deletions src/cuda/cuda.jl

This file was deleted.

54 changes: 12 additions & 42 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,45 +90,20 @@ end

# Allows caching of the parameters when params is called within gradient() to fix #2040.
# @non_differentiable params(m...) # https://github.com/FluxML/Flux.jl/pull/2054
# That speeds up implicit use, and silently breaks explicit use.
# That speeds up implicit use, and silently breaks explicit use.
# From @macroexpand Zygote.@nograd params(m...) and https://github.com/FluxML/Zygote.jl/pull/1248
Zygote._pullback(::Zygote.Context{true}, ::typeof(params), m...) = params(m), _ -> nothing

struct FluxCUDAAdaptor end
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
if VERSION >= v"1.7"
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
else
adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng()
end
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) =
error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().")

# TODO: figure out the correct design for OneElement
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x))

struct FluxCPUAdaptor end

# define rules for handling structured arrays
adapt_storage(to::FluxCPUAdaptor, x::AbstractArray) = adapt(Array, x)
adapt_storage(to::FluxCPUAdaptor, x::AbstractRange) = x
adapt_storage(to::FluxCPUAdaptor, x::Zygote.FillArrays.AbstractFill) = x
adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x)
adapt_storage(to::FluxCPUAdaptor, x::Zygote.OneElement) = x
adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x

function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray)
Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),)
end

function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx)),)
end

# CPU/GPU movement conveniences

"""
Expand Down Expand Up @@ -166,8 +141,12 @@ _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)
"""
gpu(x)
Requires CUDA and NNlibCUDA to be loaded
```julia-rept
julia> using Flux, CUDA, NNlibCUDA
```
Moves `m` to the current GPU device, if available. It is a no-op otherwise.
See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/)
See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/)
to help identify the current device.
This works for functions, and any struct marked with [`@functor`](@ref).
Expand All @@ -187,23 +166,14 @@ CuArray{Float32, 2}
```
"""
function gpu(x)
check_use_cuda()
use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x
end

function check_use_cuda()
if use_cuda[] === nothing
use_cuda[] = CUDA.functional()
if use_cuda[] && !CUDA.has_cudnn()
@warn "CUDA.jl found cuda, but did not find libcudnn. Some functionality will not be available."
end
if !(use_cuda[])
@info """The GPU function is being called but the GPU is not accessible.
Defaulting back to the CPU. (No action is required if you want to run on the CPU).""" maxlog=1
end
if hasmethod(_gpu, Tuple{Any})
_gpu(x)
else
error("CUDA not loaded. Load `CUDA` and `NNlibCUDA` to access GPU functionality")
end
end
ChainRulesCore.@non_differentiable check_use_cuda()

function _gpu end

# Precision

Expand Down
7 changes: 2 additions & 5 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ function dropout(rng, x, p; dims=:, active::Bool=true)
end
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)

dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
dropout_mask(rng, x::CuArray, p; kwargs...) =
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
function _dropout_mask(rng, x, p; dims=:)
realfptype = float(real(eltype(x)))
Expand All @@ -56,9 +53,9 @@ ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
Dropout layer.
While training, for each input, this layer either sets that input to `0` (with probability
`p`) or scales it by `1 / (1 - p)`. To apply dropout along certain dimension(s), specify the
`p`) or scales it by `1 / (1 - p)`. To apply dropout along certain dimension(s), specify the
`dims` keyword. e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input
(also called 2D dropout). This is used as a regularisation, i.e. it reduces overfitting during
(also called 2D dropout). This is used as a regularisation, i.e. it reduces overfitting during
training.
In the forward pass, this layer applies the [`Flux.dropout`](@ref) function. See that for more
Expand Down
1 change: 0 additions & 1 deletion src/losses/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using Zygote
using Zygote: @adjoint
using ChainRulesCore
using ..Flux: ofeltype, epseltype
using CUDA
using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss
import Base.Broadcast: broadcasted

Expand Down
5 changes: 2 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ The current defaults are:
- Julia version is >= 1.7: `Random.default_rng()`
"""
rng_from_array(::AbstractArray) = default_rng_value()
rng_from_array(::CuArray) = CUDA.default_rng()

@non_differentiable rng_from_array(::Any)

Expand Down Expand Up @@ -226,7 +225,7 @@ ChainRulesCore.@non_differentiable kaiming_normal(::Any...)
"""
truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
truncated_normal([rng]; kw...) -> Function
Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution.
The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(100))`.
Expand Down Expand Up @@ -393,7 +392,7 @@ Has the following behaviour
* 2D: An identity matrix (useful for an identity matrix multiplication)
* More than 2D: A dense block array of center tap spatial filters (useful for an identity convolution)
Some caveats:
Some caveats:
* Not all layers will be identity mapping when used with this init. Exceptions
include recurrent layers and normalization layers.
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ using Test
using Random, Statistics, LinearAlgebra
using IterTools: ncycle
using Zygote

# Both required to trigger CUDAExt
using CUDA
using NNlibCUDA

Random.seed!(0)

Expand Down

0 comments on commit bf3d95f

Please sign in to comment.