Skip to content

Commit 8a29779

Browse files
Merge pull request #124 from ChrisRackauckas-Claude/main
Fix deprecation warnings
2 parents f806803 + 054dcbc commit 8a29779

File tree

5 files changed

+31
-25
lines changed

5 files changed

+31
-25
lines changed

src/DeepBSDE.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ function DiffEqBase.solve(
101101
p = prob.p isa AbstractArray ? prob.p : Float32[]
102102
A = haskey(kwargs, :A) ? prob.A : nothing
103103
u_domain = prob.x0_sample
104-
data = Iterators.repeated((), maxiters)
105104

106105
#hidden layer
107106
opt = pdealg.opt
@@ -182,15 +181,18 @@ function DiffEqBase.solve(
182181

183182
iters = eltype(x0)[]
184183
losses = eltype(x0)[]
185-
cb = function ()
184+
verbose && println("DeepBSDE")
185+
for _ in 1:maxiters
186+
gs = Flux.gradient(ps) do
187+
loss_n_sde()
188+
end
189+
Flux.Optimise.update!(opt, ps, gs)
186190
save_everystep && push!(iters, u0(x0)[1])
187191
l = loss_n_sde()
188192
push!(losses, l)
189193
verbose && println("Current loss is: $l")
190-
return l < pabstol && Flux.stop()
194+
l < pabstol && break
191195
end
192-
verbose && println("DeepBSDE")
193-
Flux.train!(loss_n_sde, ps, data, opt; cb = cb)
194196

195197
if !limits
196198
# Returning iters or simply u0(x0) and the trained neural network approximation u0
@@ -264,13 +266,16 @@ function DiffEqBase.solve(
264266
loss_() = sum(sol_high()) / trajectories_upper
265267

266268
ps = Flux.params(u0, σᵀ∇u...)
267-
cb = function ()
269+
opt_upper = Flux.Optimise.Adam(0.01)
270+
for _ in 1:maxiters_limits
271+
gs = Flux.gradient(ps) do
272+
loss_()
273+
end
274+
Flux.Optimise.update!(opt_upper, ps, gs)
268275
l = loss_()
269-
true && println("Current loss is: $l")
270-
return l < 1.0e-6 && Flux.stop()
276+
println("Current loss is: $l")
277+
l < 1.0e-6 && break
271278
end
272-
dataS = Iterators.repeated((), maxiters_upper)
273-
Flux.train!(loss_, ps, dataS, Flux.Optimise.Adam(0.01); cb = cb)
274279
u_high = loss_()
275280

276281
verbose && println("Lower limit")

src/DeepBSDE_Han.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ function DiffEqBase.solve(
3030
d = length(X0)
3131
g, f, μ, σ, p = prob.g, prob.f, prob.μ, prob.σ, prob.p
3232

33-
data = Iterators.repeated((), maxiters)
34-
3533
#hidden layer
3634
opt = alg.opt
3735
u0 = alg.u0
@@ -60,16 +58,18 @@ function DiffEqBase.solve(
6058
iters = eltype(X0)[]
6159
losses = eltype(X0)[]
6260

63-
callback = function ()
61+
for _ in 1:maxiters
62+
gs = Flux.gradient(ps) do
63+
loss()
64+
end
65+
Flux.Optimise.update!(opt, ps, gs)
6466
save_everystep && push!(iters, u0(X0)[1])
6567
l = loss()
6668
push!(losses, l)
6769
verbose && println("Current loss is: $l")
68-
return l < abstol && Flux.stop()
70+
l < abstol && break
6971
end
7072

71-
Flux.train!(loss, ps, data, opt; cb = callback)
72-
7373
if limits == false
7474
if save_everystep
7575
sol = PIDESolution(X0, ts, losses, iters, u0)
@@ -110,13 +110,16 @@ function DiffEqBase.solve(
110110
loss_() = sum(sol_high()) / trajectories_upper
111111

112112
ps = Flux.params(u0, σᵀ∇u...)
113-
callback = function ()
113+
opt_limits = Flux.Optimise.Adam(0.01)
114+
for _ in 1:maxiters_limits
115+
gs = Flux.gradient(ps) do
116+
loss_()
117+
end
118+
Flux.Optimise.update!(opt_limits, ps, gs)
114119
l = loss_()
115120
verbose && println("Current loss is: $l")
116-
return l < abstol && Flux.stop()
121+
l < abstol && break
117122
end
118-
dataS = Iterators.repeated((), maxiters_limits)
119-
Flux.train!(loss_, ps, dataS, ADAM(0.01); cb = callback)
120123
u_high = loss_()
121124

122125
verbose && println("Lower limit")

src/DeepSplitting.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
_copy(t::Tuple) = t
22
_copy(t) = t
33
function _copy(opt::O) where {O <: Flux.Optimise.AbstractOptimiser}
4-
return O([_copy(getfield(opt, f)) for f in fieldnames(typeof(opt))]...)
4+
return O([_copy(getfield(opt, fn)) for fn in fieldnames(typeof(opt))]...)
55
end
66

77
"""

src/NNKolmogorov.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct NNKolmogorov{C, O} <: HighDimPDEAlgorithm
1313
chain::C
1414
opt::O
1515
end
16-
NNKolmogorov(chain; opt = Flux.ADAM(0.1)) = NNKolmogorov(chain, opt)
16+
NNKolmogorov(chain; opt = Flux.Adam(0.1)) = NNKolmogorov(chain, opt)
1717

1818
"""
1919
$(TYPEDSIGNATURES)
@@ -69,7 +69,6 @@ function DiffEqBase.solve(
6969
#hidden layer
7070
chain = pdealg.chain
7171
opt = pdealg.opt
72-
ps = Flux.params(chain)
7372
xi = mapreduce(x -> rand(x, 1, trajectories), vcat, xs)
7473
#Finding Solution to the SDE having initial condition xi. Y = Phi(S(X , T))
7574
sdeproblem = SDEProblem(

src/NNParamKolmogorov.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ struct NNParamKolmogorov{C, O} <: HighDimPDEAlgorithm
1414
opt::O
1515
end
1616

17-
NNParamKolmogorov(chain; opt = Flux.ADAM(0.1)) = NNParamKolmogorov(chain, opt)
17+
NNParamKolmogorov(chain; opt = Flux.Adam(0.1)) = NNParamKolmogorov(chain, opt)
1818

1919
"""
2020
$(TYPEDSIGNATURES)
@@ -75,7 +75,6 @@ function DiffEqBase.solve(
7575
dps = merge(p_defaults, dps)
7676

7777
chain = pdealg.chain
78-
ps = Flux.params(chain)
7978
opt = pdealg.opt
8079

8180
xi = mapreduce(x -> rand(x, 1, trajectories), vcat, xs)

0 commit comments

Comments
 (0)