Skip to content

BatchNorm causes error during gradient computation #514

Open
@pxl-th

Description

@pxl-th

Hi, I've been using UNet-like architecture that accepts different encoders.
And when passing EfficientNet as an encoder (that contains BatchNorm in the MBConv blocks), it crashes during the gradient computation only when on GPU.

Not 100% sure the issue is with this library, but here's the stacktrace:

ERROR: LoadError: MethodError: no method matching
  ∇batchnorm(::CuArray{Float32, 1}, ::CuArray{Float32, 1}, ::CuArray{Float32, 4}, ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}}, ::CuArray{Float32, 1}, ::CuArray{Float32, 1}, ::Float32; cache=nothing, alpha=1, beta=0, eps=0.001f0, training=true)
Closest candidates are:
  ∇batchnorm(::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::Any; cache, eps, alpha, beta, training) where T<:Union{Float32, Float64} at /home/pxl-th/.julia/packages/NNlibCUDA/Oc2CZ/src/cudnn/batchnorm.jl:81
  ∇batchnorm(::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, 2}, ::CuArray{T, 2}, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::Any; cache, eps, alpha, beta, training) where T<:Union{Float32, Float64} at /home/pxl-th/.julia/packages/NNlibCUDA/Oc2CZ/src/cudnn/batchnorm.jl:71
Stacktrace:
  [1] (::Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Iterators.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CuArray{Float32, 1}, CuArray{Float32, 1}, CuArray{Float32, 4}, CuArray{Float32, 1}, CuArray{Float32, 1}, Float32})(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Flux.CUDAint ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:17
  [2] (::Flux.CUDAint.var"#793#back#4"{Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Iterators.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CuArray{Float32, 1}, CuArray{Float32, 1}, CuArray{Float32, 4}, CuArray{Float32, 1}, CuArray{Float32, 1}, Float32}})(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Flux.CUDAint ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:65
  [3] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:9 [inlined]
  [4] (::typeof((λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [5] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:6 [inlined]
  [6] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:37 [inlined]
  [7] (::typeof((applychain)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:37 [inlined]
  [9] (::typeof((applychain)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:39 [inlined]
 [11] (::typeof((λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [12] Pullback
    @ ./operators.jl:858 [inlined]
 [13] (::typeof((|>)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/EfficientNet/NKvyu/src/mb.jl:109 [inlined]
 [15] (::typeof((#_#7)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/EfficientNet/NKvyu/src/mb.jl:100 [inlined]
 [17] (::typeof((Any##kw)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/.julia/packages/EfficientNet/NKvyu/src/model.jl:125 [inlined]
 [19] (::typeof((λ)))(Δ::Vector{Union{Nothing, CuArray{Float32, 4}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [20] Pullback
    @ ~/projects/Segmentation.jl/src/Segmentation.jl:27 [inlined]
 [21] (::typeof((λ)))(Δ::CuArray{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [22] Pullback
    @ ./operators.jl:858 [inlined]
 [23] (::typeof((|>)))(Δ::CuArray{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [24] Pullback
    @ ~/projects/Segmentation.jl/example/comma.jl:185 [inlined]
 [25] (::typeof((λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [26] (::Zygote.var"#90#91"{Zygote.Params, typeof((λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:348
 [27] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76
 [28] test_grads()
    @ Main ~/projects/Segmentation.jl/example/comma.jl:184
 [29] top-level scope
    @ ~/projects/Segmentation.jl/example/comma.jl:199
in expression starting at /home/pxl-th/projects/Segmentation.jl/example/comma.jl:199

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions