Skip to content

Commit 2d3c6d4

Browse files
committed
Switch from JuliaFormatter to Runic.jl for code formatting
- Update CI workflow to use fredrikekre/runic-action@v1 - Remove .JuliaFormatter.toml configuration - Format all source files with Runic.jl 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent bc016d9 commit 2d3c6d4

20 files changed

+542
-352
lines changed

.JuliaFormatter.toml

Lines changed: 0 additions & 5 deletions
This file was deleted.

.github/workflows/FormatCheck.yml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1-
name: Format suggestions
1+
name: format-check
22

3-
on: [pull_request]
3+
on:
4+
push:
5+
branches:
6+
- 'master'
7+
- 'main'
8+
- 'release-'
9+
tags: '*'
10+
pull_request:
411

512
jobs:
6-
code-style:
13+
runic:
714
runs-on: ubuntu-latest
815
steps:
9-
- uses: julia-actions/julia-format@v4
16+
- uses: actions/checkout@v4
17+
- uses: fredrikekre/runic-action@v1
18+
with:
19+
version: '1'

docs/make.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,19 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true
99

1010
include("pages.jl")
1111

12-
makedocs(; sitename = "DiffEqFlux.jl",
12+
makedocs(;
13+
sitename = "DiffEqFlux.jl",
1314
authors = "Chris Rackauckas et al.",
1415
clean = true,
1516
doctest = false,
1617
# linkcheck = true,
1718
warnonly = [:docs_block, :missing_docs, :linkcheck],
1819
modules = [DiffEqFlux],
19-
format = Documenter.HTML(; assets = ["assets/favicon.ico"],
20-
canonical = "https://docs.sciml.ai/DiffEqFlux/stable/"),
21-
pages = pages)
20+
format = Documenter.HTML(;
21+
assets = ["assets/favicon.ico"],
22+
canonical = "https://docs.sciml.ai/DiffEqFlux/stable/"
23+
),
24+
pages = pages
25+
)
2226

2327
deploydocs(; repo = "github.com/SciML/DiffEqFlux.jl.git", push_preview = true)

ext/DiffEqFluxDataInterpolationsExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ using DiffEqFlux: DiffEqFlux
55

66
@views function DiffEqFlux.collocate_data(
77
data::AbstractMatrix{T}, tpoints::AbstractVector{T},
8-
tpoints_sample::AbstractVector{T}, interp, args...) where {T}
8+
tpoints_sample::AbstractVector{T}, interp, args...
9+
) where {T}
910
u = zeros(T, size(data, 1), length(tpoints_sample))
1011
du = zeros(T, size(data, 1), length(tpoints_sample))
1112
for d1 in axes(data, 1)

src/DiffEqFlux.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ using LuxLib: batched_matmul
1111
using Random: Random, AbstractRNG, randn!
1212
using Reexport: @reexport
1313
using SciMLBase: SciMLBase, DAEProblem, DDEFunction, DDEProblem, EnsembleProblem,
14-
ODEFunction, ODEProblem, ODESolution, SDEFunction, SDEProblem, remake,
15-
solve
14+
ODEFunction, ODEProblem, ODESolution, SDEFunction, SDEProblem, remake,
15+
solve
1616
using SciMLSensitivity: SciMLSensitivity, AdjointLSS, BacksolveAdjoint, EnzymeVJP,
17-
ForwardDiffOverAdjoint, ForwardDiffSensitivity, ForwardLSS,
18-
ForwardSensitivity, GaussAdjoint, InterpolatingAdjoint, NILSAS,
19-
NILSS, QuadratureAdjoint, ReverseDiffAdjoint, ReverseDiffVJP,
20-
SteadyStateAdjoint, TrackerAdjoint, TrackerVJP, ZygoteAdjoint,
21-
ZygoteVJP
17+
ForwardDiffOverAdjoint, ForwardDiffSensitivity, ForwardLSS,
18+
ForwardSensitivity, GaussAdjoint, InterpolatingAdjoint, NILSAS,
19+
NILSS, QuadratureAdjoint, ReverseDiffAdjoint, ReverseDiffVJP,
20+
SteadyStateAdjoint, TrackerAdjoint, TrackerVJP, ZygoteAdjoint,
21+
ZygoteVJP
2222
using Setfield: @set!
2323
using Static: True, False
2424

