@@ -101,7 +101,6 @@ function DiffEqBase.solve(
101101 p = prob. p isa AbstractArray ? prob. p : Float32[]
102102 A = haskey (kwargs, :A ) ? prob. A : nothing
103103 u_domain = prob. x0_sample
104- data = Iterators. repeated ((), maxiters)
105104
106105 # hidden layer
107106 opt = pdealg. opt
@@ -182,15 +181,18 @@ function DiffEqBase.solve(
182181
183182 iters = eltype (x0)[]
184183 losses = eltype (x0)[]
185- cb = function ()
184+ verbose && println (" DeepBSDE" )
185+ for _ in 1 : maxiters
186+ gs = Flux. gradient (ps) do
187+ loss_n_sde ()
188+ end
189+ Flux. Optimise. update! (opt, ps, gs)
186190 save_everystep && push! (iters, u0 (x0)[1 ])
187191 l = loss_n_sde ()
188192 push! (losses, l)
189193 verbose && println (" Current loss is: $l " )
190- return l < pabstol && Flux . stop ()
194+ l < pabstol && break
191195 end
192- verbose && println (" DeepBSDE" )
193- Flux. train! (loss_n_sde, ps, data, opt; cb = cb)
194196
195197 if ! limits
196198 # Returning iters or simply u0(x0) and the trained neural network approximation u0
@@ -264,13 +266,16 @@ function DiffEqBase.solve(
264266 loss_ () = sum (sol_high ()) / trajectories_upper
265267
266268 ps = Flux. params (u0, σᵀ∇u... )
267- cb = function ()
269+ opt_upper = Flux. Optimise. Adam (0.01 )
270+ for _ in 1 : maxiters_limits
271+ gs = Flux. gradient (ps) do
272+ loss_ ()
273+ end
274+ Flux. Optimise. update! (opt_upper, ps, gs)
268275 l = loss_ ()
269- true && println (" Current loss is: $l " )
270- return l < 1.0e-6 && Flux . stop ()
276+ println (" Current loss is: $l " )
277+ l < 1.0e-6 && break
271278 end
272- dataS = Iterators. repeated ((), maxiters_upper)
273- Flux. train! (loss_, ps, dataS, Flux. Optimise. Adam (0.01 ); cb = cb)
274279 u_high = loss_ ()
275280
276281 verbose && println (" Lower limit" )
0 commit comments