Skip to content

Commit bb20f52

Browse files
author
Closed-Limelike-Curves
committed
Various minor+doc improvements
1 parent f2ec718 commit bb20f52

15 files changed

+107
-93
lines changed

src/AbstractCV.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ const CV_DESC = """
1414
estimated using leave-one-out cross validation.
1515
- `:naive_est` contains estimates of the in-sample prediction error.
1616
- `:p_eff` is the effective number of parameters -- a model with a `p_eff` of 2 is
17-
"about as overfit" as a model with 2 parameters and no regularization. It equals the
18-
difference between the previous two estimators, and measures how much your model
19-
tends to overfit the data.
17+
"about as overfit" as a model with 2 parameters and no regularization.
2018
- `pointwise::KeyedArray`: A `KeyedArray` of pointwise estimates with 5 columns --
2119
- `:cv_est` contains the estimated out-of-sample error for this point, as measured
2220
using leave-one-out cross validation.

src/GPD.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ function gpdfit(
7777

7878
end
7979

80+
8081
"""
8182
gpd_quantile(p::T, k::T, sigma::T) where {T<:Real} -> T
8283
@@ -95,4 +96,3 @@ A quantile of the Generalized Pareto Distribution.
9596
function gpd_quantile(p, ξ::T, sigma::T) where {T <: Real}
9697
return sigma * expm1(-ξ * log1p(-p)) / ξ
9798
end
98-

src/ImportanceSampling.jl

+3-5
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,9 @@ Implements Pareto-smoothed importance sampling (PSIS).
8585
8686
# Arguments
8787
## Positional Arguments
88-
- `log_ratios::AbstractArray`: A 2d or 3d array of importance ratios on the log scale (for
89-
PSIS-LOO these are *negative* log-likelihood values). Indices must be ordered as
90-
`[data, step, chain]`: `log_ratios[1, 2, 3]` should be the log-likelihood of the first
91-
data point, evaluated at the second step in the third chain. Chain indices can be
92-
left off if there is only one chain, or if keyword argument `chain_index` is provided.
88+
- `log_ratios::AbstractArray`: A 2d or 3d array of (unnormalized) importance ratios on the
89+
log scale. Indices must be ordered as `[data, step, chain]`. The chain index can be left
90+
off if there is only one chain, or if keyword argument `chain_index` is provided.
9391
- $R_EFF_DOC
9492
9593
## Keyword Arguments

src/InternalHelpers.jl

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ const CHAIN_INDEX_DOC = """
44
`log_likelihood[:, step]` belongs to the second chain.
55
"""
66

7+
const DATA_ARG = """
8+
`data`: An array of data points used to estimate the parameters of the model.
9+
"""
10+
711
const LIKELIHOOD_FUNCTION_ARG = """
812
`ll_fun::Function`: A function taking a single data point and returning the log-likelihood
913
of that point. This function must take the form `f(θ[1], ..., θ[n], data)`, where `θ` is the

src/LeaveOneOut.jl

+19-17
Original file line numberDiff line numberDiff line change
@@ -109,18 +109,15 @@ score.
109109
- `log_likelihood::Array`: A matrix or 3d array of log-likelihood values indexed as
110110
`[data, step, chain]`. The chain argument can be left off if `chain_index` is provided
111111
or if all posterior samples were drawn from a single chain.
112-
- `args...`: Positional arguments to be passed to [`psis`](@ref).
113-
- `chain_index::Vector`: An optional vector of integers specifying which chain each
114-
step belongs to. For instance, `chain_index[3]` should return `2` if
115-
`log_likelihood[:, 3]` belongs to the second chain.
116-
- `kwargs...`: Keyword arguments to be passed to [`psis`](@ref).
112+
- $ARGS [`psis`](@ref).
113+
- $CHAIN_INDEX_DOC
114+
- $KWARGS [`psis`](@ref).
117115
118116
See also: [`psis`](@ref), [`loo`](@ref), [`PsisLoo`](@ref).
119117
"""
120118
function psis_loo(
121-
log_likelihood::T, args...; kwargs...
122-
) where {F <: Real, T <: AbstractArray{F, 3}}
123-
119+
log_likelihood::AbstractArray{<:Real, 3}, args...; kwargs...
120+
)
124121

