Skip to content

Commit 8c50a75

Browse files
author
Closed-Limelike-Curves
committed
Add geometric mean of probability calculation
1 parent 9f80cad commit 8c50a75

File tree

3 files changed

+26
-27
lines changed

3 files changed

+26
-27
lines changed

src/GPD.jl

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using LinearAlgebra
22
using LoopVectorization
33
using Statistics
4-
54
using Tullio
65

76

src/LeaveOneOut.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ function loo_from_psis(
188188
end
189189

190190

191-
function _generate_loo_table(pointwise::KeyedArray{<:Real})
191+
function _generate_loo_table(pointwise::AbstractMatrix{<:Real})
192192

193193
data_size = size(pointwise, :data)
194194
# create table with the right labels
@@ -200,7 +200,7 @@ function _generate_loo_table(pointwise::KeyedArray{<:Real})
200200

201201
# calculate the sample expectation for the total score
202202
to_sum = pointwise([:cv_elpd, :naive_lpd, :p_eff])
203-
@tullio avx=false avgs[statistic] := to_sum[data, statistic] / data_size
203+
@tullio avgs[statistic] := to_sum[data, statistic] / data_size
204204
avgs = reshape(avgs, 3)
205205
table(:, :mean) .= avgs
206206

src/ModelComparison.jl

+24-24
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,24 @@ A struct containing the results of model comparison.
2727
+ `weight`: A set of Akaike-like weights assigned to each model, which can be used in
2828
pseudo-Bayesian model averaging.
2929
- `std_err::NamedTuple`: A named tuple containing the standard error of `cv_elpd`. Note
30-
that these estimators (incorrectly) assume that all folds are independent, despite their
31-
substantial overlap, which creates a severely biased estimator. In addition, note
32-
that LOO-CV differences are *not* asymptotically normal. As a result, these standard
33-
errors cannot be used to calculate a confidence interval. These standard errors are
34-
included for consistency with R's LOO package, and should not be relied upon.
30+
that these estimators (incorrectly) assume all folds are independent, despite their
31+
substantial overlap, which creates a downward biased estimator. LOO-CV differences are
32+
*not* asymptotically normal, so these standard errors cannot be used to
33+
calculate a confidence interval.
34+
- `gmp::NamedTuple`: The geometric mean of the posterior probability assigned to each data
35+
point by each model. This is equal to `exp(cv_avg/n)` for each model. By taking the
36+
exponent of the average score, we can take outcomes on the log scale and shift them back
37+
onto the probability scale, making the results more easily interpretable. This measure
38+
is only meaningful for classifiers, i.e. variables with discrete outcomes; it is not
39+
possible to interpret GMP values for continuous outcome variables.
3540
3641
See also: [`PsisLoo`](@ref)
3742
"""
38-
struct ModelComparison
39-
pointwise::KeyedArray
40-
estimates::KeyedArray
41-
std_err::NamedTuple
42-
differences::KeyedArray
43+
struct ModelComparison{RealType<:Real, N}
44+
pointwise::KeyedArray{RealType, 3, <:NamedDimsArray, <:Any}
45+
estimates::KeyedArray{RealType, 2, <:NamedDimsArray, <:Any}
46+
std_err::NamedTuple{<:Any, Tuple{Vararg{RealType, N}}}
47+
gmp::NamedTuple{<:Any, Tuple{Vararg{RealType, N}}}
4348
end
4449

4550

@@ -55,10 +60,9 @@ end
5560
Construct a model comparison table from several [`PsisLoo`](@ref) objects.
5661
5762
# Arguments
58-
5963
- `cv_results`: One or more [`PsisLoo`](@ref) objects to be compared. Alternatively,
60-
a tuple or named tuple of `PsisLoo` objects can be passed. If a named tuple is passed,
61-
these names will be used to label each model.
64+
a tuple or named tuple of `PsisLoo` objects can be passed. If a named tuple is passed,
65+
these names will be used to label each model.
6266
- $LOO_COMPARE_KWARGS
6367
6468
See also: [`ModelComparison`](@ref), [`PsisLoo`](@ref), [`psis_loo`](@ref)
@@ -109,23 +113,19 @@ function loo_compare(
109113

110114
log_norm = logsumexp(cv_elpd)
111115
weights = @turbo warn_check_args=false @. exp(cv_elpd - log_norm)
112-
@turbo warn_check_args=false avg_elpd = cv_elpd ./ data_size
116+
117+
gmp = @turbo @. exp(cv_elpd / data_size)
118+
gmp = NamedTuple{name_tuple}(gmp)
119+
113120
@turbo warn_check_args=false @. cv_elpd = cv_elpd - cv_elpd[1]
121+
@turbo warn_check_args=false avg_elpd = cv_elpd ./ data_size
114122
total_diffs = KeyedArray(
115123
hcat(cv_elpd, avg_elpd, weights);
116124
model=model_names,
117125
statistic=[:cv_elpd, :cv_avg, :weight],
118126
)
119-
120-
all_diffs = @views [
121-
cv_results[i].estimates(column=[:total, :mean]) for i in 1:n_models
122-
]
123-
base_case = all_diffs[1]
124-
all_diffs = cat(all_diffs...; dims=3)
125-
@inbounds for slice in eachslice(all_diffs; dims=3)
126-
@fastmath @turbo warn_check_args=false slice .-= base_case
127-
end
128-
return ModelComparison(pointwise_diffs, total_diffs, se_total, all_diffs)
127+
128+
return ModelComparison(pointwise_diffs, total_diffs, se_total, gmp)
129129

130130
end
131131

0 commit comments

Comments
 (0)