Skip to content

Commit 142f753

Browse files
torfjeldegithub-actions[bot]sunxd3
authored
Allowing using NamedTuple as initial_params (#632)
* Initial work on `NamedTuple` as `initial_params` * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add some tests * Update src/sampler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix type error for inits with `nothing` * Update src/sampler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * use better variable names * remove init with scalar * move `update_values!!` out of `TestUtils` * fix error --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Xianda Sun <[email protected]>
1 parent fa11c95 commit 142f753

File tree

5 files changed

+116
-88
lines changed

5 files changed

+116
-88
lines changed

src/sampler.jl

+38-22
Original file line numberDiff line numberDiff line change
@@ -142,38 +142,54 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
142142
"""
143143
initialsampler(spl::Sampler) = SampleFromPrior()
144144

145-
function initialize_parameters!!(
146-
vi::AbstractVarInfo, initial_params, spl::Sampler, model::Model
145+
function set_values!!(
146+
varinfo::AbstractVarInfo,
147+
initial_params::AbstractVector{<:Union{Real,Missing}},
148+
spl::AbstractSampler,
147149
)
148-
@debug "Using passed-in initial variable values" initial_params
149-
150-
# Flatten parameters.
151-
init_theta = mapreduce(vcat, initial_params) do x
152-
vec([x;])
153-
end
154-
155-
# Get all values.
156-
linked = islinked(vi, spl)
157-
if linked
158-
vi = invlink!!(vi, spl, model)
159-
end
160-
theta = vi[spl]
161-
length(theta) == length(init_theta) || throw(
150+
flattened_param_vals = varinfo[spl]
151+
length(flattened_param_vals) == length(initial_params) || throw(
162152
DimensionMismatch(
163-
"Provided initial value size ($(length(init_theta))) doesn't match the model size ($(length(theta)))",
153+
"Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(theta)))",
164154
),
165155
)
166156

167157
# Update values that are provided.
168-
for i in eachindex(init_theta)
169-
x = init_theta[i]
158+
for i in eachindex(initial_params)
159+
x = initial_params[i]
170160
if x !== missing
171-
theta[i] = x
161+
flattened_param_vals[i] = x
172162
end
173163
end
174164

175-
# Update in `vi`.
176-
vi = setindex!!(vi, theta, spl)
165+
# Update in `varinfo`.
166+
return setindex!!(varinfo, flattened_param_vals, spl)
167+
end
168+
169+
function set_values!!(
170+
varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler
171+
)
172+
initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing)
173+
return update_values!!(
174+
varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params))
175+
)
176+
end
177+
178+
function initialize_parameters!!(
179+
vi::AbstractVarInfo, initial_params, spl::AbstractSampler, model::Model
180+
)
181+
@debug "Using passed-in initial variable values" initial_params
182+
183+
# `link` the varinfo if needed.
184+
linked = islinked(vi, spl)
185+
if linked
186+
vi = invlink!!(vi, spl, model)
187+
end
188+
189+
# Set the values in `vi`.
190+
vi = set_values!!(vi, initial_params, spl)
191+
192+
# `invlink` if needed.
177193
if linked
178194
vi = link!!(vi, spl, model)
179195
end

src/test_utils.jl

+1-13
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,7 @@ using Bijectors: Bijectors
1111
using Accessors: Accessors
1212

1313
# For backwards compat.
14-
using DynamicPPL: varname_leaves
15-
16-
"""
17-
update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
18-
19-
Return instance similar to `vi` but with `vns` set to values from `vals`.
20-
"""
21-
function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
22-
for vn in vns
23-
vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn)
24-
end
25-
return vi
26-
end
14+
using DynamicPPL: varname_leaves, update_values!!
2715

2816
"""
2917
test_values(vi::AbstractVarInfo, vals::NamedTuple, vns)

src/utils.jl

+12
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,18 @@ function nested_getindex(values::AbstractDict, vn::VarName)
796796
return child(value)
797797
end
798798

799+
"""
800+
update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
801+
802+
Return instance similar to `vi` but with `vns` set to values from `vals`.
803+
"""
804+
function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
805+
for vn in vns
806+
vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn)
807+
end
808+
return vi
809+
end
810+
799811
"""
800812
float_type_with_fallback(x)
801813

src/varinfo.jl

+6
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,12 @@ Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[]
892892
return expr
893893
end
894894

895+
# FIXME(torfjelde): Don't use `_getvns`.
896+
Base.keys(vi::UntypedVarInfo, spl::AbstractSampler) = _getvns(vi, spl)
897+
function Base.keys(vi::TypedVarInfo, spl::AbstractSampler)
898+
return mapreduce(values, vcat, _getvns(vi, spl))
899+
end
900+
895901
"""
896902
setgid!(vi::VarInfo, gid::Selector, vn::VarName)
897903

test/sampler.jl

+59-53
Original file line numberDiff line numberDiff line change
@@ -84,23 +84,25 @@
8484
model = coinflip()
8585
sampler = Sampler(alg)
8686
lptrue = logpdf(Binomial(25, 0.2), 10)
87-
chain = sample(model, sampler, 1; initial_params=0.2, progress=false)
88-
@test chain[1].metadata.p.vals == [0.2]
89-
@test getlogp(chain[1]) == lptrue
90-
91-
# parallel sampling
92-
chains = sample(
93-
model,
94-
sampler,
95-
MCMCThreads(),
96-
1,
97-
10;
98-
initial_params=fill(0.2, 10),
99-
progress=false,
100-
)
101-
for c in chains
102-
@test c[1].metadata.p.vals == [0.2]
103-
@test getlogp(c[1]) == lptrue
87+
let inits = (; p=0.2)
88+
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
89+
@test chain[1].metadata.p.vals == [0.2]
90+
@test getlogp(chain[1]) == lptrue
91+
92+
# parallel sampling
93+
chains = sample(
94+
model,
95+
sampler,
96+
MCMCThreads(),
97+
1,
98+
10;
99+
initial_params=fill(inits, 10),
100+
progress=false,
101+
)
102+
for c in chains
103+
@test c[1].metadata.p.vals == [0.2]
104+
@test getlogp(c[1]) == lptrue
105+
end
104106
end
105107

106108
# model with two variables: initialization s = 4, m = -1
@@ -110,45 +112,49 @@
110112
end
111113
model = twovars()
112114
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
113-
chain = sample(model, sampler, 1; initial_params=[4, -1], progress=false)
114-
@test chain[1].metadata.s.vals == [4]
115-
@test chain[1].metadata.m.vals == [-1]
116-
@test getlogp(chain[1]) == lptrue
117-
118-
# parallel sampling
119-
chains = sample(
120-
model,
121-
sampler,
122-
MCMCThreads(),
123-
1,
124-
10;
125-
initial_params=fill([4, -1], 10),
126-
progress=false,
127-
)
128-
for c in chains
129-
@test c[1].metadata.s.vals == [4]
130-
@test c[1].metadata.m.vals == [-1]
131-
@test getlogp(c[1]) == lptrue
115+
for inits in ([4, -1], (; s=4, m=-1))
116+
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
117+
@test chain[1].metadata.s.vals == [4]
118+
@test chain[1].metadata.m.vals == [-1]
119+
@test getlogp(chain[1]) == lptrue
120+
121+
# parallel sampling
122+
chains = sample(
123+
model,
124+
sampler,
125+
MCMCThreads(),
126+
1,
127+
10;
128+
initial_params=fill(inits, 10),
129+
progress=false,
130+
)
131+
for c in chains
132+
@test c[1].metadata.s.vals == [4]
133+
@test c[1].metadata.m.vals == [-1]
134+
@test getlogp(c[1]) == lptrue
135+
end
132136
end
133137

134138
# set only m = -1
135-
chain = sample(model, sampler, 1; initial_params=[missing, -1], progress=false)
136-
@test !ismissing(chain[1].metadata.s.vals[1])
137-
@test chain[1].metadata.m.vals == [-1]
138-
139-
# parallel sampling
140-
chains = sample(
141-
model,
142-
sampler,
143-
MCMCThreads(),
144-
1,
145-
10;
146-
initial_params=fill([missing, -1], 10),
147-
progress=false,
148-
)
149-
for c in chains
150-
@test !ismissing(c[1].metadata.s.vals[1])
151-
@test c[1].metadata.m.vals == [-1]
139+
for inits in ([missing, -1], (; s=missing, m=-1), (; m=-1))
140+
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
141+
@test !ismissing(chain[1].metadata.s.vals[1])
142+
@test chain[1].metadata.m.vals == [-1]
143+
144+
# parallel sampling
145+
chains = sample(
146+
model,
147+
sampler,
148+
MCMCThreads(),
149+
1,
150+
10;
151+
initial_params=fill(inits, 10),
152+
progress=false,
153+
)
154+
for c in chains
155+
@test !ismissing(c[1].metadata.s.vals[1])
156+
@test c[1].metadata.m.vals == [-1]
157+
end
152158
end
153159

154160
# specify `initial_params=nothing`

0 commit comments

Comments
 (0)