125122
dims = size(log_likelihood)
126123
data_size = dims[1]
@@ -139,14 +136,14 @@ function psis_loo(
139136

140137
@tullio pointwise_loo[i] := weights[i, j, k] * exp(log_likelihood[i, j, k]) |> log
141138
@tullio pointwise_naive[i] := exp(log_likelihood[i, j, k] - log_count) |> log
142-
pointwise_overfit = pointwise_naive - pointwise_loo
139+
pointwise_p_eff = pointwise_naive - pointwise_loo
143140
pointwise_mcse = _calc_mcse(weights, log_likelihood, pointwise_loo, r_eff)
144141

145142

146143
pointwise = KeyedArray(
147-
hcat(pointwise_loo, pointwise_naive, pointwise_overfit, pointwise_mcse, ξ);
144+
hcat(pointwise_loo, pointwise_naive, pointwise_p_eff, pointwise_mcse, ξ);
148145
data=1:length(pointwise_loo),
149-
statistic=[:cv_est, :naive_est, :overfit, :mcse, :pareto_k],
146+
statistic=[:cv_est, :naive_est, :p_eff, :mcse, :pareto_k],
150147
)
151148

152149
table = _generate_loo_table(pointwise)
@@ -160,28 +157,28 @@ end
160157

161158

162159
function psis_loo(
163-
log_likelihood::T,
160+
log_likelihood::AbstractMatrix{<:Real},
164161
args...;
165162
chain_index::AbstractVector=ones(size(log_likelihood, 1)),
166163
kwargs...,
167-
) where {F <: Real, T <: AbstractMatrix{F}}
164+
)
168165
new_log_ratios = _convert_to_array(log_likelihood, chain_index)
169166
return psis_loo(new_log_ratios, args...; kwargs...)
170167
end
171168

172169

173-
function _generate_loo_table(pointwise::AbstractArray)
170+
function _generate_loo_table(pointwise::AbstractArray{<:Real})
174171

175172
data_size = size(pointwise, :data)
176173
# create table with the right labels
177174
table = KeyedArray(
178175
similar(NamedDims.unname(pointwise), 3, 4);
179-
criterion=[:cv_est, :naive_est, :overfit],
176+
criterion=[:cv_est, :naive_est, :p_eff],
180177
statistic=[:total, :se_total, :mean, :se_mean],
181178
)
182179

183180
# calculate the sample expectation for the total score
184-
to_sum = pointwise([:cv_est, :naive_est, :overfit])
181+
to_sum = pointwise([:cv_est, :naive_est, :p_eff])
185182
@tullio averages[crit] := to_sum[data, crit] / data_size
186183
averages = reshape(averages, 3)
187184
table(:, :mean) .= averages
@@ -197,6 +194,11 @@ function _generate_loo_table(pointwise::AbstractArray)
197194
# calculate the sample expectation for the standard error in averages
198195
table(:, :se_total) .= se_mean * data_size
199196

197+
if table(:p_eff, :total) 0
198+
@warn "The calculated effective number of parameters is negative, which should " *
199+
"not be possible. PSIS has failed to approximate the target distribution."
200+
end
201+
200202
return table
201203
end
202204

@@ -211,4 +213,4 @@ function _calc_mcse(weights, log_likelihood, pointwise_loo, r_eff)
211213
# (google "log-normal method of moments" for a proof)
212214
# apply MCMC correlation correction:
213215
return @turbo @. sqrt(pointwise_var / r_eff)
214-
end
216+
end

src/MCMCChainsHelpers.jl

+14-7
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
using .MCMCChains
22
export pointwise_log_likelihoods
33

