Skip to content

ArgumentError in Lux.jl with GPU-Accelerated Neural Network Using LuxCUDA and Zygote #1094

@aligurbu

Description

@aligurbu

When running a GPU-accelerated neural network model using Lux.jl, LuxCUDA.jl, and Zygote.jl, the program encounters an ArgumentError when attempting to compute gradients on GPU inputs using Zygote.gradient. The error message indicates device mismatch, specifically an incompatibility between CPUDevice and CUDADevice.

The error arises when using a Chain neural network with GPU parameters and attempting to compute gradients using Zygote.

Reproducible Example:

using Lux
using LuxCUDA
using CUDA
using ComponentArrays
using Random
using Zygote

# Setup
const gpud = gpu_device()
rng = Random.default_rng()
Random.seed!(rng, 0)

# Neural network definition
nn = Chain(
    Dense(3, 20, σ),
    Dense(20, 10, σ),
    Dense(10, 1, tanh)
)

# Initialize parameters
parameters, layer_states = Lux.setup(rng, nn)
gpu_parameters = parameters |> ComponentArray |> gpud

# GPU function
gpu_NN(x) = nn(x, gpu_parameters, layer_states)[1]

# Data points
points = rand(rng, Float32, 3, 10)
gpu_points = CuArray(points)

# Gradient computation
CUDA.allowscalar() do
    for kk in axes(gpu_points, 2)
        r = gpu_points[:, kk]
        φ = gpu_NN(r)[1]
        ∇φ = Zygote.gradient(s -> gpu_NN(s)[1], r)[1] # Fails here
    end
end

Error Output:

ArgumentError: Objects are on devices with different types: CPUDevice and CUDADevice.
Stacktrace:
  [1] combine_devices(T1::Type{CPUDevice}, T2::Type{CUDADevice})
    @ MLDataDevices.Internal C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:127
  [2] macro expansion
    @ C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:205 [inlined]
  [3] unrolled_mapreduce
    @ C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:192 [inlined]
  [4] unrolled_mapreduce(f::typeof(get_device_type), op::typeof(MLDataDevices.Internal.combine_devices), itr::Tuple{…})
    @ MLDataDevices.Internal C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:183
  [5] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}})
    @ MLDataDevices.Internal C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:155
  [6] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}})
    @ MLDataDevices C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\public.jl:388
  [7] internal_operation_mode(xs::Tuple{Base.ReshapedArray{…}, CuArray{…}})
    @ LuxLib C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\traits.jl:210
  [8] ∇activation(Δ::Base.ReshapedArray{…}, out::CuArray{…}, act::typeof(tanh_fast),
 x::LuxLib.Utils.NotaNumber)
    @ LuxLib.Impl C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\impl\activation.jl:107
  [9] (::LuxLib.Impl.var"#78#81"{…})(Δ::Base.ReshapedArray{…})
    @ LuxLib.Impl C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\impl\dense.jl:51  
 [10] ZBack
    @ C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\chainrules.jl:212 [inlined]
 [11] fused_dense
    @ C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\impl\dense.jl:11 [inlined]    
 [12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.ReshapedArray{Float32, 2, ChainRules.OneElement{…}, Tuple{}})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [13] fused_dense_bias_activation
    @ C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\api\dense.jl:35 [inlined]     
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.ReshapedArray{Float32, 2, ChainRules.OneElement{…}, Tuple{}})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [15] Dense
    @ C:\Users\aligu\.julia\packages\Lux\gmUbf\src\layers\basic.jl:343 [inlined]    
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [17] apply
    @ C:\Users\aligu\.julia\packages\LuxCore\SN4dl\src\LuxCore.jl:155 [inlined]     
 [18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [19] applychain
    @ C:\Users\aligu\.julia\packages\Lux\gmUbf\src\layers\containers.jl:0 [inlined] 
 [20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [21] Chain
    @ C:\Users\aligu\.julia\packages\Lux\gmUbf\src\layers\containers.jl:480 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [23] gpu_NN
    @ d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:162 [inlined]
 [24] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRules.OneElement{Float32, 1, Tuple{…}, Tuple{…}})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [25] #16
    @ d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:173 [inlined]
 [26] (::Zygote.Pullback{Tuple{var"#16#18", CuArray{…}}, Tuple{Zygote.Pullback{…}, Zygote.Pullback{…}}})(Δ::Float32)
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [27] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface.jl:91
 [28] gradient(f::Function, args::CuArray{Float32, 1, CUDA.DeviceMemory})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface.jl:148
 [29] (::var"#15#17")()
    @ Main d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:173
 [30] task_local_storage(body::var"#15#17", key::Symbol, val::GPUArraysCore.ScalarIndexing)
    @ Base .\task.jl:297
 [31] allowscalar(f::Function)
    @ GPUArraysCore C:\Users\aligu\.julia\packages\GPUArraysCore\GMsgk\src\GPUArraysCore.jl:183
 [32] top-level scope
    @ d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:169    
Some type information was truncated. Use `show(err)` to see complete types.

ExceptionStack output:
ExceptionStack.txt

Environment:

julia> versioninfo()
Julia Version 1.10.4
Commit 48d4fd4843 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 48 × AMD Ryzen Threadripper PRO 5965WX 24-Cores
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)
Threads: 48 default, 0 interactive, 24 GC (on 48 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 48

Lux v1.2.3
LuxCUDA v0.3.3
CUDA v5.5.2
Zygote v0.6.73
ComponentArrays v0.15.17

julia> CUDA.versioninfo()
CUDA runtime 12.6, artifact installation
CUDA driver 12.5
NVIDIA driver 556.18.0

CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+556.18

Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0

Toolchain:
- Julia: 1.10.4
- LLVM: 15.0.7

2 devices:
  0: NVIDIA RTX A4000 (sm_86, 11.093 GiB / 15.992 GiB available)
  1: NVIDIA RTX A4000 (sm_86, 11.094 GiB / 15.992 GiB available)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions