Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hessian vector products with moderately complex models #1813

Open
colinxs opened this issue Dec 4, 2019 · 2 comments
Open

Hessian vector products with moderately complex models #1813

colinxs opened this issue Dec 4, 2019 · 2 comments

Comments

@colinxs
Copy link

colinxs commented Dec 4, 2019

I'm attempting to compute Hessian vector products for use with RL algorithms like Natural Policy Gradient or TRPO, but have been entirely unsuccessful.

Following FluxML/Zygote.jl#115, https://github.com/JuliaDiffEq/SparseDiffTools.jl, and elsewhere I was able to compute HVPs for simple models parameterized by a single Array, but the following appears to have issues inferring the type of Dual.

Any help would be greatly appreciated! :)

# Zygote v0.4.1, Flux v0.10.0, ForwardDiff v0.10.7, DiffRules 0.1.0, ZygoteRules 0.2.0

# Julia Version 1.3.0
# Commit 46ce4d7933 (2019-11-26 06:09 UTC)
# Platform Info:
#   OS: Linux (x86_64-pc-linux-gnu)
#   CPU: Intel(R) Core(TM) i9-7960X CPU @ 2.80GHz
#   WORD_SIZE: 64
#   LIBM: libopenlibm
#   LLVM: libLLVM-6.0.1 (ORCJIT, skylake)

using Flux, ForwardDiff, Zygote
using LinearAlgebra

# A Gaussian policy with diagonal covariance
struct DiagGaussianPolicy{M,L<:AbstractVector}
    meanNN::M
    logstd::L
end

Flux.@functor DiagGaussianPolicy

(policy::DiagGaussianPolicy)(features) = policy.meanNN(features)

# log(pi_theta(a | s))
function loglikelihood(P::DiagGaussianPolicy, feature::AbstractVector, action::AbstractVector)
    meanact = P(feature)
    ll = -length(P.logstd) * log(2pi) / 2
    for i = 1:length(action)
        ll -= ((meanact[i] - action[i]) / exp(P.logstd[i]))^2 / 2
        ll -= P.logstd[i]
    end
    ll
end

function flatgrad(f, ps)
    gs = Zygote.gradient(f, ps)
    vcat([vec(gs[p]) for p in ps]...)
end

Base.length(ps::Params) = 228 #sum(length, ps)
Base.size(ps::Params) = (228, ) #(length(ps), )
Base.eltype(ps::Params) = Float32

function hessian_vector_product(f,ps,v)
    g = let f=f
        ps -> flatgrad(f, ps)::Vector{Float32}
    end
    gvp = let g=g, v=v
        ps -> (g(ps)v)::Vector{Float32}
    end
    Zygote.forward_jacobian(gvp, ps)[2]
end

function test()
    policy = Flux.paramtype(Float32, DiagGaussianPolicy(Flux.Chain(Dense(4, 32), Dense(32, 2)), zeros(2)))
    ps = Flux.params(policy)
    v = rand(Float32, sum(length, ps))
    feat = rand(Float32, 4)
    act = rand(Float32, 2)
    f = let policy=policy, feat=feat, act=act
        () -> loglikelihood(policy, feat, act)
    end
    hessian_vector_product(f, ps, v)
end

Calling test() yields:

an_dual.
Stacktrace:
 [1] throw_cannot_dual(::Type) at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:36
 [2] ForwardDiff.Dual{Nothing,Any,12}(::Array{Float32,2}, ::ForwardDiff.Partials{12,Any}) at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:18
 [3] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:55 [inlined]
 [4] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:62 [inlined]
 [5] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:68 [inlined]
 [6] (::Zygote.var"#1565#1567"{12,Int64})(::Array{Float32,2}, ::Int64) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:8
 [7] (::Base.var"#3#4"{Zygote.var"#1565#1567"{12,Int64}})(::Tuple{Array{Float32,2},Int64}) at ./generator.jl:36
 [8] iterate at ./generator.jl:47 [inlined]
 [9] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Params,UnitRange{Int64}}},Base.var"#3#4"{Zygote.var"#1565#1567"{12,Int64}}}) at ./array.jl:622
 [10] map at ./abstractarray.jl:2155 [inlined]
 [11] seed at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:7 [inlined] (repeats 2 times)
 [12] forward_jacobian(::var"#340#342"{var"#339#341"{var"#343#344"{DiagGaussianPolicy{Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},Array{Float32,1}},
Array{Float32,1},Array{Float32,1}}},Array{Float32,1}}, ::Params, ::Val{12}) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:23
 [13] forward_jacobian(::Function, ::Params) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:40
 [14] hessian_vector_product(::Function, ::Params, ::Array{Float32,1}) at /home/colinxs/workspace/dev/SharedExperiments/lyceum/hvp.jl:52
 [15] test() at /home/colinxs/workspace/dev/SharedExperiments/lyceum/hvp.jl:64
 [16] top-level scope at REPL[41]:1
@colinxs
Copy link
Author

colinxs commented Dec 6, 2019

So with some more fiddling and realizing that ForwardDiff only works with AbstractArray inputs I was able to get the above working (not yet checked for correctness). I was able to get around this through a combination of RecursiveArrayTools and Flux.fmap. I also had to modify the loglikelihood expression to get rid of the .^2 expression, which appears to be related to FluxML/Zygote.jl#405.

While I haven't checked the result for correctness, the solution itself isn't too ugly :).

I was able to make this work without the custom seeding and instead explicitly calculating the vector product, but the solution is messier/slower so I won't bother posting it.

I'll circle back once I've cleaned things up a bit and verified the result is correct, but until then if anyone has suggestions do let me know! This is fairly critical ability for anyone doing research in ML, RL, etc.

using Flux, ForwardDiff, Zygote, RecursiveArrayTools, Random, LinearAlgebra
using Zygote: Params, Grads
using MacroTools: @forward

# A Gaussian policy with diagonal covariance
struct DiagGaussianPolicy{M,L<:AbstractVector}
    meanNN::M
    logstd::L
end

Flux.@functor DiagGaussianPolicy

(policy::DiagGaussianPolicy)(features) = policy.meanNN(features)

# log(pi_theta(a | s))
function loglikelihood(P::DiagGaussianPolicy, feature::AbstractVector, action::AbstractVector)
    meanact = P(feature)
    # broken (possibly related to https://github.com/FluxML/Zygote.jl/issues/405)
    #zs = ((meanact .- action) ./ exp.(P.logstd)) .^ 2
    # works
    zs = (meanact .- action) ./ exp.(P.logstd)
    zs = zs .* zs

    ll = -sum(zs)/2 - sum(P.logstd) - length(P.logstd) * log(2pi) / 2
    ll
end

flatgrad(gs::Grads, ps::Params) = ArrayPartition((gs[p] for p in ps if !isnothing(gs[p]))...)

function flat_hessian_vector_product(feat, act, policy, vs::ArrayPartition)
    ps = Flux.params(policy)

    i = 1
    dualpol = Flux.fmap(policy) do p
        if p in ps.params
            p = ForwardDiff.Dual{Nothing}.(p, vs.x[i])
            i += 1
        end
        p
    end
    dualps = params(dualpol)

    G = let feat=feat, act=act
        function (ps)
            gs = gradient(() -> loglikelihood(dualpol, feat, act), ps)
            flatgrad(gs, ps)
        end
    end

    ForwardDiff.partials.(G(dualps), 1)
end


function test_flathvp(T::DataType=Float32)
    Random.seed!(1)

    dobs, dact = 4, 2
    policy = DiagGaussianPolicy(Chain(Dense(dobs, 32), Dense(32, 32), Dense(32, dact)), zeros(dact))
    policy = Flux.paramtype(T, policy)

    v = ArrayPartition((rand(size(p)...) for p in params(policy))...)
    feat = rand(T, 4)
    act = rand(T, 2)

    @time flat_hessian_vector_product(feat, act, policy, v)
end

@ToucheSir ToucheSir transferred this issue from FluxML/Zygote.jl Dec 22, 2021
@YichengDWu
Copy link

Can you try using Lux and ComponentArrays?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants