Open
Description
Using leakyrelu
causes a compilation error when differentiating through the following gradient penalty loss. It works on cpu and using for example elu/relu
on gpu.
using Flux, Zygote, CUDA
function gradient_penalty(m, x)
_, back = Flux.pullback(() -> sum(m(x)), params(x))
grads = back(1.0f0)[x]
return sum(grads .^ 2)
end
x = randn(Float32, 1, 4) # dims, batch
m₁ = Chain(Dense(1, 1), x -> leakyrelu.(x, 0.2f0))
l, b = Flux.pullback(() -> gradient_penalty(m₁, x), params(m₁)) # Ok
cx = x |> gpu
cm₂ = Chain(Dense(1, 1), x -> elu.(x)) |> gpu
l, b = Flux.pullback(() -> gradient_penalty(cm₂, cx), params(cm₂)) # Ok
cm₁ = Chain(Dense(1, 1), x -> leakyrelu.(x, 0.2f0)) |> gpu
l, b = Flux.pullback(() -> gradient_penalty(cm₁, cx), params(cm₁)) # Fails to compile
Throws
ERROR: LoadError: GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceMatrix{Tuple{Float32, typeof(∂(#1122))}, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#561#565"{Zygote.Context, Zygote.var"#1122#1126"}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{ForwardDiff.Dual{Nothing, Float32, 2}, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64) failed
KernelError: passing and using non-bitstype argument
Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#561#565"{Zygote.Context, Zygote.var"#1122#1126"}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{ForwardDiff.Dual{Nothing, Float32, 2}, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
.f is of type Zygote.var"#561#565"{Zygote.Context, Zygote.var"#1122#1126"} which is not isbits.
.cx is of type Zygote.Context which is not isbits.
.cache is of type Union{Nothing, IdDict{Any, Any}} which is not isbits.
Stacktrace:
[1] check_invocation(job::GPUCompiler.CompilerJob)
@ GPUCompiler ~/.julia/packages/GPUCompiler/1Ajz2/src/validation.jl:66
[2] macro expansion
@ ~/.julia/packages/GPUCompiler/1Ajz2/src/driver.jl:325 [inlined]
[3] macro expansion
@ ~/.julia/packages/TimerOutputs/5tW2E/src/TimerOutput.jl:252 [inlined]
[4] macro expansion
@ ~/.julia/packages/GPUCompiler/1Ajz2/src/driver.jl:324 [inlined]
[5] emit_asm(job::GPUCompiler.CompilerJob, ir::LLVM.Module; strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
@ GPUCompiler ~/.julia/packages/GPUCompiler/1Ajz2/src/utils.jl:64
[6] cufunction_compile(job::GPUCompiler.CompilerJob)
@ CUDA ~/.julia/packages/CUDA/bki2w/src/compiler/execution.jl:326
[7] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
@ GPUCompiler ~/.julia/packages/GPUCompiler/1Ajz2/src/cache.jl:90
[8] cufunction(f::GPUArrays.var"#broadcast_kernel#17", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Tuple{Float32, typeof(∂(#1122))}, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#561#565"{Zygote.Context, Zygote.var"#1122#1126"}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{ForwardDiff.Dual{Nothing, Float32, 2}, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}; name::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ CUDA ~/.julia/packages/CUDA/bki2w/src/compiler/execution.jl:297
[9] cufunction
@ ~/.julia/packages/CUDA/bki2w/src/compiler/execution.jl:291 [inlined]
[10] macro expansion
@ ~/.julia/packages/CUDA/bki2w/src/compiler/execution.jl:102 [inlined]
[11] #launch_heuristic#270
@ ~/.julia/packages/CUDA/bki2w/src/gpuarrays.jl:17 [inlined]
[12] copyto!
@ ~/.julia/packages/GPUArrays/umZob/src/host/broadcast.jl:65 [inlined]
[13] copyto!
@ ./broadcast.jl:913 [inlined]
[14] copy
@ ~/.julia/packages/GPUArrays/umZob/src/host/broadcast.jl:47 [inlined]
[15] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, Zygote.var"#561#565"{Zygote.Context, Zygote.var"#1122#1126"}, Tuple{CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 2, CUDA.Mem.DeviceBuffer}}})
@ Base.Broadcast ./broadcast.jl:860
[16] map(::Function, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 2, CUDA.Mem.DeviceBuffer})
@ GPUArrays ~/.julia/packages/GPUArrays/umZob/src/host/broadcast.jl:90
[17] ∇map(cx::Zygote.Context, f::Zygote.var"#1122#1126", args::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/lib/array.jl:197
[18] adjoint
@ ~/.julia/packages/Zygote/FPUm3/src/lib/array.jl:223 [inlined]
[19] _pullback(__context__::Zygote.Context, 541::typeof(map), f::Function, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65
[20] _pullback
@ ~/.julia/packages/Zygote/FPUm3/src/lib/broadcast.jl:241 [inlined]
[21] _pullback(::Zygote.Context, ::typeof(Zygote.broadcast_forward), ::typeof(leakyrelu), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float32)
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
[22] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:814
[23] adjoint
@ ~/.julia/packages/Zygote/FPUm3/src/lib/lib.jl:200 [inlined]
[24] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[25] _pullback
@ ~/.julia/packages/Zygote/FPUm3/src/lib/broadcast.jl:265 [inlined]
[26] _pullback(::Zygote.Context, ::typeof(ZygoteRules.adjoint), ::Zygote.Context, ::typeof(Base.Broadcast.broadcasted), ::CUDA.CuArrayStyle{2}, ::typeof(leakyrelu), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float32)
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
[27] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:814
[28] adjoint
@ ~/.julia/packages/Zygote/FPUm3/src/lib/lib.jl:200 [inlined]
[29] _pullback (repeats 2 times)
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[30] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::typeof(Base.Broadcast.broadcasted), ::CUDA.CuArrayStyle{2}, ::typeof(leakyrelu), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float32)
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
[31] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:814
[32] adjoint
@ ~/.julia/packages/Zygote/FPUm3/src/lib/lib.jl:189 [inlined]
[33] _pullback(::Zygote.Context, ::typeof(Core._apply), ::Function, ::Tuple{Zygote.Context, typeof(Base.Broadcast.broadcasted)}, ::Tuple{CUDA.CuArrayStyle{2}, typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Float32}, ::Tuple{})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65
[34] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:814
[35] adjoint
@ ~/.julia/packages/Zygote/FPUm3/src/lib/lib.jl:200 [inlined]
[36] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[37] _pullback
@ ~/.julia/packages/Zygote/FPUm3/src/lib/lib.jl:200 [inlined]
[38] _pullback(::Zygote.Context, ::typeof(ZygoteRules.adjoint), ::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::typeof(Base.Broadcast.broadcasted), ::Tuple{CUDA.CuArrayStyle{2}, typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Float32}, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
[39] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:814
[40] adjoint
@ ~/.julia/packages/Zygote/FPUm3/src/lib/lib.jl:200 [inlined]
[41] _pullback (repeats 2 times)
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[42] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::typeof(Base.Broadcast.broadcasted), ::Tuple{CUDA.CuArrayStyle{2}, typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Float32}, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
[43] _pullback
@ ./broadcast.jl:1303 [inlined]
[44] _pullback
@ ~/ws/msc/scratch/gpsmaller.jl:13 [inlined]
[45] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::var"#15#16", ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
[46] _pullback
@ ~/.julia/packages/Flux/qAdFM/src/layers/basic.jl:47 [inlined]
--- the last 2 lines are repeated 1 more time ---
[49] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#15#16"}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
[50] _pullback
@ ~/.julia/packages/Flux/qAdFM/src/layers/basic.jl:49 [inlined]
[51] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#15#16"}}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
[52] _pullback
@ ~/ws/msc/scratch/gpsmaller.jl:6 [inlined]
[53] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::var"#13#14"{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#15#16"}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
[54] _pullback
@ ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:352 [inlined]
[55] _pullback(::Zygote.Context, ::typeof(pullback), ::var"#13#14"{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#15#16"}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, ::Params)
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
[56] _pullback
@ ~/ws/msc/scratch/gpsmaller.jl:6 [inlined]
[57] _pullback(::Zygote.Context, ::typeof(gradient_penalty), ::Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#15#16"}}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
[58] _pullback
@ ~/ws/msc/scratch/gpsmaller.jl:21 [inlined]
[59] _pullback(::Zygote.Context, ::var"#23#24")
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
[60] pullback(f::Function, ps::Params)
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:352
in expression starting at /home/vincent/ws/msc/scratch/gpsmaller.jl:21