diff --git a/src/functor.jl b/src/functor.jl index 7e4d552753..0c254bb1e1 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -341,6 +341,7 @@ function gpu(::FluxCUDAAdaptor, x) `CUDA.jl` must be loaded to access it. Add `using CUDA` or `import CUDA` to your code. """ maxlog=1 + return x end end @@ -361,6 +362,7 @@ function gpu(::FluxAMDAdaptor, x) `AMDGPU.jl` must be loaded to access it. Add `using AMDGPU` or `import AMDGPU` to your code. """ maxlog=1 + return x end end @@ -380,6 +382,7 @@ function gpu(::FluxMetalAdaptor, x) The Metal functionality is being called but `Metal.jl` must be loaded to access it. """ maxlog=1 + return x end end diff --git a/test/functors.jl b/test/functors.jl new file mode 100644 index 0000000000..9919567b56 --- /dev/null +++ b/test/functors.jl @@ -0,0 +1,4 @@ +x = rand(Float32, 10, 10) +if !(Flux.CUDA_LOADED[] || Flux.AMD_LOADED[] || Flux.METAL_LOADED[]) + @test x === gpu(x) +end diff --git a/test/runtests.jl b/test/runtests.jl index 50c9ea8b01..90bafb67ba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,6 +56,9 @@ Random.seed!(0) include("outputsize.jl") end + @testset "functors" begin + include("functors.jl") + end if get(ENV, "FLUX_TEST_CUDA", "false") == "true" using CUDA