-
-
Notifications
You must be signed in to change notification settings - Fork 604
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
move CUDA support to a package extension
- Loading branch information
1 parent
c850df5
commit bf3d95f
Showing
13 changed files
with
106 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
module CUDAExt | ||
|
||
using CUDA | ||
import NNlib, NNlibCUDA | ||
|
||
using Flux | ||
import Flux: adapt_storage, _gpu, FluxCPUAdaptor | ||
|
||
using Adapt | ||
using ChainRulesCore | ||
using Random | ||
using Zygote | ||
|
||
const use_cuda = Ref{Union{Nothing,Bool}}(nothing) | ||
|
||
include("utils.jl") | ||
include("functor.jl") | ||
|
||
include("layers/normalise.jl") | ||
|
||
include("cudnn.jl") | ||
|
||
end # module |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
struct FluxCUDAAdaptor end | ||
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) | ||
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x)) | ||
if VERSION >= v"1.7" | ||
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng() | ||
else | ||
adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng() | ||
end | ||
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x | ||
adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) = | ||
error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().") | ||
|
||
# TODO: figure out the correct design for OneElement | ||
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x)) | ||
|
||
|
||
adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x) | ||
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng() | ||
|
||
function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray) | ||
Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),) | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray) | ||
adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx)),) | ||
end | ||
|
||
|
||
function _gpu(x) | ||
check_use_cuda() | ||
use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x | ||
end | ||
|
||
function check_use_cuda() | ||
if use_cuda[] === nothing | ||
use_cuda[] = CUDA.functional() | ||
if use_cuda[] && !CUDA.has_cudnn() | ||
@warn "CUDA.jl found cuda, but did not find libcudnn. Some functionality will not be available." | ||
end | ||
if !(use_cuda[]) | ||
@info """The GPU function is being called but the GPU is not accessible. | ||
Defaulting back to the CPU. (No action is required if you want to run on the CPU).""" maxlog=1 | ||
end | ||
end | ||
end | ||
ChainRulesCore.@non_differentiable check_use_cuda() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) | ||
dropout_mask(rng, x::CuArray, p; kwargs...) = | ||
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays.")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
rng_from_array(::CuArray) = CUDA.default_rng() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters