-
-
Notifications
You must be signed in to change notification settings - Fork 604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hessian vector products with moderately complex models #1813
Comments
So with some more fiddling and realizing that ForwardDiff only works with While I haven't checked the result for correctness, the solution itself isn't too ugly :). I was able to make this work without the custom seeding and instead explicitly calculating the vector product, but the solution is messier/slower so I won't bother posting it. I'll circle back once I've cleaned things up a bit and verified the result is correct, but until then if anyone has suggestions do let me know! This is fairly critical ability for anyone doing research in ML, RL, etc. using Flux, ForwardDiff, Zygote, RecursiveArrayTools, Random, LinearAlgebra
using Zygote: Params, Grads
using MacroTools: @forward
# A Gaussian policy with diagonal covariance
struct DiagGaussianPolicy{M,L<:AbstractVector}
meanNN::M
logstd::L
end
Flux.@functor DiagGaussianPolicy
(policy::DiagGaussianPolicy)(features) = policy.meanNN(features)
# log(pi_theta(a | s))
function loglikelihood(P::DiagGaussianPolicy, feature::AbstractVector, action::AbstractVector)
meanact = P(feature)
# broken (possibly related to https://github.com/FluxML/Zygote.jl/issues/405)
#zs = ((meanact .- action) ./ exp.(P.logstd)) .^ 2
# works
zs = (meanact .- action) ./ exp.(P.logstd)
zs = zs .* zs
ll = -sum(zs)/2 - sum(P.logstd) - length(P.logstd) * log(2pi) / 2
ll
end
flatgrad(gs::Grads, ps::Params) = ArrayPartition((gs[p] for p in ps if !isnothing(gs[p]))...)
function flat_hessian_vector_product(feat, act, policy, vs::ArrayPartition)
ps = Flux.params(policy)
i = 1
dualpol = Flux.fmap(policy) do p
if p in ps.params
p = ForwardDiff.Dual{Nothing}.(p, vs.x[i])
i += 1
end
p
end
dualps = params(dualpol)
G = let feat=feat, act=act
function (ps)
gs = gradient(() -> loglikelihood(dualpol, feat, act), ps)
flatgrad(gs, ps)
end
end
ForwardDiff.partials.(G(dualps), 1)
end
function test_flathvp(T::DataType=Float32)
Random.seed!(1)
dobs, dact = 4, 2
policy = DiagGaussianPolicy(Chain(Dense(dobs, 32), Dense(32, 32), Dense(32, dact)), zeros(dact))
policy = Flux.paramtype(T, policy)
v = ArrayPartition((rand(size(p)...) for p in params(policy))...)
feat = rand(T, 4)
act = rand(T, 2)
@time flat_hessian_vector_product(feat, act, policy, v)
end
|
Can you try using Lux and ComponentArrays? |
I'm attempting to compute Hessian vector products for use with RL algorithms like Natural Policy Gradient or TRPO, but have been entirely unsuccessful.
Following FluxML/Zygote.jl#115, https://github.com/JuliaDiffEq/SparseDiffTools.jl, and elsewhere I was able to compute HVPs for simple models parameterized by a single
Array
, but the following appears to have issues inferring the type ofDual
.Any help would be greatly appreciated! :)
Calling
test()
yields:The text was updated successfully, but these errors were encountered: