Skip to content

Commit ea94694

Browse files
author
Closed-Limelike-Curves
committed
Test for approx equality with R
1 parent 99c94fb commit ea94694

File tree

6 files changed

+12
-28
lines changed

6 files changed

+12
-28
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Carlos Parada <[email protected]>"]
44
version = "0.1.1"
55

66
[deps]
7+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
910
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

src/GPD.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function gpdfit(
5555
# build pointwise estimates of ξ and θ at each grid point
5656
θ_hats = similar(sample, m)
5757
ξ_hats = similar(sample, m)
58-
@turbo @. θ_hats = 1 / sample[len] + (1 - sqrt((m+1) / $(1:m))) / prior / quartile
58+
@turbo @. θ_hats = 1 / sample[len] + (1 - sqrt(m / ($(1:m)-.5))) / prior / quartile
5959
@tullio threads=false ξ_hats[x] := log1p(-θ_hats[x] * sample[y]) |> _ / len
6060
log_like = similar(ξ_hats)
6161
# Calculate profile log-likelihood at each estimate:

src/ImportanceSampling.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function psis(
6464
weights::AbstractArray{F} = similar(log_ratios)
6565
# Shift ratios by maximum to prevent overflow
6666
# then shift by log of posterior sample size to avoid loss of precision for small values
67-
@tturbo @. weights = exp(log_ratios + $log(post_sample_size) - $maximum(log_ratios; dims=2))
67+
@tturbo @. weights = exp(log_ratios - $maximum(log_ratios; dims=2))
6868

6969
rel_eff = generate_rel_eff(weights, dims, rel_eff, source)
7070
check_input_validity_psis(reshape(log_ratios, dims), rel_eff)
@@ -85,10 +85,6 @@ function psis(
8585
@tturbo @. weights = log(weights)
8686
end
8787

88-
if dims[3] == 1
89-
weights = dropdims(weights; dims=3) # Reshape as array
90-
end
91-
9288
return Psis(
9389
weights,
9490
ξ,
@@ -138,13 +134,12 @@ Do PSIS on a single vector, smoothing its tail values.
138134
139135
# Arguments
140136
141-
- `is_ratios::AbstractVector{AbstractFloat}`: A vector of (not necessarily
142-
normalized) importance sampling ratios.
137+
- `is_ratios::AbstractVector{AbstractFloat}`: A vector of importance sampling ratios,
138+
scaled to have a maximum of 1.
143139
144140
# Returns
145141
146-
- `T<:AbstractFloat`: ξ, the shape parameter for the GPD; larger numbers indicate
147-
thicker tails.
142+
- `T<:AbstractFloat`: ξ, the shape parameter for the GPD; big numbers indicate thick tails.
148143
149144
# Extended help
150145
@@ -169,8 +164,8 @@ function do_psis_i!(
169164
cutoff = sorted_ratios[tail_start-1]
170165
ξ = psis_smooth_tail!(tail, cutoff)
171166

172-
# truncate at max of raw wts; because is_ratios ∝ len / maximum, max(raw weights) = len
173-
clamp!(sorted_ratios, 0, len)
167+
# truncate at max of raw wts (which is 1)
168+
clamp!(sorted_ratios, 0, 1)
174169
# unsort the ratios to their original position:
175170
is_ratios .= @views sorted_ratios[invperm(ordering)]
176171

@@ -202,7 +197,7 @@ function psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T<:Abstrac
202197
# save time not sorting since tail is already sorted
203198
ξ, σ = GPD.gpdfit(tail)
204199
if ξ Inf
205-
@turbo @. tail = GPD.gpd_quantile($(1:len) / (len + 1), ξ, σ) + cutoff
200+
@turbo @. tail = GPD.gpd_quantile(($(1:len) - .5) / len, ξ, σ) + cutoff
206201
end
207202
return ξ
208203
end
@@ -229,7 +224,7 @@ observations).
229224
"""
230225
struct Psis{
231226
F<:AbstractFloat,
232-
AF<:AbstractArray{F},
227+
AF<:AbstractArray{F,3},
233228
VF<:AbstractVector{F},
234229
I<:Integer,
235230
VI<:AbstractVector{I},

test/Manifest.toml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,6 @@ git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
6060
uuid = "9718e550-a3fa-408a-8086-8db961cd8217"
6161
version = "0.1.1"
6262

63-
[[BenchmarkTools]]
64-
deps = ["JSON", "Logging", "Printf", "Statistics", "UUIDs"]
65-
git-tree-sha1 = "01ca3823217f474243cc2c8e6e1d1f45956fe872"
66-
uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
67-
version = "1.0.0"
68-
6963
[[BufferedStreams]]
7064
deps = ["Compat", "Test"]
7165
git-tree-sha1 = "5d55b9486590fdda5905c275bb21ce1f0754020f"
@@ -409,12 +403,6 @@ version = "1.1.0"
409403
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
410404
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
411405

412-
[[PkgBenchmark]]
413-
deps = ["BenchmarkTools", "Dates", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Pkg", "Printf", "TerminalLoggers", "UUIDs"]
414-
git-tree-sha1 = "e4a10b7cdb7ec836850e43a4cee196f4e7b02756"
415-
uuid = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
416-
version = "0.2.12"
417-
418406
[[PooledArrays]]
419407
deps = ["DataAPI", "Future"]
420408
git-tree-sha1 = "cde4ce9d6f33219465b55162811d8de8139c0414"

test/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
[deps]
22
Libz = "2ec943e9-cfe8-584d-b93d-64dcb6d567b7"
33
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
4-
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
54
RData = "df47a6cb-8c03-5eed-afd8-b6050d6c41da"
65
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ logPsis = psis(logLikelihoodArray; lw=true)
2424
@test sqrt(mean((with_rel_eff.weights ./ rWeights .- 1).^2)) .001
2525
# RMSE less than .2% when using InferenceDiagnostics' ESS
2626
@test sqrt(mean((juliaPsis.weights ./ rWeights .- 1).^2)) .002
27-
@test count(juliaPsis.weights .≉ matrixPsis.weights) 1
27+
@test count(with_rel_eff.weights .≈ rWeights) 10
28+
@test count(juliaPsis.weights .≉ matrixPsis.weights) 10
2829
@test sqrt(mean((logPsis.weights .- log.(rWeights)).^2)) .001
2930
end

0 commit comments

Comments
 (0)