@@ -37,22 +37,22 @@ include("collocation.jl")
3737
include("multiple_shooting.jl")
3838

3939
export NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, AugmentedNDELayer,
40-
NeuralODEMM
40+
NeuralODEMM
4141
export FFJORD, FFJORDDistribution
4242
export DimMover
4343

4444
export EpanechnikovKernel, UniformKernel, TriangularKernel, QuarticKernel, TriweightKernel,
45-
TricubeKernel, GaussianKernel, CosineKernel, LogisticKernel, SigmoidKernel,
46-
SilvermanKernel
45+
TricubeKernel, GaussianKernel, CosineKernel, LogisticKernel, SigmoidKernel,
46+
SilvermanKernel
4747
export collocate_data
4848

4949
export multiple_shoot
5050

5151
# Reexporting only certain functions from SciMLSensitivity
5252
export BacksolveAdjoint, QuadratureAdjoint, GaussAdjoint, InterpolatingAdjoint,
53-
TrackerAdjoint, ZygoteAdjoint, ReverseDiffAdjoint, ForwardSensitivity,
54-
ForwardDiffSensitivity, ForwardDiffOverAdjoint, SteadyStateAdjoint, ForwardLSS,
55-
AdjointLSS, NILSS, NILSAS
53+
TrackerAdjoint, ZygoteAdjoint, ReverseDiffAdjoint, ForwardSensitivity,
54+
ForwardDiffSensitivity, ForwardDiffOverAdjoint, SteadyStateAdjoint, ForwardLSS,
55+
AdjointLSS, NILSS, NILSAS
5656
export TrackerVJP, ZygoteVJP, EnzymeVJP, ReverseDiffVJP
5757

5858
# Precompilation workload - must be at the end

