-
Notifications
You must be signed in to change notification settings - Fork 82
Open
Labels
Description
When running a GPU-accelerated neural network model using Lux.jl, LuxCUDA.jl, and Zygote.jl, the program encounters an ArgumentError when attempting to compute gradients on GPU inputs using Zygote.gradient. The error message indicates device mismatch, specifically an incompatibility between CPUDevice and CUDADevice.
The error arises when using a Chain neural network with GPU parameters and attempting to compute gradients using Zygote.
Reproducible Example:
using Lux
using LuxCUDA
using CUDA
using ComponentArrays
using Random
using Zygote
# Setup
const gpud = gpu_device()
rng = Random.default_rng()
Random.seed!(rng, 0)
# Neural network definition
nn = Chain(
Dense(3, 20, σ),
Dense(20, 10, σ),
Dense(10, 1, tanh)
)
# Initialize parameters
parameters, layer_states = Lux.setup(rng, nn)
gpu_parameters = parameters |> ComponentArray |> gpud
# GPU function
gpu_NN(x) = nn(x, gpu_parameters, layer_states)[1]
# Data points
points = rand(rng, Float32, 3, 10)
gpu_points = CuArray(points)
# Gradient computation
CUDA.allowscalar() do
for kk in axes(gpu_points, 2)
r = gpu_points[:, kk]
φ = gpu_NN(r)[1]
∇φ = Zygote.gradient(s -> gpu_NN(s)[1], r)[1] # Fails here
end
end
Error Output:
ArgumentError: Objects are on devices with different types: CPUDevice and CUDADevice.
Stacktrace:
[1] combine_devices(T1::Type{CPUDevice}, T2::Type{CUDADevice})
@ MLDataDevices.Internal C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:127
[2] macro expansion
@ C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:205 [inlined]
[3] unrolled_mapreduce
@ C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:192 [inlined]
[4] unrolled_mapreduce(f::typeof(get_device_type), op::typeof(MLDataDevices.Internal.combine_devices), itr::Tuple{…})
@ MLDataDevices.Internal C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:183
[5] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}})
@ MLDataDevices.Internal C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:155
[6] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}})
@ MLDataDevices C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\public.jl:388
[7] internal_operation_mode(xs::Tuple{Base.ReshapedArray{…}, CuArray{…}})
@ LuxLib C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\traits.jl:210
[8] ∇activation(Δ::Base.ReshapedArray{…}, out::CuArray{…}, act::typeof(tanh_fast),
x::LuxLib.Utils.NotaNumber)
@ LuxLib.Impl C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\impl\activation.jl:107
[9] (::LuxLib.Impl.var"#78#81"{…})(Δ::Base.ReshapedArray{…})
@ LuxLib.Impl C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\impl\dense.jl:51
[10] ZBack
@ C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\chainrules.jl:212 [inlined]
[11] fused_dense
@ C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\impl\dense.jl:11 [inlined]
[12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.ReshapedArray{Float32, 2, ChainRules.OneElement{…}, Tuple{}})
@ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
[13] fused_dense_bias_activation
@ C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\api\dense.jl:35 [inlined]
[14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.ReshapedArray{Float32, 2, ChainRules.OneElement{…}, Tuple{}})
@ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
[15] Dense
@ C:\Users\aligu\.julia\packages\Lux\gmUbf\src\layers\basic.jl:343 [inlined]
[16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
@ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
[17] apply
@ C:\Users\aligu\.julia\packages\LuxCore\SN4dl\src\LuxCore.jl:155 [inlined]
[18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
@ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
[19] applychain
@ C:\Users\aligu\.julia\packages\Lux\gmUbf\src\layers\containers.jl:0 [inlined]
[20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
@ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
[21] Chain
@ C:\Users\aligu\.julia\packages\Lux\gmUbf\src\layers\containers.jl:480 [inlined]
[22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
@ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
[23] gpu_NN
@ d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:162 [inlined]
[24] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRules.OneElement{Float32, 1, Tuple{…}, Tuple{…}})
@ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
[25] #16
@ d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:173 [inlined]
[26] (::Zygote.Pullback{Tuple{var"#16#18", CuArray{…}}, Tuple{Zygote.Pullback{…}, Zygote.Pullback{…}}})(Δ::Float32)
@ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
[27] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
@ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface.jl:91
[28] gradient(f::Function, args::CuArray{Float32, 1, CUDA.DeviceMemory})
@ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface.jl:148
[29] (::var"#15#17")()
@ Main d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:173
[30] task_local_storage(body::var"#15#17", key::Symbol, val::GPUArraysCore.ScalarIndexing)
@ Base .\task.jl:297
[31] allowscalar(f::Function)
@ GPUArraysCore C:\Users\aligu\.julia\packages\GPUArraysCore\GMsgk\src\GPUArraysCore.jl:183
[32] top-level scope
@ d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:169
Some type information was truncated. Use `show(err)` to see complete types.
ExceptionStack output:
ExceptionStack.txt
Environment:
julia> versioninfo()
Julia Version 1.10.4
Commit 48d4fd4843 (2024-06-04 10:41 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Windows (x86_64-w64-mingw32)
CPU: 48 × AMD Ryzen Threadripper PRO 5965WX 24-Cores
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver3)
Threads: 48 default, 0 interactive, 24 GC (on 48 virtual cores)
Environment:
JULIA_EDITOR = code
JULIA_NUM_THREADS = 48
Lux v1.2.3
LuxCUDA v0.3.3
CUDA v5.5.2
Zygote v0.6.73
ComponentArrays v0.15.17
julia> CUDA.versioninfo()
CUDA runtime 12.6, artifact installation
CUDA driver 12.5
NVIDIA driver 556.18.0
CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+556.18
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0
Toolchain:
- Julia: 1.10.4
- LLVM: 15.0.7
2 devices:
0: NVIDIA RTX A4000 (sm_86, 11.093 GiB / 15.992 GiB available)
1: NVIDIA RTX A4000 (sm_86, 11.094 GiB / 15.992 GiB available)
Reactions are currently unavailable