@@ -109,18 +109,15 @@ score.
109
109
- `log_likelihood::Array`: A matrix or 3d array of log-likelihood values indexed as
110
110
`[data, step, chain]`. The chain argument can be left off if `chain_index` is provided
111
111
or if all posterior samples were drawn from a single chain.
112
- - `args...`: Positional arguments to be passed to [`psis`](@ref).
113
- - `chain_index::Vector`: An optional vector of integers specifying which chain each
114
- step belongs to. For instance, `chain_index[3]` should return `2` if
115
- `log_likelihood[:, 3]` belongs to the second chain.
116
- - `kwargs...`: Keyword arguments to be passed to [`psis`](@ref).
112
+ - $ARGS [`psis`](@ref).
113
+ - $CHAIN_INDEX_DOC
114
+ - $KWARGS [`psis`](@ref).
117
115
118
116
See also: [`psis`](@ref), [`loo`](@ref), [`PsisLoo`](@ref).
119
117
"""
120
118
function psis_loo (
121
- log_likelihood:: T , args... ; kwargs...
122
- ) where {F <: Real , T <: AbstractArray{F, 3} }
123
-
119
+ log_likelihood:: AbstractArray{<:Real, 3} , args... ; kwargs...
120
+ )
124
121
125
122
dims = size (log_likelihood)
126
123
data_size = dims[1 ]
@@ -139,14 +136,14 @@ function psis_loo(
139
136
140
137
@tullio pointwise_loo[i] := weights[i, j, k] * exp (log_likelihood[i, j, k]) |> log
141
138
@tullio pointwise_naive[i] := exp (log_likelihood[i, j, k] - log_count) |> log
142
- pointwise_overfit = pointwise_naive - pointwise_loo
139
+ pointwise_p_eff = pointwise_naive - pointwise_loo
143
140
pointwise_mcse = _calc_mcse (weights, log_likelihood, pointwise_loo, r_eff)
144
141
145
142
146
143
pointwise = KeyedArray (
147
- hcat (pointwise_loo, pointwise_naive, pointwise_overfit , pointwise_mcse, ξ);
144
+ hcat (pointwise_loo, pointwise_naive, pointwise_p_eff , pointwise_mcse, ξ);
148
145
data= 1 : length (pointwise_loo),
149
- statistic= [:cv_est , :naive_est , :overfit , :mcse , :pareto_k ],
146
+ statistic= [:cv_est , :naive_est , :p_eff , :mcse , :pareto_k ],
150
147
)
151
148
152
149
table = _generate_loo_table (pointwise)
@@ -160,28 +157,28 @@ end
160
157
161
158
162
159
function psis_loo (
163
- log_likelihood:: T ,
160
+ log_likelihood:: AbstractMatrix{<:Real} ,
164
161
args... ;
165
162
chain_index:: AbstractVector = ones (size (log_likelihood, 1 )),
166
163
kwargs... ,
167
- ) where {F <: Real , T <: AbstractMatrix{F} }
164
+ )
168
165
new_log_ratios = _convert_to_array (log_likelihood, chain_index)
169
166
return psis_loo (new_log_ratios, args... ; kwargs... )
170
167
end
171
168
172
169
173
- function _generate_loo_table (pointwise:: AbstractArray )
170
+ function _generate_loo_table (pointwise:: AbstractArray{<:Real} )
174
171
175
172
data_size = size (pointwise, :data )
176
173
# create table with the right labels
177
174
table = KeyedArray (
178
175
similar (NamedDims. unname (pointwise), 3 , 4 );
179
- criterion= [:cv_est , :naive_est , :overfit ],
176
+ criterion= [:cv_est , :naive_est , :p_eff ],
180
177
statistic= [:total , :se_total , :mean , :se_mean ],
181
178
)
182
179
183
180
# calculate the sample expectation for the total score
184
- to_sum = pointwise ([:cv_est , :naive_est , :overfit ])
181
+ to_sum = pointwise ([:cv_est , :naive_est , :p_eff ])
185
182
@tullio averages[crit] := to_sum[data, crit] / data_size
186
183
averages = reshape (averages, 3 )
187
184
table (:, :mean ) .= averages
@@ -197,6 +194,11 @@ function _generate_loo_table(pointwise::AbstractArray)
197
194
# calculate the sample expectation for the standard error in averages
198
195
table (:, :se_total ) .= se_mean * data_size
199
196
197
+ if table (:p_eff , :total ) ≤ 0
198
+ @warn " The calculated effective number of parameters is negative, which should " *
199
+ " not be possible. PSIS has failed to approximate the target distribution."
200
+ end
201
+
200
202
return table
201
203
end
202
204
@@ -211,4 +213,4 @@ function _calc_mcse(weights, log_likelihood, pointwise_loo, r_eff)
211
213
# (google "log-normal method of moments" for a proof)
212
214
# apply MCMC correlation correction:
213
215
return @turbo @. sqrt (pointwise_var / r_eff)
214
- end
216
+ end
0 commit comments