Skip to content

Commit f422bbc

Browse files
Support Optimisers.jl optimizers in DeepSplitting
- Add Optimisers.jl as a dependency - Add _copy and _get_eta overloads for Optimisers.AbstractRule - Add constructor for Optimisers.jl optimizers - Update docs to use Flux.Optimise.Adam explicitly
1 parent b8472d6 commit f422bbc

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
99
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1010
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
12+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1415
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -31,6 +32,7 @@ DocStringExtensions = "0.9.3"
3132
Flux = "0.14.16, 0.15, 0.16"
3233
Functors = "0.4.11, 0.5"
3334
LinearAlgebra = "1.10"
35+
Optimisers = "0.3, 0.4"
3436
Random = "1.10"
3537
Reexport = "1.2.2"
3638
SafeTestsets = "0.1"

docs/src/getting_started.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ nn = Flux.Chain(Dense(d, hls, tanh),
9999
Dense(hls, hls, tanh),
100100
Dense(hls, 1)) # neural network used by the scheme
101101
102-
opt = ADAM(1e-2)
102+
opt = Flux.Optimise.Adam(1e-2)
103103
104104
## Definition of the algorithm
105105
alg = DeepSplitting(nn,

src/DeepSplitting.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
1+
using Optimisers
2+
13
_copy(t::Tuple) = t
24
_copy(t) = t
35
function _copy(opt::O) where {O <: Flux.Optimise.AbstractOptimiser}
46
return O([_copy(getfield(opt, fn)) for fn in fieldnames(typeof(opt))]...)
57
end
8+
# Support for new-style Optimisers.jl optimizers
9+
function _copy(opt::O) where {O <: Optimisers.AbstractRule}
10+
return O([_copy(getfield(opt, fn)) for fn in fieldnames(typeof(opt))]...)
11+
end
12+
13+
# Helper to get learning rate from either optimizer type
14+
_get_eta(opt::Flux.Optimise.AbstractOptimiser) = opt.eta
15+
_get_eta(opt::Optimisers.AbstractRule) = opt.eta
616

717
"""
818
DeepSplitting(nn, K=1, opt = Flux.Optimise.Adam(0.01), λs = nothing, mc_sample = NoSampling())
@@ -38,6 +48,7 @@ struct DeepSplitting{NN, F, O, L, MCS} <: HighDimPDEAlgorithm
3848
mc_sample!::MCS # Monte Carlo sample
3949
end
4050

51+
# Constructor for old-style Flux.Optimise optimizers
4152
function DeepSplitting(
4253
nn;
4354
K = 1,
@@ -48,7 +59,22 @@ function DeepSplitting(
4859
O <: Flux.Optimise.AbstractOptimiser,
4960
L <: Union{Nothing, Vector{N}} where {N <: Number},
5061
}
51-
isnothing(λs) ? λs = [opt.eta] : nothing
62+
isnothing(λs) ? λs = [_get_eta(opt)] : nothing
63+
return DeepSplitting(nn, K, opt, λs, mc_sample)
64+
end
65+
66+
# Constructor for new-style Optimisers.jl optimizers
67+
function DeepSplitting(
68+
nn;
69+
K = 1,
70+
opt::O,
71+
λs::L = nothing,
72+
mc_sample = NoSampling()
73+
) where {
74+
O <: Optimisers.AbstractRule,
75+
L <: Union{Nothing, Vector{N}} where {N <: Number},
76+
}
77+
isnothing(λs) ? λs = [_get_eta(opt)] : nothing
5278
return DeepSplitting(nn, K, opt, λs, mc_sample)
5379
end
5480

0 commit comments

Comments
 (0)