Skip to content

Commit 677b7b2

Browse files
devmotionsethaxen
andauthored
Support MCMCDiagnosticTools 0.3 and MCMCChains 6 (#79)
* Support MCMCDiagnosticTools 0.3 and MCMCChains 6 * Apply suggestions from code review Co-authored-by: Seth Axen <[email protected]> * Remove Turing test dependency * Update Project.toml --------- Co-authored-by: Seth Axen <[email protected]>
1 parent 25f28fe commit 677b7b2

8 files changed

+48049
-39
lines changed

Project.toml

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ParetoSmooth"
22
uuid = "a68b5a21-f429-434e-8bfa-46b447300aac"
33
authors = ["Carlos Parada <[email protected]>"]
4-
version = "0.7.5"
4+
version = "0.7.6"
55

66
[deps]
77
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
@@ -20,8 +20,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2020
AxisKeys = "0.1.18, 0.2"
2121
DynamicPPL = "0.21, 0.22"
2222
LogExpFunctions = "0.3"
23-
MCMCChains = "5"
24-
MCMCDiagnosticTools = "0.1.0"
23+
MCMCChains = "6"
24+
MCMCDiagnosticTools = "0.3.2"
2525
NamedDims = "0.2.35, 1"
2626
PrettyTables = "2.1,2.2"
2727
Requires = "1.1.3"
@@ -42,10 +42,9 @@ RData = "df47a6cb-8c03-5eed-afd8-b6050d6c41da"
4242
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
4343
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4444
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
45-
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
4645

4746
[targets]
48-
test = ["CSV", "DataFrames", "Distributions", "DynamicPPL", "MCMCChains", "RData", "StatsFuns", "Test", "TestSetExtensions", "Turing"]
47+
test = ["CSV", "DataFrames", "Distributions", "DynamicPPL", "MCMCChains", "RData", "StatsFuns", "Test", "TestSetExtensions"]
4948

5049
[weakdeps]
5150
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"

src/ESS.jl

+20-14
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,39 @@
1-
using MCMCDiagnosticTools
1+
import MCMCDiagnosticTools
22

33
export relative_eff, psis_ess, sup_ess
44

55
"""
66
relative_eff(
7-
sample::AbstractArray{<:Real, 3};
8-
method=MCMCDiagnosticTools.FFTESSMethod()
7+
sample::AbstractArray{<:Real, 3};
8+
source::Union{AbstractString, Symbol} = "default",
9+
maxlag::Int = typemax(Int),
10+
kwargs...,
911
)
1012
11-
Calculate the relative efficiency of an MCMC chain, i.e. the effective sample size divided
13+
Calculate the relative efficiency of an MCMC chain, i.e., the effective sample size divided
1214
by the nominal sample size.
1315
16+
If `lowercase(String(source))` is `"default"` or `"mcmc"`, the relative effective sample size is computed with `MCMCDiagnosticTools.ess`, using keyword arguments `kind = :basic`, `maxlag = maxlag`, and the remaining keyword arguments `kwargs...`.
17+
Otherwise a vector of ones for each chain is returned.
18+
1419
# Arguments
1520
16-
- `sample::AbstractArray{<:Real, 3}`: An array of log-likelihood values.
21+
- `sample::AbstractArray{<:Real, 3}`: An array of log-likelihood values of the shape `(parameters, draws, chains)`.
1722
"""
1823
function relative_eff(
1924
sample::AbstractArray{<:Real,3};
20-
source::Union{AbstractString, Symbol}="default", maxlag=size(sample, 2), kwargs...
25+
source::Union{AbstractString, Symbol}="default",
26+
maxlag=typemax(Int),
27+
kwargs...,
2128
)
22-
if lowercase(String(source)) ["mcmc", "default"]
23-
return ones(size(sample, 1))
29+
if lowercase(String(source)) ("mcmc", "default")
30+
# Avoid type instability by computing the return type of `ess`
31+
T = promote_type(eltype(sample), typeof(zero(eltype(sample)) / 1))
32+
res = similar(sample, T, (axes(sample, 3),))
33+
return fill!(res, 1)
2434
end
25-
26-
dims = size(sample)
27-
post_sample_size = dims[2] * dims[3]
28-
ess_sample = permutedims(sample, [2, 1, 3])
29-
ess, = MCMCDiagnosticTools.ess_rhat(ess_sample; maxlag=maxlag, kwargs...)
30-
return r_eff = ess / post_sample_size
35+
ess_sample = PermutedDimsArray(sample, (2, 3, 1))
36+
return MCMCDiagnosticTools.ess(ess_sample; maxlag, kwargs..., kind=:basic, relative=true)
3137
end
3238

3339

test/data/samples_m5_1t.csv

+12,001
Large diffs are not rendered by default.

test/data/samples_m5_2t.csv

+12,001
Large diffs are not rendered by default.

test/data/samples_m5_3t.csv

+12,001
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)