|
1 |
| -using MCMCDiagnosticTools |
| 1 | +import MCMCDiagnosticTools |
2 | 2 |
|
3 | 3 | export relative_eff, psis_ess, sup_ess
|
4 | 4 |
|
5 | 5 | """
|
6 | 6 | relative_eff(
|
7 |
| - sample::AbstractArray{<:Real, 3}; |
8 |
| - method=MCMCDiagnosticTools.FFTESSMethod() |
| 7 | + sample::AbstractArray{<:Real, 3}; |
| 8 | + source::Union{AbstractString, Symbol} = "default", |
| 9 | + maxlag::Int = typemax(Int), |
| 10 | + kwargs..., |
9 | 11 | )
|
10 | 12 |
|
11 |
| -Calculate the relative efficiency of an MCMC chain, i.e. the effective sample size divided |
| 13 | +Calculate the relative efficiency of an MCMC chain, i.e., the effective sample size divided |
12 | 14 | by the nominal sample size.
|
13 | 15 |
|
| 16 | +If `lowercase(String(source))` is `"default"` or `"mcmc"`, the relative effective sample size is computed with `MCMCDiagnosticTools.ess`, using keyword arguments `kind = :basic`, `maxlag = maxlag`, and the remaining keyword arguments `kwargs...`. |
| 17 | +Otherwise a vector of ones for each chain is returned. |
| 18 | +
|
14 | 19 | # Arguments
|
15 | 20 |
|
16 |
| - - `sample::AbstractArray{<:Real, 3}`: An array of log-likelihood values. |
| 21 | + - `sample::AbstractArray{<:Real, 3}`: An array of log-likelihood values of the shape `(parameters, draws, chains)`. |
17 | 22 | """
|
18 | 23 | function relative_eff(
|
19 | 24 | sample::AbstractArray{<:Real,3};
|
20 |
| - source::Union{AbstractString, Symbol}="default", maxlag=size(sample, 2), kwargs... |
| 25 | + source::Union{AbstractString, Symbol}="default", |
| 26 | + maxlag=typemax(Int), |
| 27 | + kwargs..., |
21 | 28 | )
|
22 |
| - if lowercase(String(source)) ∉ ["mcmc", "default"] |
23 |
| - return ones(size(sample, 1)) |
| 29 | + if lowercase(String(source)) ∉ ("mcmc", "default") |
| 30 | + # Avoid type instability by computing the return type of `ess` |
| 31 | + T = promote_type(eltype(sample), typeof(zero(eltype(sample)) / 1)) |
| 32 | + res = similar(sample, T, (axes(sample, 3),)) |
| 33 | + return fill!(res, 1) |
24 | 34 | end
|
25 |
| - |
26 |
| - dims = size(sample) |
27 |
| - post_sample_size = dims[2] * dims[3] |
28 |
| - ess_sample = permutedims(sample, [2, 1, 3]) |
29 |
| - ess, = MCMCDiagnosticTools.ess_rhat(ess_sample; maxlag=maxlag, kwargs...) |
30 |
| - return r_eff = ess / post_sample_size |
| 35 | + ess_sample = PermutedDimsArray(sample, (2, 3, 1)) |
| 36 | + return MCMCDiagnosticTools.ess(ess_sample; maxlag, kwargs..., kind=:basic, relative=true) |
31 | 37 | end
|
32 | 38 |
|
33 | 39 |
|
|
0 commit comments