Skip to content

Commit fee165b

Browse files
author
Closed-Limelike-Curves
committed
Fix FFTW compat
1 parent e5499fa commit fee165b

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

src/ESS.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ESS
22

3+
using FFTW
34
using MCMCChains
45
using LoopVectorization
56
using Tullio
@@ -16,7 +17,7 @@ function relative_eff(sample::AbstractArray{T,3}) where {T<:AbstractFloat}
1617
post_sample_size = dims[2] * dims[3]
1718
# Only need ESS, not rhat
1819
ess_sample = inv.(permutedims(sample, [2, 1, 3]))
19-
ess, = MCMCChains.ess_rhat(ess_sample)
20+
ess, = MCMCChains.ess_rhat(ess_sample; method=FFTESSMethod())
2021
rel_eff = ess / post_sample_size
2122
return rel_eff
2223
end

test/runtests.jl

+12-7
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,24 @@ let ogWeights = RData.load("test/weightMatrix.RData")["weightMatrix"]
1212
end
1313
rel_eff = RData.load("test/Rel_Eff.RData")["rel_eff"]
1414
rPsis = RData.load("test/Psis_Object.RData")["psisObject"]
15-
with_rel_eff = psis(logLikelihoodArray, rel_eff)
16-
juliaPsis = psis(logLikelihoodArray)
17-
logLikelihoodMatrix = reshape(logLikelihoodArray, 32, 1000)
18-
chainIndex = vcat(fill(1, 500), fill(2, 500))
19-
matrixPsis = psis(logLikelihoodMatrix; chain_index=chainIndex)
20-
logPsis = psis(logLikelihoodArray; lw=true)
2115

2216
@testset "ParetoSmooth.jl" begin
17+
18+
# All of these should run
19+
20+
with_rel_eff = psis(logLikelihoodArray, rel_eff)
21+
juliaPsis = psis(logLikelihoodArray)
22+
logLikelihoodMatrix = reshape(logLikelihoodArray, 32, 1000)
23+
chainIndex = vcat(fill(1, 500), fill(2, 500))
24+
matrixPsis = psis(logLikelihoodMatrix; chain_index=chainIndex)
25+
logPsis = psis(logLikelihoodArray; lw=true)
26+
27+
2328
# RMSE from R version is less than .1%
2429
@test sqrt(mean((with_rel_eff.weights ./ rWeights .- 1).^2)) .001
2530
# RMSE less than .2% when using InferenceDiagnostics' ESS
2631
@test sqrt(mean((juliaPsis.weights ./ rWeights .- 1).^2)) .002
27-
@test count(with_rel_eff.weights . rWeights) 10
32+
@test count(with_rel_eff.weights . rWeights) 10
2833
@test count(juliaPsis.weights .≉ matrixPsis.weights) 10
2934
@test sqrt(mean((logPsis.weights .- log.(rWeights)).^2)) .001
3035
end

0 commit comments

Comments
 (0)