-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathproposal.jl
More file actions
106 lines (90 loc) · 2.32 KB
/
Copy pathproposal.jl
File metadata and controls
106 lines (90 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
abstract type Proposal{P} end
struct StaticProposal{P} <: Proposal{P}
proposal::P
end
struct RandomWalkProposal{P} <: Proposal{P}
proposal::P
end
# Random draws
Base.rand(p::Proposal, args...) = rand(Random.GLOBAL_RNG, p, args...)
Base.rand(rng::Random.AbstractRNG, p::Proposal{<:Distribution}) = rand(rng, p.proposal)
function Base.rand(rng::Random.AbstractRNG, p::Proposal{<:AbstractArray})
return map(x -> rand(rng, x), p.proposal)
end
# Densities
Distributions.logpdf(p::Proposal{<:Distribution}, v) = logpdf(p.proposal, v)
function Distributions.logpdf(p::Proposal{<:AbstractArray}, v)
# `mapreduce` with multiple iterators requires Julia 1.2 or later
return mapreduce(((pi, vi),) -> logpdf(pi, vi), +, zip(p.proposal, v))
end
###############
# Random Walk #
###############
function propose(rng::Random.AbstractRNG, p::RandomWalkProposal, m::DensityModel)
return propose(rng, StaticProposal(p.proposal), m)
end
function propose(
rng::Random.AbstractRNG,
proposal::RandomWalkProposal{<:Union{Distribution,AbstractArray}},
model::DensityModel,
t
)
return t + rand(rng, proposal)
end
function q(
proposal::RandomWalkProposal{<:Union{Distribution,AbstractArray}},
t,
t_cond
)
return logpdf(proposal, t - t_cond)
end
##########
# Static #
##########
function propose(
rng::Random.AbstractRNG,
proposal::StaticProposal{<:Union{Distribution,AbstractArray}},
model::DensityModel,
t=nothing
)
return rand(rng, proposal)
end
function q(
proposal::StaticProposal{<:Union{Distribution,AbstractArray}},
t,
t_cond
)
return logpdf(proposal, t)
end
############
# Function #
############
# function definition with abstract types requires Julia 1.3 or later
for T in (StaticProposal, RandomWalkProposal)
@eval begin
(p::$T{<:Function})() = $T(p.proposal())
(p::$T{<:Function})(t) = $T(p.proposal(t))
end
end
function propose(
rng::Random.AbstractRNG,
proposal::Proposal{<:Function},
model::DensityModel
)
return propose(rng, proposal(), model)
end
function propose(
rng::Random.AbstractRNG,
proposal::Proposal{<:Function},
model::DensityModel,
t
)
return propose(rng, proposal(t), model)
end
function q(
proposal::Proposal{<:Function},
t,
t_cond
)
return q(proposal(t_cond), t, t_cond)
end