src/collocation.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ function collocate_data(data, tpoints, kernel = TriangularKernel(), bandwidth =
7676
e2 = [_zero; _one; _zero]
7777
n = length(tpoints)
7878
bandwidth = bandwidth === nothing ?
79-
(n^(-1 / 5)) * (n^(-3 / 35)) * ((log(n))^(-1 / 16)) : bandwidth
79+
(n^(-1 / 5)) * (n^(-3 / 35)) * ((log(n))^(-1 / 16)) : bandwidth
8080

8181
Wd = similar(data, n, size(data, 1))
8282
WT1 = similar(data, n, 2)
@@ -101,8 +101,10 @@ function collocate_data(data, tpoints, kernel = TriangularKernel(), bandwidth =
101101
return estimated_derivative, estimated_solution
102102
end
103103
104-
@views function collocate_data(data::AbstractVector, tpoints::AbstractVector,
105-
tpoints_sample::AbstractVector, interp, args...)
104+
@views function collocate_data(
105+
data::AbstractVector, tpoints::AbstractVector,
106+
tpoints_sample::AbstractVector, interp, args...
107+
)
106108
du, u = collocate_data(reshape(data, 1, :), tpoints, tpoints_sample, interp, args...)
107109
return du[1, :], u[1, :]
108110
end

src/ffjord.jl

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,15 @@ preprint arXiv:1810.01367 (2018).
5959
end
6060

6161
function LuxCore.initialstates(rng::AbstractRNG, n::FFJORD)
62-
return (; model = LuxCore.initialstates(rng, n.model),
63-
regularize = false, monte_carlo = true)
62+
return (;
63+
model = LuxCore.initialstates(rng, n.model),
64+
regularize = false, monte_carlo = true,
65+
)
6466
end
6567

6668
function FFJORD(
67-
model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs...)
69+
model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs...
70+
)
6871
!(model isa AbstractLuxLayer) && (model = FromFluxAdaptor()(model))
6972
return FFJORD(model, basedist, ad, input_dims, tspan, args, kwargs)
7073
end
@@ -75,8 +78,10 @@ end
7578

7679
@inline __norm_batched(x) = sqrt.(sum(abs2, x; dims = 1:(ndims(x) - 1)))
7780

78-
function __ffjord(model::StatefulLuxLayer, u::AbstractArray{T, N}, p, ad = nothing,
79-
regularize::Bool = false, monte_carlo::Bool = true) where {T, N}
81+
function __ffjord(
82+
model::StatefulLuxLayer, u::AbstractArray{T, N}, p, ad = nothing,
83+
regularize::Bool = false, monte_carlo::Bool = true
84+
) where {T, N}
8085
L = size(u, N - 1)
8186
z = selectdim(u, N - 1, 1:(L - ifelse(regularize, 3, 1)))
8287
@set! model.ps = p
@@ -91,17 +96,23 @@ function __ffjord(model::StatefulLuxLayer, u::AbstractArray{T, N}, p, ad = nothi
9196
trace_jac = dropdims(
9297
sum(
9398
batched_matmul(
94-
reshape(e, 1, :, size(e, N)), reshape(Je, :, 1, size(Je, N)));
95-
dims = (1, 2));
96-
dims = (1, 2))
99+
reshape(e, 1, :, size(e, N)), reshape(Je, :, 1, size(Je, N))
100+
);
101+
dims = (1, 2)
102+
);
103+
dims = (1, 2)
104+
)
97105
elseif ad isa AutoZygote
98106
eJ = Lux.vector_jacobian_product(model, AutoZygote(), z, e)
99107
trace_jac = dropdims(
100108
sum(
101109
batched_matmul(
102-
reshape(eJ, 1, :, size(eJ, N)), reshape(e, :, 1, size(e, N)));
103-
dims = (1, 2));
104-
dims = (1, 2))
110+
reshape(eJ, 1, :, size(eJ, N)), reshape(e, :, 1, size(e, N))
111+
);
112+
dims = (1, 2)
113+
);
114+
dims = (1, 2)
115+
)
105116
else
106117
error("`ad` must be `nothing` or `AutoForwardDiff` or `AutoZygote`.")
107118
end
@@ -136,11 +147,14 @@ function __forward_ffjord(n::FFJORD, x::AbstractArray{T, N}, ps, st) where {T, N
136147
ffjord(u, p, t) = __ffjord(model, u, p, n.ad, regularize, monte_carlo)
137148

138149
_z = ChainRulesCore.@ignore_derivatives fill!(
139-
similar(x, S[1:(N - 2)]..., ifelse(regularize, 3, 1), S[N]), zero(T))
150+
similar(x, S[1:(N - 2)]..., ifelse(regularize, 3, 1), S[N]), zero(T)
151+
)
140152

141153
prob = ODEProblem{false}(ffjord, cat(x, _z; dims = Val(N - 1)), n.tspan, ps)
142-
sol = solve(prob, n.args...; sensealg, n.kwargs...,
143-
save_everystep = false, save_start = false, save_end = true)
154+
sol = solve(
155+
prob, n.args...; sensealg, n.kwargs...,
156+
save_everystep = false, save_start = false, save_end = true
157+
)
144158
pred = __get_pred(sol)
145159
L = size(pred, N - 1)
146160

@@ -185,11 +199,14 @@ function __backward_ffjord(::Type{T1}, n::FFJORD, n_samples::Int, ps, st, rng) w
185199
ffjord(u, p, t) = __ffjord(model, u, p, n.ad, regularize, monte_carlo)
186200

187201
_z = ChainRulesCore.@ignore_derivatives fill!(
188-
similar(x, S[1:(N - 2)]..., ifelse(regularize, 3, 1), S[N]), zero(T))
202+
similar(x, S[1:(N - 2)]..., ifelse(regularize, 3, 1), S[N]), zero(T)
203+
)
189204

190205
prob = ODEProblem{false}(ffjord, cat(x, _z; dims = Val(N - 1)), reverse(n.tspan), ps)
191-
sol = solve(prob, n.args...; sensealg, n.kwargs...,
192-
save_everystep = false, save_start = false, save_end = true)
206+
sol = solve(
207+
prob, n.args...; sensealg, n.kwargs...,
208+
save_everystep = false, save_start = false, save_end = true
209+
)
193210
pred = __get_pred(sol)
194211
L = size(pred, N - 1)
195212

@@ -222,7 +239,8 @@ function Distributions._logpdf(d::FFJORDDistribution, x::AbstractArray)
222239
return first(first(__forward_ffjord(d.model, x, d.ps, d.st)))
223240
end
224241
function Distributions._rand!(
225-
rng::AbstractRNG, d::FFJORDDistribution, x::AbstractArray{<:Real})
242+
rng::AbstractRNG, d::FFJORDDistribution, x::AbstractArray{<:Real}
243+
)
226244
copyto!(x, __backward_ffjord(eltype(d), d.model, size(x, ndims(x)), d.ps, d.st, rng))
227245
return x
228246
end

