@@ -33,9 +33,11 @@ Arguments:
3333 The parameter 'continuity_term' should be a relatively big number to enforce a large penalty
3434 whenever the last point of any group doesn't coincide with the first point of next group.
3535"""
36- function multiple_shoot(p, ode_data, tsteps, prob:: ODEProblem , loss_function:: F ,
36+ function multiple_shoot(
37+ p, ode_data, tsteps, prob:: ODEProblem , loss_function:: F ,
3738 continuity_loss:: C , solver:: SciMLBase.AbstractODEAlgorithm ,
38- group_size:: Integer ; continuity_term:: Real = 100 , kwargs... ) where {F, C}
39+ group_size:: Integer ; continuity_term:: Real = 100 , kwargs...
40+ ) where {F, C}
3941 datasize = size(ode_data, ndims(ode_data))
4042 griddims = ntuple(_ -> Colon(), ndims(ode_data) - 1 )
4143
@@ -47,12 +49,17 @@ function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
4749 ranges = group_ranges(datasize, group_size)
4850
4951 # Multiple shooting predictions
50- sols = [solve(
51- remake(prob; p, tspan = (tsteps[first(rg)], tsteps[last(rg)]),
52- u0 = ode_data[griddims... , first(rg)]),
52+ sols = [
53+ solve(
54+ remake(
55+ prob; p, tspan = (tsteps[first(rg)], tsteps[last(rg)]),
56+ u0 = ode_data[griddims... , first(rg)]
57+ ),
5358 solver;
5459 saveat = tsteps[rg],
55- kwargs... ) for rg in ranges]
60+ kwargs...
61+ ) for rg in ranges
62+ ]
5663 group_predictions = Array.(sols)
5764
5865 # Abort and return infinite loss if one of the integrations failed
@@ -76,10 +83,14 @@ function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
7683 return loss, group_predictions
7784end
7885
79- function multiple_shoot(p, ode_data, tsteps, prob:: ODEProblem , loss_function:: F ,
80- solver:: SciMLBase.AbstractODEAlgorithm , group_size:: Integer ; kwargs... ) where {F}
81- return multiple_shoot(p, ode_data, tsteps, prob, loss_function,
82- _default_continuity_loss, solver, group_size; kwargs... )
86+ function multiple_shoot(
87+ p, ode_data, tsteps, prob:: ODEProblem , loss_function:: F ,
88+ solver:: SciMLBase.AbstractODEAlgorithm , group_size:: Integer ; kwargs...
89+ ) where {F}
90+ return multiple_shoot(
91+ p, ode_data, tsteps, prob, loss_function,
92+ _default_continuity_loss, solver, group_size; kwargs...
93+ )
8394end
8495
8596"""
@@ -117,20 +128,22 @@ Arguments:
117128 The parameter 'continuity_term' should be a relatively big number to enforce a large penalty
118129 whenever the last point of any group doesn't coincide with the first point of next group.
119130"""
120- function multiple_shoot(p, ode_data, tsteps, ensembleprob:: EnsembleProblem ,
131+ function multiple_shoot(
132+ p, ode_data, tsteps, ensembleprob:: EnsembleProblem ,
121133 ensemblealg:: SciMLBase.BasicEnsembleAlgorithm , loss_function:: F ,
122134 continuity_loss:: C , solver:: SciMLBase.AbstractODEAlgorithm ,
123- group_size:: Integer ; continuity_term:: Real = 100 , kwargs... ) where {F, C}
135+ group_size:: Integer ; continuity_term:: Real = 100 , kwargs...
136+ ) where {F, C}
124137 ntraj = size(ode_data, ndims(ode_data))
125- datasize = size(ode_data, ndims(ode_data)- 1 )
138+ datasize = size(ode_data, ndims(ode_data) - 1 )
126139 griddims = ntuple(_ -> Colon(), ndims(ode_data) - 2 )
127140 prob = ensembleprob. prob
128141
129142 if group_size < 2 || group_size > datasize
130143 throw(DomainError(group_size, " group_size can't be < 2 or > number of data points" ))
131144 end
132145
133- @assert ndims(ode_data)>= 3 " ode_data must have at least three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
146+ @assert ndims(ode_data) >= 3 " ode_data must have at least three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
134147 @assert datasize == length(tsteps)
135148 @assert ntraj == kwargs[:trajectories]
136149
@@ -142,14 +155,16 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
142155 rg -> begin
143156 newprob = remake(prob; p = p, tspan = (tsteps[first(rg)], tsteps[last(rg)]))
144157 function prob_func(prob, i, repeat)
145- remake(prob; u0 = ode_data[griddims... , first(rg), i])
158+ return remake(prob; u0 = ode_data[griddims... , first(rg), i])
146159 end
147160 newensembleprob = EnsembleProblem(
148161 newprob, prob_func, ensembleprob. output_func, ensembleprob. reduction,
149- ensembleprob. u_init, ensembleprob. safetycopy)
162+ ensembleprob. u_init, ensembleprob. safetycopy
163+ )
150164 solve(newensembleprob, solver, ensemblealg; saveat = tsteps[rg], kwargs... )
151165 end ,
152- ranges)
166+ ranges
167+ )
153168 group_predictions = Array.(sols)
154169
155170 # Abort and return infinite loss if one of the integrations did not converge?
@@ -176,12 +191,16 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
176191 return loss, group_predictions
177192end
178193
179- function multiple_shoot(p, ode_data, tsteps, ensembleprob:: EnsembleProblem ,
194+ function multiple_shoot(
195+ p, ode_data, tsteps, ensembleprob:: EnsembleProblem ,
180196 ensemblealg:: SciMLBase.BasicEnsembleAlgorithm , loss_function:: F ,
181197 solver:: SciMLBase.AbstractODEAlgorithm , group_size:: Integer ;
182- continuity_term:: Real = 100 , kwargs... ) where {F}
183- return multiple_shoot(p, ode_data, tsteps, ensembleprob, ensemblealg, loss_function,
184- _default_continuity_loss, solver, group_size; continuity_term, kwargs... )
198+ continuity_term:: Real = 100 , kwargs...
199+ ) where {F}
200+ return multiple_shoot(
201+ p, ode_data, tsteps, ensembleprob, ensemblealg, loss_function,
202+ _default_continuity_loss, solver, group_size; continuity_term, kwargs...
203+ )
185204end
186205
187206"""
@@ -207,8 +226,12 @@ julia> group_ranges(10, 5)
207226```
208227"""
209228function group_ranges(datasize:: Integer , groupsize:: Integer )
210- 2 ≤ groupsize ≤ datasize || throw(DomainError(groupsize,
211- " datasize must be positive and groupsize must to be within [2, datasize]" ))
229+ 2 ≤ groupsize ≤ datasize || throw(
230+ DomainError(
231+ groupsize,
232+ " datasize must be positive and groupsize must to be within [2, datasize]"
233+ )
234+ )
212235 return [i: min(datasize, i + groupsize - 1 ) for i in 1 : (groupsize - 1 ): (datasize - 1 )]
213236end
214237
0 commit comments