Open
Description
Hi, I've been using UNet-like architecture that accepts different encoders.
And when passing EfficientNet as an encoder (that contains BatchNorm in the MBConv blocks), it crashes during the gradient computation only when on GPU.
Not 100% sure the issue is with this library, but here's the stacktrace:
ERROR: LoadError: MethodError: no method matching
∇batchnorm(::CuArray{Float32, 1}, ::CuArray{Float32, 1}, ::CuArray{Float32, 4}, ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}}, ::CuArray{Float32, 1}, ::CuArray{Float32, 1}, ::Float32; cache=nothing, alpha=1, beta=0, eps=0.001f0, training=true)
Closest candidates are:
∇batchnorm(::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::Any; cache, eps, alpha, beta, training) where T<:Union{Float32, Float64} at /home/pxl-th/.julia/packages/NNlibCUDA/Oc2CZ/src/cudnn/batchnorm.jl:81
∇batchnorm(::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, 2}, ::CuArray{T, 2}, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::Any; cache, eps, alpha, beta, training) where T<:Union{Float32, Float64} at /home/pxl-th/.julia/packages/NNlibCUDA/Oc2CZ/src/cudnn/batchnorm.jl:71
Stacktrace:
[1] (::Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Iterators.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CuArray{Float32, 1}, CuArray{Float32, 1}, CuArray{Float32, 4}, CuArray{Float32, 1}, CuArray{Float32, 1}, Float32})(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Flux.CUDAint ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:17
[2] (::Flux.CUDAint.var"#793#back#4"{Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Iterators.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CuArray{Float32, 1}, CuArray{Float32, 1}, CuArray{Float32, 4}, CuArray{Float32, 1}, CuArray{Float32, 1}, Float32}})(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Flux.CUDAint ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:65
[3] Pullback
@ ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:9 [inlined]
[4] (::typeof(∂(λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[5] Pullback
@ ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:6 [inlined]
[6] Pullback
@ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:37 [inlined]
[7] (::typeof(∂(applychain)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:37 [inlined]
[9] (::typeof(∂(applychain)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[10] Pullback
@ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:39 [inlined]
[11] (::typeof(∂(λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[12] Pullback
@ ./operators.jl:858 [inlined]
[13] (::typeof(∂(|>)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[14] Pullback
@ ~/.julia/packages/EfficientNet/NKvyu/src/mb.jl:109 [inlined]
[15] (::typeof(∂(#_#7)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[16] Pullback
@ ~/.julia/packages/EfficientNet/NKvyu/src/mb.jl:100 [inlined]
[17] (::typeof(∂(Any##kw)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[18] Pullback
@ ~/.julia/packages/EfficientNet/NKvyu/src/model.jl:125 [inlined]
[19] (::typeof(∂(λ)))(Δ::Vector{Union{Nothing, CuArray{Float32, 4}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[20] Pullback
@ ~/projects/Segmentation.jl/src/Segmentation.jl:27 [inlined]
[21] (::typeof(∂(λ)))(Δ::CuArray{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[22] Pullback
@ ./operators.jl:858 [inlined]
[23] (::typeof(∂(|>)))(Δ::CuArray{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[24] Pullback
@ ~/projects/Segmentation.jl/example/comma.jl:185 [inlined]
[25] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[26] (::Zygote.var"#90#91"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:348
[27] gradient(f::Function, args::Zygote.Params)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76
[28] test_grads()
@ Main ~/projects/Segmentation.jl/example/comma.jl:184
[29] top-level scope
@ ~/projects/Segmentation.jl/example/comma.jl:199
in expression starting at /home/pxl-th/projects/Segmentation.jl/example/comma.jl:199