diff --git a/HISTORY.md b/HISTORY.md index 77675f424e..1265fafabc 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,19 @@ # Release 0.39.0 +## Update to the AdvancedVI interface + +Turing's variational inference interface was updated to match version 0.4 version of AdvancedVI.jl. + +AdvancedVI v0.4 introduces various new features: + + - location-scale families with dense scale matrices, + - parameter-free stochastic optimization algorithms like `DoG` and `DoWG`, + - proximal operators for stable optimization, + - the sticking-the-landing control variate for faster convergence, and + - the score gradient estimator for non-differentiable targets. + +Please see the [Turing API documentation](https://turinglang.org/Turing.jl/stable/api/#Variational-inference), and [AdvancedVI's documentation](https://turinglang.org/AdvancedVI.jl/stable/), for more details. + ## Removal of Turing.Essential The Turing.Essential module has been removed. diff --git a/Project.toml b/Project.toml index a82e5f437b..4954782948 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,7 @@ Accessors = "0.1" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7" AdvancedMH = "0.8" AdvancedPS = "0.6.0" -AdvancedVI = "0.2" +AdvancedVI = "0.4" BangBang = "0.4.2" Bijectors = "0.14, 0.15" Compat = "4.15.0" diff --git a/docs/make.jl b/docs/make.jl index 978e5881b3..af24e7b1ec 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -23,8 +23,11 @@ makedocs(; pages=[ "Home" => "index.md", "API" => "api.md", - "Submodule APIs" => - ["Inference" => "api/Inference.md", "Optimisation" => "api/Optimisation.md"], + "Submodule APIs" => [ + "Inference" => "api/Inference.md", + "Optimisation" => "api/Optimisation.md", + "Variational " => "api/Variational.md", + ], ], checkdocs=:exports, doctest=false, diff --git a/docs/src/api.md b/docs/src/api.md index 1633be75d2..0b8351eb3b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -77,12 +77,14 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu ### Variational inference -See the [variational inference tutorial](https://turinglang.org/docs/tutorials/09-variational-inference/) for a walkthrough on how to use these. - -| Exported symbol | Documentation | Description | -|:--------------- |:---------------------------- |:--------------------------------------- | -| `vi` | [`AdvancedVI.vi`](@extref) | Perform variational inference | -| `ADVI` | [`AdvancedVI.ADVI`](@extref) | Construct an instance of a VI algorithm | +See the [docs of AdvancedVI.jl](https://turinglang.org/AdvancedVI.jl/stable/) for detailed usage and the [variational inference tutorial](https://turinglang.org/docs/tutorials/09-variational-inference/) for a basic walkthrough. + +| Exported symbol | Documentation | Description | +|:---------------------- |:------------------------------------------------- |:---------------------------------------------------------------------------------------- | +| `vi` | [`Turing.vi`](@ref) | Perform variational inference | +| `q_locationscale` | [`Turing.Variational.q_locationscale`](@ref) | Find a numerically non-degenerate initialization for a location-scale variational family | +| `q_meanfield_gaussian` | [`Turing.Variational.q_meanfield_gaussian`](@ref) | Find a numerically non-degenerate initialization for a mean-field Gaussian family | +| `q_fullrank_gaussian` | [`Turing.Variational.q_fullrank_gaussian`](@ref) | Find a numerically non-degenerate initialization for a full-rank Gaussian family | ### Automatic differentiation types diff --git a/docs/src/api/Variational.md b/docs/src/api/Variational.md new file mode 100644 index 0000000000..382efe7e18 --- /dev/null +++ b/docs/src/api/Variational.md @@ -0,0 +1,6 @@ +# API: `Turing.Variational` + +```@autodocs +Modules = [Turing.Variational] +Order = [:type, :function] +``` diff --git a/src/Turing.jl b/src/Turing.jl index 8362e0c97d..1ff2310174 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -39,8 +39,6 @@ function setprogress!(progress::Bool) @info "[Turing]: progress logging is $(progress ? "enabled" : "disabled") globally" PROGRESS[] = progress AbstractMCMC.setprogress!(progress; silent=true) - # TODO: `AdvancedVI.turnprogress` is removed in AdvancedVI v0.3 - AdvancedVI.turnprogress(progress) return progress end @@ -118,6 +116,9 @@ export # Variational inference - AdvancedVI vi, ADVI, + q_locationscale, + q_meanfield_gaussian, + q_fullrank_gaussian, # ADTypes AutoForwardDiff, AutoReverseDiff, diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 189d3f7001..b9428af112 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -1,50 +1,329 @@ + module Variational -using DistributionsAD: DistributionsAD -using DynamicPPL: DynamicPPL -using StatsBase: StatsBase -using StatsFuns: StatsFuns -using LogDensityProblems: LogDensityProblems +using DynamicPPL +using ADTypes using Distributions +using LinearAlgebra +using LogDensityProblems +using Random -using Random: Random +import ..Turing: DEFAULT_ADTYPE, PROGRESS import AdvancedVI import Bijectors -# Reexports -using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad -export vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad +export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian + +include("deprecated.jl") + +function make_logdensity(model::DynamicPPL.Model) + weight = 1.0 + ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight) + return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) +end """ - make_logjoint(model::Model; weight = 1.0) -Constructs the logjoint as a function of latent variables, i.e. the map z → p(x ∣ z) p(z). -The weight used to scale the likelihood, e.g. when doing stochastic gradient descent one needs to -use `DynamicPPL.MiniBatch` context to run the `Model` with a weight `num_total_obs / batch_size`. -## Notes -- For sake of efficiency, the returned function is closes over an instance of `VarInfo`. This means that you *might* run into some weird behaviour if you call this method sequentially using different types; if that's the case, just generate a new one for each type using `make_logjoint`. + q_initialize_scale( + [rng::Random.AbstractRNG,] + model::DynamicPPL.Model, + location::AbstractVector, + scale::AbstractMatrix, + basedist::Distributions.UnivariateDistribution; + num_samples::Int = 10, + num_max_trials::Int = 10, + reduce_factor::Real = one(eltype(scale)) / 2 + ) + +Given an initial location-scale distribution `q` formed by `location`, `scale`, and `basedist`, shrink `scale` until the expectation of log-densities of `model` taken over `q` are finite. +If the log-densities are not finite even after `num_max_trials`, throw an error. + +For reference, a location-scale distribution \$q\$ formed by `location`, `scale`, and `basedist` is a distribution where its sampling process \$z \\sim q\$ can be represented as +```julia +u = rand(basedist, d) +z = scale * u + location +``` + +# Arguments +- `model`: The target `DynamicPPL.Model`. +- `location`: The location parameter of the initialization. +- `scale`: The scale parameter of the initialization. +- `basedist`: The base distribution of the location-scale family. + +# Keyword Arguments +- `num_samples`: Number of samples used to compute the average log-density at each trial. +- `num_max_trials`: Number of trials until throwing an error. +- `reduce_factor`: Factor for shrinking the scale. After `n` trials, the scale is then `scale*reduce_factor^n`. + +# Returns +- `scale_adj`: The adjusted scale matrix matching the type of `scale`. """ -function make_logjoint(model::DynamicPPL.Model; weight=1.0) - # setup - ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight) - f = DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) - return Base.Fix1(LogDensityProblems.logdensity, f) +function q_initialize_scale( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + location::AbstractVector, + scale::AbstractMatrix, + basedist::Distributions.UnivariateDistribution; + num_samples::Int=10, + num_max_trials::Int=10, + reduce_factor::Real=one(eltype(scale)) / 2, +) + prob = make_logdensity(model) + ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) + varinfo = DynamicPPL.VarInfo(model) + + n_trial = 0 + while true + q = AdvancedVI.MvLocationScale(location, scale, basedist) + b = Bijectors.bijector(model; varinfo=varinfo) + q_trans = Bijectors.transformed(q, Bijectors.inverse(b)) + energy = mean(ℓπ, eachcol(rand(rng, q_trans, num_samples))) + + if isfinite(energy) + return scale + elseif n_trial == num_max_trials + error("Could not find an initial") + end + + scale = reduce_factor * scale + n_trial += 1 + end end -# objectives -function (elbo::ELBO)( +""" + q_locationscale( + [rng::Random.AbstractRNG,] + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector} = nothing, + scale::Union{Nothing,<:Diagonal,<:LowerTriangular} = nothing, + meanfield::Bool = true, + basedist::Distributions.UnivariateDistribution = Normal() + ) + +Find a numerically non-degenerate variational distribution `q` for approximating the target `model` within the location-scale variational family formed by the type of `scale` and `basedist`. + +The distribution can be manually specified by setting `location`, `scale`, and `basedist`. +Otherwise, it chooses a standard Gaussian by default. +Whether the default choice is used or not, the `scale` may be adjusted via `q_initialize_scale` so that the log-densities of `model` are finite over the samples from `q`. +If `meanfield` is set as `true`, the scale of `q` is restricted to be a diagonal matrix and only the diagonal of `scale` is used. + +For reference, a location-scale distribution \$q\$ formed by `location`, `scale`, and `basedist` is a distribution where its sampling process \$z \\sim q\$ can be represented as +```julia +u = rand(basedist, d) +z = scale * u + location +``` + +# Arguments +- `model`: The target `DynamicPPL.Model`. + +# Keyword Arguments +- `location`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. +- `scale`: The scale parameter of the initialization. If `nothing`, an identity matrix is used. +- `meanfield`: Whether to use the mean-field approximation. If `true`, `scale` is converted into a `Diagonal` matrix. Otherwise, it is converted into a `LowerTriangular` matrix. +- `basedist`: The base distribution of the location-scale family. + +The remaining keywords are passed to `q_initialize_scale`. + +# Returns +- `q::Bijectors.TransformedDistribution`: A `AdvancedVI.LocationScale` distribution matching the support of `model`. +""" +function q_locationscale( + rng::Random.AbstractRNG, + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:Diagonal,<:LowerTriangular}=nothing, + meanfield::Bool=true, + basedist::Distributions.UnivariateDistribution=Normal(), + kwargs..., +) + varinfo = DynamicPPL.VarInfo(model) + # Use linked `varinfo` to determine the correct number of parameters. + # TODO: Replace with `length` once this is implemented for `VarInfo`. + varinfo_linked = DynamicPPL.link(varinfo, model) + num_params = length(varinfo_linked[:]) + + μ = if isnothing(location) + zeros(num_params) + else + @assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)." + location + end + + L = if isnothing(scale) + if meanfield + q_initialize_scale(rng, model, μ, Diagonal(ones(num_params)), basedist; kwargs...) + else + L0 = LowerTriangular(Matrix{Float64}(I, num_params, num_params)) + q_initialize_scale(rng, model, μ, L0, basedist; kwargs...) + end + else + @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." + if meanfield + Diagonal(diag(scale)) + else + LowerTriangular(Matrix(scale)) + end + end + q = AdvancedVI.MvLocationScale(μ, L, basedist) + b = Bijectors.bijector(model; varinfo=varinfo) + return Bijectors.transformed(q, Bijectors.inverse(b)) +end + +function q_locationscale(model::DynamicPPL.Model; kwargs...) + return q_locationscale(Random.default_rng(), model; kwargs...) +end + +""" + q_meanfield_gaussian( + [rng::Random.AbstractRNG,] + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector} = nothing, + scale::Union{Nothing,<:Diagonal} = nothing, + kwargs... + ) + +Find a numerically non-degenerate mean-field Gaussian `q` for approximating the target `model`. + +# Arguments +- `model`: The target `DynamicPPL.Model`. + +# Keyword Arguments +- `location`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. +- `scale`: The scale parameter of the initialization. If `nothing`, an identity matrix is used. + +The remaining keyword arguments are passed to `q_locationscale`. + +# Returns +- `q::Bijectors.TransformedDistribution`: A `AdvancedVI.LocationScale` distribution matching the support of `model`. +""" +function q_meanfield_gaussian( + rng::Random.AbstractRNG, + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:Diagonal}=nothing, + kwargs..., +) + return q_locationscale( + rng, model; location, scale, meanfield=true, basedist=Normal(), kwargs... + ) +end + +function q_meanfield_gaussian(model::DynamicPPL.Model; kwargs...) + return q_meanfield_gaussian(Random.default_rng(), model; kwargs...) +end + +""" + q_fullrank_gaussian( + [rng::Random.AbstractRNG,] + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector} = nothing, + scale::Union{Nothing,<:LowerTriangular} = nothing, + kwargs... + ) + +Find a numerically non-degenerate Gaussian `q` with a scale with full-rank factors (traditionally referred to as a "full-rank family") for approximating the target `model`. + +# Arguments +- `model`: The target `DynamicPPL.Model`. + +# Keyword Arguments +- `location`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. +- `scale`: The scale parameter of the initialization. If `nothing`, an identity matrix is used. + +The remaining keyword arguments are passed to `q_locationscale`. + +# Returns +- `q::Bijectors.TransformedDistribution`: A `AdvancedVI.LocationScale` distribution matching the support of `model`. +""" +function q_fullrank_gaussian( + rng::Random.AbstractRNG, + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:LowerTriangular}=nothing, + kwargs..., +) + return q_locationscale( + rng, model; location, scale, meanfield=false, basedist=Normal(), kwargs... + ) +end + +function q_fullrank_gaussian(model::DynamicPPL.Model; kwargs...) + return q_fullrank_gaussian(Random.default_rng(), model; kwargs...) +end + +""" + vi( + [rng::Random.AbstractRNG,] + model::DynamicPPL.Model; + q, + n_iterations::Int; + objective::AdvancedVI.AbstractVariationalObjective = AdvancedVI.RepGradELBO( + 10; entropy = AdvancedVI.ClosedFormEntropyZeroGradient() + ), + show_progress::Bool = Turing.PROGRESS[], + optimizer::Optimisers.AbstractRule = AdvancedVI.DoWG(), + averager::AdvancedVI.AbstractAverager = AdvancedVI.PolynomialAveraging(), + operator::AdvancedVI.AbstractOperator = AdvancedVI.ProximalLocationScaleEntropy(), + adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE, + kwargs... + ) + +Approximating the target `model` via variational inference by optimizing `objective` with the initialization `q`. +This is a thin wrapper around `AdvancedVI.optimize`. + +# Arguments +- `model`: The target `DynamicPPL.Model`. +- `q`: The initial variational approximation. +- `n_iterations`: Number of optimization steps. + +# Keyword Arguments +- `objective`: Variational objective to be optimized. +- `show_progress`: Whether to show the progress bar. +- `optimizer`: Optimization algorithm. +- `averager`: Parameter averaging strategy. +- `operator`: Operator applied after each optimization step. +- `adtype`: Automatic differentiation backend. + +See the docs of `AdvancedVI.optimize` for additional keyword arguments. + +# Returns +- `q`: Variational distribution formed by the last iterate of the optimization run. +- `q_avg`: Variational distribution formed by the averaged iterates according to `averager`. +- `state`: Collection of states used for optimization. This can be used to resume from a past call to `vi`. +- `info`: Information generated during the optimization run. +""" +function vi( rng::Random.AbstractRNG, - alg::AdvancedVI.VariationalInference, - q, model::DynamicPPL.Model, - num_samples; - weight=1.0, + q, + n_iterations::Int; + objective=AdvancedVI.RepGradELBO( + 10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient() + ), + show_progress::Bool=PROGRESS[], + optimizer=AdvancedVI.DoWG(), + averager=AdvancedVI.PolynomialAveraging(), + operator=AdvancedVI.ProximalLocationScaleEntropy(), + adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, kwargs..., ) - return elbo(rng, alg, q, make_logjoint(model; weight=weight), num_samples; kwargs...) + return AdvancedVI.optimize( + rng, + make_logdensity(model), + objective, + q, + n_iterations; + show_progress=show_progress, + adtype, + optimizer, + averager, + operator, + kwargs..., + ) end -# VI algorithms -include("advi.jl") +function vi(model::DynamicPPL.Model, q, n_iterations::Int; kwargs...) + return vi(Random.default_rng(), model, q, n_iterations; kwargs...) +end end diff --git a/src/variational/advi.jl b/src/variational/advi.jl deleted file mode 100644 index 3819109a09..0000000000 --- a/src/variational/advi.jl +++ /dev/null @@ -1,70 +0,0 @@ -""" - meanfield([rng, ]model::Model) - -Creates a mean-field approximation with multivariate normal as underlying distribution. -""" -meanfield(model::DynamicPPL.Model) = meanfield(Random.default_rng(), model) -function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model) - # Setup. - varinfo = DynamicPPL.VarInfo(model) - # Use linked `varinfo` to determine the correct number of parameters. - # TODO: Replace with `length` once this is implemented for `VarInfo`. - varinfo_linked = DynamicPPL.link(varinfo, model) - num_params = length(varinfo_linked[:]) - - # initial params - μ = randn(rng, num_params) - σ = StatsFuns.softplus.(randn(rng, num_params)) - - # Construct the base family. - d = DistributionsAD.TuringDiagMvNormal(μ, σ) - - # Construct the bijector constrained → unconstrained. - b = Bijectors.bijector(model; varinfo=varinfo) - - # We want to transform from unconstrained space to constrained, - # hence we need the inverse of `b`. - return Bijectors.transformed(d, Bijectors.inverse(b)) -end - -# Overloading stuff from `AdvancedVI` to specialize for Turing -function AdvancedVI.update(d::DistributionsAD.TuringDiagMvNormal, μ, σ) - return DistributionsAD.TuringDiagMvNormal(μ, σ) -end -function AdvancedVI.update(td::Bijectors.TransformedDistribution, θ...) - return Bijectors.transformed(AdvancedVI.update(td.dist, θ...), td.transform) -end -function AdvancedVI.update( - td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}, - θ::AbstractArray, -) - # `length(td.dist) != length(td)` if `td.transform` changes the dimensionality, - # so we need to use the length of the underlying distribution `td.dist` here. - # TODO: Check if we can get away with `view` instead of `getindex` for all AD backends. - μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end] - return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω)) -end - -function AdvancedVI.vi( - model::DynamicPPL.Model, alg::AdvancedVI.ADVI; optimizer=AdvancedVI.TruncatedADAGrad() -) - q = meanfield(model) - return AdvancedVI.vi(model, alg, q; optimizer=optimizer) -end - -function AdvancedVI.vi( - model::DynamicPPL.Model, - alg::AdvancedVI.ADVI, - q::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}; - optimizer=AdvancedVI.TruncatedADAGrad(), -) - # Initial parameters for mean-field approx - μ, σs = StatsBase.params(q) - θ = vcat(μ, StatsFuns.invsoftplus.(σs)) - - # Optimize - AdvancedVI.optimize!(elbo, alg, q, make_logjoint(model), θ; optimizer=optimizer) - - # Return updated `Distribution` - return AdvancedVI.update(q, θ) -end diff --git a/src/variational/deprecated.jl b/src/variational/deprecated.jl new file mode 100644 index 0000000000..9a9f4777b5 --- /dev/null +++ b/src/variational/deprecated.jl @@ -0,0 +1,61 @@ + +import DistributionsAD +export ADVI + +Base.@deprecate meanfield(model) q_meanfield_gaussian(model) + +struct ADVI{AD} + "Number of samples used to estimate the ELBO in each optimization step." + samples_per_step::Int + "Maximum number of gradient steps." + max_iters::Int + "AD backend used for automatic differentiation." + adtype::AD +end + +function ADVI( + samples_per_step::Int=1, + max_iters::Int=1000; + adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff(), +) + Base.depwarn( + "The type ADVI will be removed in future releases. Please refer to the new interface for `vi`", + :ADVI; + force=true, + ) + return ADVI{typeof(adtype)}(samples_per_step, max_iters, adtype) +end + +function vi(model::DynamicPPL.Model, alg::ADVI; kwargs...) + Base.depwarn( + "This specialization along with the type `ADVI` will be deprecated in future releases. Please refer to the new interface for `vi`.", + :vi; + force=true, + ) + q = q_meanfield_gaussian(Random.default_rng(), model) + objective = AdvancedVI.RepGradELBO( + alg.samples_per_step; entropy=AdvancedVI.ClosedFormEntropy() + ) + operator = AdvancedVI.IdentityOperator() + _, q_avg, _, _ = vi(model, q, alg.max_iters; objective, operator, kwargs...) + return q_avg +end + +function vi( + model::DynamicPPL.Model, + alg::ADVI, + q::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}; + kwargs..., +) + Base.depwarn( + "This specialization along with the type `ADVI` will be deprecated in future releases. Please refer to the new interface for `vi`.", + :vi; + force=true, + ) + objective = AdvancedVI.RepGradELBO( + alg.samples_per_step; entropy=AdvancedVI.ClosedFormEntropy() + ) + operator = AdvancedVI.IdentityOperator() + _, q_avg, _, _ = vi(model, q, alg.max_iters; objective, operator, kwargs...) + return q_avg +end diff --git a/test/Project.toml b/test/Project.toml index 11694dd064..720fd875c8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -44,7 +44,7 @@ AbstractMCMC = "5" AbstractPPL = "0.9, 0.10, 0.11" AdvancedMH = "0.6, 0.7, 0.8" AdvancedPS = "=0.6.0" -AdvancedVI = "0.2" +AdvancedVI = "0.4" Aqua = "0.8" BangBang = "0.4" Bijectors = "0.14, 0.15" diff --git a/test/runtests.jl b/test/runtests.jl index 69f8045dd8..cbabd62c41 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -70,10 +70,6 @@ end end end - @testset "variational optimisers" begin - @timeit_include("variational/optimisers.jl") - end - @testset "stdlib" verbose = true begin @timeit_include("stdlib/distributions.jl") @timeit_include("stdlib/RandomMeasures.jl") diff --git a/test/variational/advi.jl b/test/variational/advi.jl index 642b8b80d2..ed8f745df2 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -1,71 +1,133 @@ + module AdvancedVITests using ..Models: gdemo_default using ..NumericalTests: check_gdemo -import AdvancedVI -using AdvancedVI: TruncatedADAGrad, DecayedADAGrad + +using AdvancedVI using Bijectors: Bijectors using Distributions: Dirichlet, Normal -using LinearAlgebra: I +using LinearAlgebra using MCMCChains: Chains -import Random +using Random +using StableRNGs: StableRNG using Test: @test, @testset using Turing -using DistributionsAD: TuringDiagMvNormal +using Turing.Variational + +@testset "ADVI" begin + @testset "q initialization" begin + m = gdemo_default + d = length(Turing.DynamicPPL.VarInfo(m)[:]) + for q in [q_meanfield_gaussian(m), q_fullrank_gaussian(m)] + rand(q) + end + + μ = ones(d) + q = q_meanfield_gaussian(m; location=μ) + @assert mean(q.dist) ≈ μ -@testset "advi.jl" begin - @testset "advi constructor" begin - Random.seed!(0) - N = 500 + q = q_fullrank_gaussian(m; location=μ) + @assert mean(q.dist) ≈ μ - s1 = ADVI() - q = vi(gdemo_default, s1) - c1 = rand(q, N) + L = Diagonal(fill(0.1, d)) + q = q_meanfield_gaussian(m; scale=L) + @assert cov(q.dist) ≈ L * L + + L = LowerTriangular(tril(0.01 * ones(d, d) + I)) + q = q_fullrank_gaussian(m; scale=L) + @assert cov(q.dist) ≈ L * L' end - @testset "advi inference" begin - @testset for opt in [TruncatedADAGrad(), DecayedADAGrad()] - Random.seed!(1) - N = 500 - - alg = ADVI(10, 5000) - q = vi(gdemo_default, alg; optimizer=opt) - samples = transpose(rand(q, N)) - chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"]) - # TODO: uhmm, seems like a large `eps` here... - check_gdemo(chn; atol=0.5) + @testset "default interface" begin + for q0 in [q_meanfield_gaussian(gdemo_default), q_fullrank_gaussian(gdemo_default)] + _, q, _, _ = vi(gdemo_default, q0, 100; show_progress=Turing.PROGRESS[]) + c1 = rand(q, 10) end end - @testset "advi different interfaces" begin - Random.seed!(1234) - - target = MvNormal(zeros(2), I) - logπ(z) = logpdf(target, z) - advi = ADVI(10, 1000) - - # Using a function z ↦ q(⋅∣z) - getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4])) - q = vi(logπ, advi, getq, randn(4)) + @testset "custom interface $name" for (name, objective, operator, optimizer) in [ + ( + "ADVI with closed-form entropy", + AdvancedVI.RepGradELBO(10), + AdvancedVI.ProximalLocationScaleEntropy(), + AdvancedVI.DoG(), + ), + ( + "ADVI with proximal entropy", + AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), + AdvancedVI.ClipScale(), + AdvancedVI.DoG(), + ), + ( + "ADVI with STL entropy", + AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.StickingTheLandingEntropy()), + AdvancedVI.ClipScale(), + AdvancedVI.DoG(), + ), + ] + T = 1000 + q, q_avg, _, _ = vi( + gdemo_default, + q_meanfield_gaussian(gdemo_default), + T; + objective, + optimizer, + operator, + show_progress=Turing.PROGRESS[], + ) + + N = 1000 + c1 = rand(q_avg, N) + c2 = rand(q, N) + end - xs = rand(target, 10) - @test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.07 + @testset "inference $name" for (name, objective, operator, optimizer) in [ + ( + "ADVI with closed-form entropy", + AdvancedVI.RepGradELBO(10), + AdvancedVI.ProximalLocationScaleEntropy(), + AdvancedVI.DoG(), + ), + ( + "ADVI with proximal entropy", + RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), + AdvancedVI.ClipScale(), + AdvancedVI.DoG(), + ), + ( + "ADVI with STL entropy", + AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.StickingTheLandingEntropy()), + AdvancedVI.ClipScale(), + AdvancedVI.DoG(), + ), + ] + rng = StableRNG(0x517e1d9bf89bf94f) + + T = 1000 + q, q_avg, _, _ = vi( + rng, + gdemo_default, + q_meanfield_gaussian(gdemo_default), + T; + optimizer, + show_progress=Turing.PROGRESS[], + ) + + N = 1000 + for q_out in [q_avg, q] + samples = transpose(rand(rng, q_out, N)) + chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"]) - # OR: implement `update` and pass a `Distribution` - function AdvancedVI.update(d::TuringDiagMvNormal, θ::AbstractArray{<:Real}) - return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[(length(q) + 1):end])) + check_gdemo(chn; atol=0.5) end - - q0 = TuringDiagMvNormal(zeros(2), ones(2)) - q = vi(logπ, advi, q0, randn(4)) - - xs = rand(target, 10) - @test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 end # regression test for: # https://github.com/TuringLang/Turing.jl/issues/2065 @testset "simplex bijector" begin + rng = StableRNG(0x517e1d9bf89bf94f) + @model function dirichlet() x ~ Dirichlet([1.0, 1.0]) return x @@ -81,25 +143,27 @@ using DistributionsAD: TuringDiagMvNormal @test all(x0 .≈ x0_inv) # And regression for https://github.com/TuringLang/Turing.jl/issues/2160. - q = vi(m, ADVI(10, 1000)) - x = rand(q, 1000) + _, q, _, _ = vi(rng, m, q_meanfield_gaussian(m), 1000) + x = rand(rng, q, 1000) @test mean(eachcol(x)) ≈ [0.5, 0.5] atol = 0.1 end # Ref: https://github.com/TuringLang/Turing.jl/issues/2205 @testset "with `condition` (issue #2205)" begin + rng = StableRNG(0x517e1d9bf89bf94f) + @model function demo_issue2205() x ~ Normal() return y ~ Normal(x, 1) end model = demo_issue2205() | (y=1.0,) - q = vi(model, ADVI(10, 1000)) + _, q, _, _ = vi(rng, model, q_meanfield_gaussian(model), 1000) # True mean. mean_true = 1 / 2 var_true = 1 / 2 # Check the mean and variance of the posterior. - samples = rand(q, 1000) + samples = rand(rng, q, 1000) mean_est = mean(samples) var_est = var(samples) @test mean_est ≈ mean_true atol = 0.2 diff --git a/test/variational/optimisers.jl b/test/variational/optimisers.jl deleted file mode 100644 index 6f64d5fb1f..0000000000 --- a/test/variational/optimisers.jl +++ /dev/null @@ -1,29 +0,0 @@ -module VariationalOptimisersTests - -using AdvancedVI: DecayedADAGrad, TruncatedADAGrad, apply! -import ForwardDiff -import ReverseDiff -using Test: @test, @testset -using Turing - -function test_opt(ADPack, opt) - θ = randn(10, 10) - θ_fit = randn(10, 10) - loss(x, θ_) = mean(sum(abs2, θ * x - θ_ * x; dims=1)) - for t in 1:(10^4) - x = rand(10) - Δ = ADPack.gradient(θ_ -> loss(x, θ_), θ_fit) - Δ = apply!(opt, θ_fit, Δ) - @. θ_fit = θ_fit - Δ - end - @test loss(rand(10, 100), θ_fit) < 0.01 - @test length(opt.acc) == 1 -end -for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)] - test_opt(ForwardDiff, opt) -end -for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)] - test_opt(ReverseDiff, opt) -end - -end