src/multiple_shooting.jl

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ Arguments:
3333
The parameter 'continuity_term' should be a relatively big number to enforce a large penalty
3434
whenever the last point of any group doesn't coincide with the first point of next group.
3535
"""
36-
function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
36+
function multiple_shoot(
37+
p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
3738
continuity_loss::C, solver::SciMLBase.AbstractODEAlgorithm,
38-
group_size::Integer; continuity_term::Real = 100, kwargs...) where {F, C}
39+
group_size::Integer; continuity_term::Real = 100, kwargs...
40+
) where {F, C}
3941
datasize = size(ode_data, ndims(ode_data))
4042
griddims = ntuple(_ -> Colon(), ndims(ode_data) - 1)
4143

@@ -47,12 +49,17 @@ function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
4749
ranges = group_ranges(datasize, group_size)
4850

4951
# Multiple shooting predictions
50-
sols = [solve(
51-
remake(prob; p, tspan = (tsteps[first(rg)], tsteps[last(rg)]),
52-
u0 = ode_data[griddims..., first(rg)]),
52+
sols = [
53+
solve(
54+
remake(
55+
prob; p, tspan = (tsteps[first(rg)], tsteps[last(rg)]),
56+
u0 = ode_data[griddims..., first(rg)]
57+
),
5358
solver;
5459
saveat = tsteps[rg],
55-
kwargs...) for rg in ranges]
60+
kwargs...
61+
) for rg in ranges
62+
]
5663
group_predictions = Array.(sols)
5764

5865
# Abort and return infinite loss if one of the integrations failed
@@ -76,10 +83,14 @@ function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
7683
return loss, group_predictions
7784
end
7885

79-
function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
80-
solver::SciMLBase.AbstractODEAlgorithm, group_size::Integer; kwargs...) where {F}
81-
return multiple_shoot(p, ode_data, tsteps, prob, loss_function,
82-
_default_continuity_loss, solver, group_size; kwargs...)
86+
function multiple_shoot(
87+
p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
88+
solver::SciMLBase.AbstractODEAlgorithm, group_size::Integer; kwargs...
89+
) where {F}
90+
return multiple_shoot(
91+
p, ode_data, tsteps, prob, loss_function,
92+
_default_continuity_loss, solver, group_size; kwargs...
93+
)
8394
end
8495

8596
"""
@@ -117,20 +128,22 @@ Arguments:
117128
The parameter 'continuity_term' should be a relatively big number to enforce a large penalty
118129
whenever the last point of any group doesn't coincide with the first point of next group.
119130
"""
120-
function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
131+
function multiple_shoot(
132+
p, ode_data, tsteps, ensembleprob::EnsembleProblem,
121133
ensemblealg::SciMLBase.BasicEnsembleAlgorithm, loss_function::F,
122134
continuity_loss::C, solver::SciMLBase.AbstractODEAlgorithm,
123-
group_size::Integer; continuity_term::Real = 100, kwargs...) where {F, C}
135+
group_size::Integer; continuity_term::Real = 100, kwargs...
136+
) where {F, C}
124137
ntraj = size(ode_data, ndims(ode_data))
125-
datasize = size(ode_data, ndims(ode_data)-1)
138+
datasize = size(ode_data, ndims(ode_data) - 1)
126139
griddims = ntuple(_ -> Colon(), ndims(ode_data) - 2)
127140
prob = ensembleprob.prob
128141

129142
if group_size < 2 || group_size > datasize
130143
throw(DomainError(group_size, "group_size can't be < 2 or > number of data points"))
131144
end
132145

133-
@assert ndims(ode_data)>=3 "ode_data must have at least three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
146+
@assert ndims(ode_data) >= 3 "ode_data must have at least three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
134147
@assert datasize == length(tsteps)
135148
@assert ntraj == kwargs[:trajectories]
136149

@@ -142,14 +155,16 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
142155
rg -> begin
143156
newprob = remake(prob; p = p, tspan = (tsteps[first(rg)], tsteps[last(rg)]))
144157
function prob_func(prob, i, repeat)
145-
remake(prob; u0 = ode_data[griddims..., first(rg), i])
158+
return remake(prob; u0 = ode_data[griddims..., first(rg), i])
146159
end
147160
newensembleprob = EnsembleProblem(
148161
newprob, prob_func, ensembleprob.output_func, ensembleprob.reduction,
149-
ensembleprob.u_init, ensembleprob.safetycopy)
162+
ensembleprob.u_init, ensembleprob.safetycopy
163+
)
150164
solve(newensembleprob, solver, ensemblealg; saveat = tsteps[rg], kwargs...)
151165
end,
152-
ranges)
166+
ranges
167+
)
153168
group_predictions = Array.(sols)
154169

155170
# Abort and return infinite loss if one of the integrations did not converge?
@@ -176,12 +191,16 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
176191
return loss, group_predictions
177192
end
178193

179-
function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
194+
function multiple_shoot(
195+
p, ode_data, tsteps, ensembleprob::EnsembleProblem,
180196
ensemblealg::SciMLBase.BasicEnsembleAlgorithm, loss_function::F,
181197
solver::SciMLBase.AbstractODEAlgorithm, group_size::Integer;
182-
continuity_term::Real = 100, kwargs...) where {F}
183-
return multiple_shoot(p, ode_data, tsteps, ensembleprob, ensemblealg, loss_function,
184-
_default_continuity_loss, solver, group_size; continuity_term, kwargs...)
198+
continuity_term::Real = 100, kwargs...
199+
) where {F}
200+
return multiple_shoot(
201+
p, ode_data, tsteps, ensembleprob, ensemblealg, loss_function,
202+
_default_continuity_loss, solver, group_size; continuity_term, kwargs...
203+
)
185204
end
186205

187206
"""
@@ -207,8 +226,12 @@ julia> group_ranges(10, 5)
207226
```
208227
"""
209228
function group_ranges(datasize::Integer, groupsize::Integer)
210-
2 groupsize datasize || throw(DomainError(groupsize,
211-
"datasize must be positive and groupsize must to be within [2, datasize]"))
229+
2 groupsize datasize || throw(
230+
DomainError(
231+
groupsize,
232+
"datasize must be positive and groupsize must to be within [2, datasize]"
233+
)
234+
)
212235
return [i:min(datasize, i + groupsize - 1) for i in 1:(groupsize - 1):(datasize - 1)]
213236
end
214237

0 commit comments

Comments
 (0)