Skip to content

No method matching BatchNorm when loading model  #37

Open
@tclements

Description

@tclements

Trying to load ArcFace model from ONNX.

Using versions

(v1.3) pkg> st
    Status `~/.julia/environments/v1.3/Project.toml`
  [c5f51814] CUDAdrv v6.0.0
  [be33ccc6] CUDAnative v2.10.2
  [3a865a2d] CuArrays v1.7.2
  [587475ba] Flux v0.10.1
  [d0dd6a25] ONNX v0.1.1

Beginning of model.jl looks like this

using Statistics 
Mul(a,b,c) = b .* reshape(c, (1,1,size(c)[a],1)) 
Add(axis, A ,B) = A .+ reshape(B, (1,1,size(B)[1],1)) 
begin
    c_1 = BatchNorm(identity, weights["fc1_beta"], weights["fc1_gamma"], broadcast(Float32, weights["fc1_moving_mean"]), broadcast(Float32, broadcast(sqrt, broadcast(+, 2.0f-5, weights["fc1_moving_var"]))), 2.0f-5, 0.9f0, false)
    c_2 = BatchNorm(identity, weights["bn1_beta"], weights["bn1_gamma"], broadcast(Float32, weights["bn1_moving_mean"]), broadcast(Float32, broadcast(sqrt, broadcast(+, 2.0f-5, weights["bn1_moving_var"]))), 2.0f-5, 0.9f0, false)
    c_3 = BatchNorm(identity, weights["stage4_unit3_bn3_beta"], weights["stage4_unit3_bn3_gamma"], broadcast(Float32, weights["stage4_unit3_bn3_moving_mean"]), broadcast(Float32, broadcast(sqrt, broadcast(+, 2.0f-5, weights["stage4_unit3_bn3_moving_var"]))), 2.0f-5, 0.9f0, false)
    c_4 = CrossCor(weights["stage4_unit3_conv2_weight"], Float32[0.0], relu, var"stride=(1, 1)", var"pad=(1, 1, 1, 1)", var"dilation=(1, 1)")

When trying to load,

using Flux, ONNX
weights = ONNX.load_weights("weights.bson")
model = include("model.jl")

get an error due to BatchNorm

model = include("model.jl")
ERROR: LoadError: MethodError: no method matching BatchNorm(::typeof(identity), ::Base.ReinterpretArray{Float32,1,Float32,Array{Float32,1}}, ::Base.ReinterpretArray{Float32,1,Float32,Array{Float32,1}}, ::Array{Float32,1}, ::Array{Float32,1}, ::Float32, ::Float32, ::Bool)
Closest candidates are:
  BatchNorm(::F, ::V, ::V, ::W, ::W, ::N, ::N) where {F, V, W, N} at /home/timclements/.julia/packages/Flux/2i5P1/src/layers/normalise.jl:123
  BatchNorm(::Integer, ::Any; initβ, initγ, ϵ, momentum) at /home/timclements/.julia/packages/Flux/2i5P1/src/layers/normalise.jl:133

This looks very similar to #17 but I am using ONNX v0.1.1, which should have fixed this. The error is being thrown from Flux, so the ONNX version of BatchNorm with identity is not being called.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions