diff --git a/Project.toml b/Project.toml index 9702cfde1..671a2914c 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,7 @@ Accessors = "0.1" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7" AdvancedMH = "0.8" AdvancedPS = "0.6.0" -AdvancedVI = "0.2" +AdvancedVI = "0.3.1" BangBang = "0.4.2" Bijectors = "0.14, 0.15" Compat = "4.15.0" diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 189d3f700..e44923721 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -1,50 +1,170 @@ + module Variational -using DistributionsAD: DistributionsAD -using DynamicPPL: DynamicPPL -using StatsBase: StatsBase -using StatsFuns: StatsFuns -using LogDensityProblems: LogDensityProblems +using DynamicPPL +using ADTypes using Distributions +using LinearAlgebra +using LogDensityProblems +using Random -using Random: Random +import ..Turing: DEFAULT_ADTYPE, PROGRESS import AdvancedVI import Bijectors # Reexports -using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad -export vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad - -""" - make_logjoint(model::Model; weight = 1.0) -Constructs the logjoint as a function of latent variables, i.e. the map z → p(x ∣ z) p(z). -The weight used to scale the likelihood, e.g. when doing stochastic gradient descent one needs to -use `DynamicPPL.MiniBatch` context to run the `Model` with a weight `num_total_obs / batch_size`. -## Notes -- For sake of efficiency, the returned function is closes over an instance of `VarInfo`. This means that you *might* run into some weird behaviour if you call this method sequentially using different types; if that's the case, just generate a new one for each type using `make_logjoint`. -""" -function make_logjoint(model::DynamicPPL.Model; weight=1.0) - # setup +using AdvancedVI: RepGradELBO, ScoreGradELBO, DoG, DoWG +export vi, RepGradELBO, ScoreGradELBO, DoG, DoWG + +export meanfield_gaussian, fullrank_gaussian + +include("bijectors.jl") + +function make_logdensity(model::DynamicPPL.Model) + weight = 1.0 ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight) - f = DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) - return Base.Fix1(LogDensityProblems.logdensity, f) + return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) end -# objectives -function (elbo::ELBO)( +function initialize_gaussian_scale( rng::Random.AbstractRNG, - alg::AdvancedVI.VariationalInference, - q, model::DynamicPPL.Model, - num_samples; - weight=1.0, + location::AbstractVector, + scale::AbstractMatrix; + num_samples::Int=10, + num_max_trials::Int=10, + reduce_factor=one(eltype(scale)) / 2, +) + prob = make_logdensity(model) + ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) + varinfo = DynamicPPL.VarInfo(model) + + n_trial = 0 + while true + q = AdvancedVI.MvLocationScale(location, scale, Normal()) + b = Bijectors.bijector(model; varinfo=varinfo) + q_trans = Bijectors.transformed(q, Bijectors.inverse(b)) + energy = mean(ℓπ, eachcol(rand(rng, q_trans, num_samples))) + + if isfinite(energy) + return scale + elseif n_trial == num_max_trials + error("Could not find an initial") + end + + scale = reduce_factor * scale + n_trial += 1 + end +end + +function meanfield_gaussian( + rng::Random.AbstractRNG, + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:Diagonal}=nothing, + kwargs..., +) + varinfo = DynamicPPL.VarInfo(model) + # Use linked `varinfo` to determine the correct number of parameters. + # TODO: Replace with `length` once this is implemented for `VarInfo`. + varinfo_linked = DynamicPPL.link(varinfo, model) + num_params = length(varinfo_linked[:]) + + μ = if isnothing(location) + zeros(num_params) + else + @assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)." + location + end + + L = if isnothing(scale) + initialize_gaussian_scale(rng, model, μ, Diagonal(ones(num_params)); kwargs...) + else + @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." + L = scale + end + + q = AdvancedVI.MeanFieldGaussian(μ, L) + b = Bijectors.bijector(model; varinfo=varinfo) + return Bijectors.transformed(q, Bijectors.inverse(b)) +end + +function meanfield_gaussian( + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:Diagonal}=nothing, kwargs..., ) - return elbo(rng, alg, q, make_logjoint(model; weight=weight), num_samples; kwargs...) + return meanfield_gaussian(Random.default_rng(), model; location, scale, kwargs...) +end + +function fullrank_gaussian( + rng::Random.AbstractRNG, + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:LowerTriangular}=nothing, + kwargs..., +) + varinfo = DynamicPPL.VarInfo(model) + # Use linked `varinfo` to determine the correct number of parameters. + # TODO: Replace with `length` once this is implemented for `VarInfo`. + varinfo_linked = DynamicPPL.link(varinfo, model) + num_params = length(varinfo_linked[:]) + + μ = if isnothing(location) + zeros(num_params) + else + @assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)." + location + end + + L = if isnothing(scale) + L0 = LowerTriangular(Matrix{Float64}(I, num_params, num_params)) + initialize_gaussian_scale(rng, model, μ, L0; kwargs...) + else + @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." + scale + end + + q = AdvancedVI.FullRankGaussian(μ, L) + b = Bijectors.bijector(model; varinfo=varinfo) + return Bijectors.transformed(q, Bijectors.inverse(b)) end -# VI algorithms -include("advi.jl") +function fullrank_gaussian( + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:LowerTriangular}=nothing, + kwargs..., +) + return fullrank_gaussian(Random.default_rng(), model; location, scale, kwargs...) +end + +function vi( + model::DynamicPPL.Model, + q::Bijectors.TransformedDistribution, + n_iterations::Int; + objective=RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), + show_progress::Bool=PROGRESS[], + optimizer=AdvancedVI.DoWG(), + averager=AdvancedVI.PolynomialAveraging(), + operator=AdvancedVI.ProximalLocationScaleEntropy(), + adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, + kwargs..., +) + return AdvancedVI.optimize( + make_logdensity(model), + objective, + q, + n_iterations; + show_progress=show_progress, + adtype, + optimizer, + averager, + operator, + kwargs..., + ) +end end diff --git a/src/variational/advi.jl b/src/variational/advi.jl deleted file mode 100644 index ec3e6552e..000000000 --- a/src/variational/advi.jl +++ /dev/null @@ -1,140 +0,0 @@ -# TODO: Move to Bijectors.jl if we find further use for this. -""" - wrap_in_vec_reshape(f, in_size) - -Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces -a vector of length `prod(Bijectors.output(f, in_size))`. -""" -function wrap_in_vec_reshape(f, in_size) - vec_in_length = prod(in_size) - reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) - out_size = Bijectors.output_size(f, in_size) - vec_out_length = prod(out_size) - reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) - return reshape_outer ∘ f ∘ reshape_inner -end - -""" - bijector(model::Model[, sym2ranges = Val(false)]) - -Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` -denoting the dimensionality of the latent variables. -""" -function Bijectors.bijector( - model::DynamicPPL.Model, ::Val{sym2ranges}=Val(false); varinfo=DynamicPPL.VarInfo(model) -) where {sym2ranges} - num_params = sum([ - size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata) - ]) - - dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) - - num_ranges = sum([ - length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) - ]) - ranges = Vector{UnitRange{Int}}(undef, num_ranges) - idx = 0 - range_idx = 1 - - # ranges might be discontinuous => values are vectors of ranges rather than just ranges - sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() - for sym in keys(varinfo.metadata) - sym_lookup[sym] = Vector{UnitRange{Int}}() - for r in varinfo.metadata[sym].ranges - ranges[range_idx] = idx .+ r - push!(sym_lookup[sym], ranges[range_idx]) - range_idx += 1 - end - - idx += varinfo.metadata[sym].ranges[end][end] - end - - bs = map(tuple(dists...)) do d - b = Bijectors.bijector(d) - if d isa Distributions.UnivariateDistribution - b - else - wrap_in_vec_reshape(b, size(d)) - end - end - - if sym2ranges - return ( - Bijectors.Stacked(bs, ranges), - (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), - ) - else - return Bijectors.Stacked(bs, ranges) - end -end - -""" - meanfield([rng, ]model::Model) - -Creates a mean-field approximation with multivariate normal as underlying distribution. -""" -meanfield(model::DynamicPPL.Model) = meanfield(Random.default_rng(), model) -function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model) - # Setup. - varinfo = DynamicPPL.VarInfo(model) - # Use linked `varinfo` to determine the correct number of parameters. - # TODO: Replace with `length` once this is implemented for `VarInfo`. - varinfo_linked = DynamicPPL.link(varinfo, model) - num_params = length(varinfo_linked[:]) - - # initial params - μ = randn(rng, num_params) - σ = StatsFuns.softplus.(randn(rng, num_params)) - - # Construct the base family. - d = DistributionsAD.TuringDiagMvNormal(μ, σ) - - # Construct the bijector constrained → unconstrained. - b = Bijectors.bijector(model; varinfo=varinfo) - - # We want to transform from unconstrained space to constrained, - # hence we need the inverse of `b`. - return Bijectors.transformed(d, Bijectors.inverse(b)) -end - -# Overloading stuff from `AdvancedVI` to specialize for Turing -function AdvancedVI.update(d::DistributionsAD.TuringDiagMvNormal, μ, σ) - return DistributionsAD.TuringDiagMvNormal(μ, σ) -end -function AdvancedVI.update(td::Bijectors.TransformedDistribution, θ...) - return Bijectors.transformed(AdvancedVI.update(td.dist, θ...), td.transform) -end -function AdvancedVI.update( - td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}, - θ::AbstractArray, -) - # `length(td.dist) != length(td)` if `td.transform` changes the dimensionality, - # so we need to use the length of the underlying distribution `td.dist` here. - # TODO: Check if we can get away with `view` instead of `getindex` for all AD backends. - μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end] - return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω)) -end - -function AdvancedVI.vi( - model::DynamicPPL.Model, alg::AdvancedVI.ADVI; optimizer=AdvancedVI.TruncatedADAGrad() -) - q = meanfield(model) - return AdvancedVI.vi(model, alg, q; optimizer=optimizer) -end - -function AdvancedVI.vi( - model::DynamicPPL.Model, - alg::AdvancedVI.ADVI, - q::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}; - optimizer=AdvancedVI.TruncatedADAGrad(), -) - # Initial parameters for mean-field approx - μ, σs = StatsBase.params(q) - θ = vcat(μ, StatsFuns.invsoftplus.(σs)) - - # Optimize - AdvancedVI.optimize!(elbo, alg, q, make_logjoint(model), θ; optimizer=optimizer) - - # Return updated `Distribution` - return AdvancedVI.update(q, θ) -end diff --git a/src/variational/bijectors.jl b/src/variational/bijectors.jl new file mode 100644 index 000000000..86078efaa --- /dev/null +++ b/src/variational/bijectors.jl @@ -0,0 +1,72 @@ + +# TODO: Move to Bijectors.jl if we find further use for this. +""" + wrap_in_vec_reshape(f, in_size) + +Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces +a vector of length `prod(Bijectors.output(f, in_size))`. +""" +function wrap_in_vec_reshape(f, in_size) + vec_in_length = prod(in_size) + reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) + out_size = Bijectors.output_size(f, in_size) + vec_out_length = prod(out_size) + reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) + return reshape_outer ∘ f ∘ reshape_inner +end + +""" + bijector(model::Model[, sym2ranges = Val(false)]) + +Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` +denoting the dimensionality of the latent variables. +""" +function Bijectors.bijector( + model::DynamicPPL.Model, + (::Val{sym2ranges})=Val(false); + varinfo=DynamicPPL.VarInfo(model), +) where {sym2ranges} + num_params = sum([ + size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata) + ]) + + dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) + + num_ranges = sum([ + length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) + ]) + ranges = Vector{UnitRange{Int}}(undef, num_ranges) + idx = 0 + range_idx = 1 + + # ranges might be discontinuous => values are vectors of ranges rather than just ranges + sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() + for sym in keys(varinfo.metadata) + sym_lookup[sym] = Vector{UnitRange{Int}}() + for r in varinfo.metadata[sym].ranges + ranges[range_idx] = idx .+ r + push!(sym_lookup[sym], ranges[range_idx]) + range_idx += 1 + end + + idx += varinfo.metadata[sym].ranges[end][end] + end + + bs = map(tuple(dists...)) do d + b = Bijectors.bijector(d) + if d isa Distributions.UnivariateDistribution + b + else + wrap_in_vec_reshape(b, size(d)) + end + end + + if sym2ranges + return ( + Bijectors.Stacked(bs, ranges), + (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), + ) + else + return Bijectors.Stacked(bs, ranges) + end +end diff --git a/test/Project.toml b/test/Project.toml index 36b7ebdec..f1d829ae1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ AbstractMCMC = "5" AbstractPPL = "0.9, 0.10" AdvancedMH = "0.6, 0.7, 0.8" AdvancedPS = "=0.6.0" -AdvancedVI = "0.2" +AdvancedVI = "0.3" Aqua = "0.8" BangBang = "0.4" Bijectors = "0.14, 0.15"