Skip to content

Commit d20c022

Browse files
committed
Optimisation rework
1 parent 0477b65 commit d20c022

12 files changed

Lines changed: 986 additions & 701 deletions

File tree

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1515
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1616
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1717
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
18+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1819
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1920
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
2021
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -25,7 +26,6 @@ Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
2526
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2627
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2728
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
28-
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
2929
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
3030
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
3131
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
@@ -58,20 +58,20 @@ BangBang = "0.4.2"
5858
Bijectors = "0.14, 0.15"
5959
Compat = "4.15.0"
6060
DataStructures = "0.18, 0.19"
61+
DifferentiationInterface = "0.7"
6162
Distributions = "0.25.77"
6263
DistributionsAD = "0.6"
6364
DocStringExtensions = "0.8, 0.9"
6465
DynamicHMC = "3.4"
65-
DynamicPPL = "0.39.1"
66+
DynamicPPL = "0.39.8"
6667
EllipticalSliceSampling = "0.5, 1, 2"
6768
ForwardDiff = "0.10.3, 1"
6869
Libtask = "0.9.3"
6970
LinearAlgebra = "1"
7071
LogDensityProblems = "2"
7172
MCMCChains = "5, 6, 7"
72-
NamedArrays = "0.9, 0.10"
7373
Optimization = "3, 4, 5"
74-
OptimizationOptimJL = "0.1, 0.2, 0.3, 0.4"
74+
OptimizationOptimJL = "0.1 - 0.4"
7575
OrderedCollections = "1"
7676
Printf = "1"
7777
Random = "1"

src/Turing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ using AdvancedVI: AdvancedVI
1111
using DynamicPPL: DynamicPPL
1212
import DynamicPPL: NoDist, NamedDist
1313
using LogDensityProblems: LogDensityProblems
14-
using NamedArrays: NamedArrays
1514
using Accessors: Accessors
1615
using StatsAPI: StatsAPI
1716
using StatsBase: StatsBase
@@ -45,6 +44,7 @@ end
4544
# Random probability measures.
4645
include("stdlib/distributions.jl")
4746
include("stdlib/RandomMeasures.jl")
47+
include("init_strategy.jl")
4848
include("mcmc/Inference.jl") # inference algorithms
4949
using .Inference
5050
include("variational/Variational.jl")

src/init_strategy.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using AbstractPPL: VarName
2+
using DynamicPPL: DynamicPPL
3+
4+
# This function is shared by both MCMC and optimisation, so has to exist outside of both.
5+
"""
6+
_convert_initial_params(initial_params)
7+
8+
Convert `initial_params` to a `DynamicPPl.AbstractInitStrategy` if it is not already one, or
9+
throw a useful error message.
10+
"""
11+
_convert_initial_params(initial_params::DynamicPPL.AbstractInitStrategy) = initial_params
12+
function _convert_initial_params(nt::NamedTuple)
13+
@info "Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead."
14+
return DynamicPPL.InitFromParams(nt)
15+
end
16+
function _convert_initial_params(d::AbstractDict{<:VarName})
17+
@info "Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead."
18+
return DynamicPPL.InitFromParams(d)
19+
end
20+
function _convert_initial_params(::AbstractVector{<:Real})
21+
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally a `DynamicPPL.AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code."
22+
throw(ArgumentError(errmsg))
23+
end
24+
function _convert_initial_params(@nospecialize(_::Any))
25+
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or a `DynamicPPL.AbstractInitStrategy`."
26+
throw(ArgumentError(errmsg))
27+
end

src/mcmc/abstractmcmc.jl

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,6 @@ parameters for sampling are chosen if not specified by the user. By default, thi
1919
"""
2020
init_strategy(::AbstractSampler) = DynamicPPL.InitFromPrior()
2121

22-
"""
23-
_convert_initial_params(initial_params)
24-
25-
Convert `initial_params` to a `DynamicPPl.AbstractInitStrategy` if it is not already one, or
26-
throw a useful error message.
27-
"""
28-
_convert_initial_params(initial_params::DynamicPPL.AbstractInitStrategy) = initial_params
29-
function _convert_initial_params(nt::NamedTuple)
30-
@info "Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead."
31-
return DynamicPPL.InitFromParams(nt)
32-
end
33-
function _convert_initial_params(d::AbstractDict{<:VarName})
34-
@info "Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead."
35-
return DynamicPPL.InitFromParams(d)
36-
end
37-
function _convert_initial_params(::AbstractVector{<:Real})
38-
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally a `DynamicPPL.AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code."
39-
throw(ArgumentError(errmsg))
40-
end
41-
function _convert_initial_params(@nospecialize(_::Any))
42-
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or a `DynamicPPL.AbstractInitStrategy`."
43-
throw(ArgumentError(errmsg))
44-
end
45-
4622
"""
4723
default_varinfo(rng, model, sampler)
4824
@@ -89,7 +65,7 @@ function AbstractMCMC.sample(
8965
model,
9066
spl,
9167
N;
92-
initial_params=_convert_initial_params(initial_params),
68+
initial_params=Turing._convert_initial_params(initial_params),
9369
chain_type,
9470
kwargs...,
9571
)
@@ -134,7 +110,7 @@ function AbstractMCMC.sample(
134110
n_chains;
135111
chain_type,
136112
check_model=false, # no need to check again
137-
initial_params=map(_convert_initial_params, initial_params),
113+
initial_params=map(Turing._convert_initial_params, initial_params),
138114
kwargs...,
139115
)
140116
end

src/mcmc/emcee.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function Turing.Inference.init_strategy(spl::Emcee)
3939
return fill(DynamicPPL.InitFromPrior(), _get_n_walkers(spl))
4040
end
4141
# We also have to explicitly allow this or else it will error...
42-
function Turing.Inference._convert_initial_params(
42+
function Turing._convert_initial_params(
4343
x::AbstractVector{<:DynamicPPL.AbstractInitStrategy}
4444
)
4545
return x

src/mcmc/repeat_sampler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ function AbstractMCMC.sample(
115115
model,
116116
sampler,
117117
N;
118-
initial_params=_convert_initial_params(initial_params),
118+
initial_params=Turing._convert_initial_params(initial_params),
119119
chain_type=chain_type,
120120
progress=progress,
121121
kwargs...,
@@ -143,7 +143,7 @@ function AbstractMCMC.sample(
143143
ensemble,
144144
N,
145145
n_chains;
146-
initial_params=map(_convert_initial_params, initial_params),
146+
initial_params=map(Turing._convert_initial_params, initial_params),
147147
chain_type=chain_type,
148148
progress=progress,
149149
kwargs...,

0 commit comments

Comments
 (0)