Skip to content

Commit ab65e83

Browse files
feat: add doc example for implementing the interface
1 parent 061f24d commit ab65e83

File tree

3 files changed

+128
-0
lines changed

3 files changed

+128
-0
lines changed

Diff for: docs/Project.toml

+6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
4+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
35
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
6+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
47

58
[compat]
69
Documenter = "1"
10+
OrdinaryDiffEqTsit5 = "1.1.0"
11+
SciMLSensitivity = "7.69"
712
SciMLStructures = "1"
13+
Zygote = "0.6.72"

Diff for: docs/make.jl

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true)
66
pages = [
77
"Home" => "index.md",
88
"interface.md",
9+
"example.md",
910
"api.md"
1011
]
1112

Diff for: docs/src/example.md

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# An example implementation of the interface
2+
3+
In this tutorial we will implement the SciMLStructures.jl interface for a parameter
4+
object. This is useful when differentiating through ODE solves using SciMLSensitivity.jl
5+
and only part of the parameters are differentiable.
6+
7+
```@example
8+
using OrdinaryDiffEqTsit5
9+
using LinearAlgebra
10+
11+
mutable struct SubproblemParameters{P, Q, R}
12+
p::P # tunable
13+
q::Q
14+
r::R
15+
end
16+
17+
mutable struct Parameters{P, C}
18+
subparams::P
19+
coeffs::C # tunable matrix
20+
end
21+
22+
# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams)
23+
# and `du[length(subparams)+1:end] .= coeffs * u`
24+
function rhs!(du, u, p::Parameters, t)
25+
for (i, subpars) in enumerate(p.subparams)
26+
du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
27+
end
28+
N = length(p.subparams)
29+
mul!(view(du, (N+1):(length(du))), p.coeffs, u)
30+
return nothing
31+
end
32+
33+
u = sin.(0.1:0.1:1.0)
34+
subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5]
35+
p = Parameters(subparams, cos.([0.1i+0.33j for i in 1:5, j in 1:10]))
36+
tspan = (0.0, 1.0)
37+
38+
prob = ODEProblem(rhs!, u, tspan, p)
39+
solve(prob, Tsit5())
40+
```
41+
42+
The ODE solves fine. Now let's try to differentiate with respect to the tunable parameters.
43+
44+
```@example
45+
using Zygote
46+
using SciMLSensitivity
47+
48+
# 5 subparams[i].p, 50 elements in coeffs
49+
Zygote.gradient(0.1ones(55)) do tunables
50+
subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
51+
coeffs = reshape(tunables[6:end], size(p.coeffs))
52+
newp = Parameters(subpars, coeffs)
53+
newprob = remake(prob; p = newp)
54+
sol = solve(prob, Tsit5())
55+
return sum(sol.u[end])
56+
end
57+
```
58+
59+
SciMLSensitivity does not know how to handle the parameter object, because it does not
60+
implement the SciMLStructures interface.
61+
62+
```@example
63+
import SciMLStructures as SS
64+
65+
# Mark the struct as a SciMLStructure
66+
SS.isscimlstructure(::Parameters) = true
67+
# It is mutable
68+
SS.ismutablescimlstructure(::Parameters) = true
69+
70+
# Only contains `Tunable` portion
71+
# We could also add a `Constants` portion to contain the values that are
72+
# not tunable. The implementation would be similar to this one.
73+
SS.hasportion(::SS.Tunable, ::Parameters) = true
74+
75+
function SS.canonicalize(::SS.Tunable, p::Parameters)
76+
# concatenate all tunable values into a single vector
77+
buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
78+
79+
# repack takes a new vector of the same length as `buffer`, and constructs
80+
# a new `Parameters` object using the values from the new vector for tunables
81+
# and retaining old values for other parameters. This is exactly what replace does,
82+
# so we can use that instead.
83+
repack = let p = p
84+
function repack(newbuffer)
85+
SS.replace(SS.Tunable(), p, newbuffer)
86+
end
87+
end
88+
# the canonicalized vector, the repack function, and a boolean indicating
89+
# whether the buffer aliases values in the parameter object (here, it doesn't)
90+
return buffer, repack, false
91+
end
92+
93+
function SS.replace(::SS.Tunable, p::Parameters, newbuffer)
94+
N = length(p.subparams) + length(p.coeffs)
95+
@assert length(newbuffer) == N
96+
subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
97+
coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs))
98+
return Parameters(subparams, coeffs)
99+
end
100+
101+
function SS.replace!(::SS.Tunable, p::Parameters, newbuffer)
102+
N = length(p.subparams) + length(p.coeffs)
103+
@assert length(newbuffer) == N
104+
for (subpar, val) in zip(p.subparams, newbuffer)
105+
subpar.p = val
106+
end
107+
copyto!(coeffs, view(newbuffer, (length(p.subparams)+1):length(newbuffer)))
108+
return p
109+
end
110+
```
111+
112+
Now, we should be able to differentiate through the ODE solve.
113+
114+
```@example
115+
Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunables
116+
newp = SS.replace(SS.Tunable(), p, tunables)
117+
newprob = remake(prob; p = newp)
118+
sol = solve(newprob, Tsit5())
119+
return sum(sol.u[end])
120+
end
121+
```

0 commit comments

Comments
 (0)