4+
5+
const CHAINS_ARG = """
6+
`chains::Chains`: A chain object from MCMCChains.
7+
"""
8+
9+
410
"""
511
pointwise_log_likelihoods(ll_fun::Function, chains::Chains, data)
612
713
Compute the pointwise log likelihoods.
814
915
# Arguments
1016
- $LIKELIHOOD_FUNCTION_ARG
11-
- `chain::Chains`: A chain object from MCMCChains.
12-
- `data`: An array of data points used to estimate the parameters of the model.
17+
- $CHAINS_ARG
18+
- $DATA_ARG
1319
1420
# Returns
1521
- `Array`: a three dimensional array of pointwise log-likelihoods. Dimensions are ordered
@@ -22,6 +28,7 @@ function pointwise_log_likelihoods(
2228
return pointwise_log_likelihoods(ll_fun, samples, data; kwargs...)
2329
end
2430

31+
2532
"""
2633
function psis_loo(
2734
ll_fun::Function,
@@ -37,8 +44,8 @@ score from an MCMCChains object.
3744
# Arguments
3845
3946
- $LIKELIHOOD_FUNCTION_ARG
40-
- `chain::Chain`: A chain object from MCMCChains.
41-
- `data`: A vector of data points used to estimate the parameters of the model.
47+
- $CHAINS_ARG
48+
- $DATA_ARG
4249
- $ARGS [`psis_loo`](@ref).
4350
- $KWARGS [`psis_loo`](@ref).
4451
@@ -57,8 +64,8 @@ Implements Pareto-smoothed importance sampling (PSIS) based on MCMCChain object.
5764
# Arguments
5865
5966
- $LIKELIHOOD_FUNCTION_ARG
60-
- `chain::Chain`: A chain object from MCMCChains.
61-
- `data`: A vector of data points used to estimate the parameters of the model.
67+
- $CHAINS_ARG
68+
- $DATA_ARG
6269
- $ARGS [`psis`](@ref).
6370
- $KWARGS [`psis`](@ref).
6471
@@ -67,4 +74,4 @@ See also: [`psis`](@ref), [`psis_loo`](@ref), [`PsisLoo`](@ref).
6774
function psis(ll_fun::Function, chain::Chains, data::AbstractVector, args...; kwargs...)
6875
pointwise_log_likes = pointwise_log_likelihoods(ll_fun, chain, data)
6976
return psis(-pointwise_log_likes, args...; kwargs...)
70-
end
77+
end

src/ModelComparison.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,14 @@ function loo_compare(
110110
pointwise = KeyedArray(
111111
pointwise;
112112
data=1:size(pointwise, :data),
113-
statistic=[:cv_est, :naive_est, :overfit, :mcse, :pareto_k],
113+
statistic=[:cv_est, :naive_est, :p_eff, :mcse, :pareto_k],
114114
model=model_names,
115115
)
116116

117117
# Subtract the effective number of params and elpd ests; leave mcse+pareto_k the same
118118
base_case = pointwise[data=:, statistic=1:3, model=1]
119119
@inbounds @simd for model_number in axes(pointwise, :model)
120-
@. pointwise[:, 1:3, model_number] =
121-
pointwise[:, 1:3, model_number] - base_case
120+
@. pointwise[:, 1:3, model_number] = pointwise[:, 1:3, model_number] - base_case
122121
end
123122

124123
return ModelComparison(pointwise, table)

src/NaiveLPD.jl

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using LoopVectorization
2+
using Tullio
3+
4+
5+
"""
6+
naive_lpd()
7+
8+
Calculate the naive (in-sample) estimate of the expected log probability density, otherwise
9+
known as the in-sample Bayes score. Not recommended for most uses.
10+
"""
11+
function naive_lpd(log_likelihood::AbstractArray{<:Real, 3})
12+
13+
dims = size(log_likelihood)
14+
data_size = dims[1]
15+
mcmc_count = dims[2] * dims[3] # total number of samples from posterior
16+
log_count = log(mcmc_count)
17+
18+
@tullio pointwise_naive[i] := exp(log_likelihood[i, j, k] - log_count) |> log
19+
20+
return sum(pointwise_naive)
21+
end

src/ParetoSmooth.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ using DocStringExtensions
44

55
function __init__()
66
@require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" include("TuringHelpers.jl")
7-
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
8-
"MCMCChainsHelpers.jl"
9-
)
7+
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" begin
8+
include("MCMCChainsHelpers.jl")
9+
end
1010
end
1111

1212
include("AbstractCV.jl")
@@ -16,6 +16,7 @@ include("InternalHelpers.jl")
1616
include("ImportanceSampling.jl")
1717
include("LeaveOneOut.jl")
1818
include("ModelComparison.jl")
19+
include("NaiveLPD.jl")
1920
include("PublicHelpers.jl")
2021

2122
end

src/PublicHelpers.jl

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
export pointwise_log_likelihoods
22

3-
const ARRAY_DIMS_WARNING = "The supplied array of mcmc samples indicates you have more
4-
parameters than mcmc samples.This is possible, but highly unusual. Please check that your
5-
array of mcmc samples has the following dimensions: [n_samples,n_params,n_chains]."
3+
const ARRAY_DIMS_WARNING = """
4+
The supplied array of mcmc samples indicates you have more parameters than samples. This is
5+
possible, but highly unusual. Please check that your array has the following dimensions, in
6+
order: [n_samples,n_params,n_chains].
7+
"""
8+
69

710
"""
811
pointwise_log_likelihoods(
9-
ll_fun::Function,
10-
samples::AbstractArray{<:Real,3},
11-
data;
12-
splat::Bool=true
12+
ll_fun::Function, samples::AbstractArray{<:Real,3}, data;
13+
splat::Bool=true[, chain_index::Vector{<:Integer}]
1314
)
1415
15-
Compute the pointwise log likelihood.
16+
Compute the pointwise log likelihoods.
1617
1718
# Arguments
1819
- $LIKELIHOOD_FUNCTION_ARG
1920
- `samples::AbstractArray`: A three dimensional array of MCMC samples. Here, the first
20-
dimension should indicate the iteration of the MCMC ; the second dimension should
21-
indicate the parameter ; and the third dimension represents the chains.
22-
- `data`: A vector of data used to estimate the parameters of the model.
21+
dimension should indicate the step of the MCMC algorithm; the second dimension should
22+
indicate the parameter; and the third should indicate the chain.
23+
- $DATA_ARG
2324
- `splat`: If `true` (default), `f` must be a function of `n` different parameters.
24-
Otherwise, `f` is assumed to be a function of a single parameter vector.
25+
Otherwise, `f` is assumed to be a function of a single parameter vector.
26+
- $CHAIN_INDEX_DOC
2527
2628
# Returns
2729
- `Array`: A three dimensional array of pointwise log-likelihoods.

src/TuringHelpers.jl

+3-22
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
using .Turing
2-
export pointwise_log_likelihoods
2+
export pointwise_log_likelihoods, psis_loo, psis
33

4-
const TURING_LOOP_WARN = """
5-
**Important Note:** The posterior log-likelihood must be computed with a `for` loop inside a
6-
Turing model; broadcasting will result in all observations being treated as if they are a
7-
single point.
8-
"""
9-
const CHAINS_ARG = """
10-
`chains::Chains`: A chain object from MCMCChains.
11-
"""
124
const TURING_MODEL_ARG = """
135
`model`: A Turing model with data in the form of `model(data)`.
146
"""
@@ -17,9 +9,7 @@ const TURING_MODEL_ARG = """
179
"""
1810
pointwise_log_likelihoods(model::DynamicPPL.Model, chain::Chains)
1911
20-
Compute the pointwise log-likelihoods from a Turing model.
21-
22-
$TURING_LOOP_WARN
12+
Compute the pointwise log-likelihoods from a Turing model.
2313
2414
# Arguments
2515
- $TURING_MODEL_ARG
@@ -30,9 +20,6 @@ $TURING_LOOP_WARN
3020
indexed using `array[data, sample, chain]`.
3121
"""
3222
function pointwise_log_likelihoods(model::DynamicPPL.Model, chain::Chains)
33-
34-
@info TURING_LOOP_WARN
35-
3623
# subset of chain for mcmc samples
3724
chain_params = MCMCChains.get_sections(chain, :parameters)
3825
# compute the pointwise log likelihoods
@@ -57,9 +44,7 @@ end
5744
) -> PsisLoo
5845
5946
Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross validation
60-
score from an MCMCChain object and a Turing model.
61-
62-
$TURING_LOOP_WARN
47+
score from a `chains` object and a Turing model.
6348
6449
# Arguments
6550
@@ -71,7 +56,6 @@ $TURING_LOOP_WARN
7156
See also: [`psis`](@ref), [`loo`](@ref), [`PsisLoo`](@ref).
7257
"""
7358
function psis_loo(model::DynamicPPL.Model, chain::Chains, args...; kwargs...)
74-
@info TURING_LOOP_WARN
7559
pointwise_log_likes = pointwise_log_likelihoods(model, chain)
7660
return psis_loo(pointwise_log_likes, args...; kwargs...)
7761
end
@@ -87,8 +71,6 @@ end
8771
8872
Generate samples using Pareto smoothed importance sampling (PSIS).
8973
90-
$TURING_LOOP_WARN
91-
9274
# Arguments
9375
- $TURING_MODEL_ARG
9476
- $CHAINS_ARG
@@ -98,7 +80,6 @@ $TURING_LOOP_WARN
9880
See also: [`psis`](@ref), [`loo`](@ref), [`PsisLoo`](@ref).
9981
"""
10082
function psis(model::DynamicPPL.Model, chain::Chains, args...; kwargs...)
101-
@info TURING_LOOP_WARN
10283
log_ratios = pointwise_log_likelihoods(model, chain)
10384
return psis(-log_ratios, args...; kwargs...)
10485
end

0 commit comments

Comments
 (0)