From f4459453cf6532df242b34ee8a6be0c07d68c9cc Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 17 Oct 2024 15:28:01 +0530 Subject: [PATCH 1/4] docs: add doc example for implementing the interface --- docs/Project.toml | 6 +++ docs/make.jl | 1 + docs/src/example.md | 121 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+) create mode 100644 docs/src/example.md diff --git a/docs/Project.toml b/docs/Project.toml index df7860c..05969e0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,7 +1,13 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Documenter = "1" +OrdinaryDiffEqTsit5 = "1.1.0" +SciMLSensitivity = "7.69" SciMLStructures = "1" +Zygote = "0.6.72" diff --git a/docs/make.jl b/docs/make.jl index 94c6d25..de48378 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -6,6 +6,7 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true) pages = [ "Home" => "index.md", "interface.md", + "example.md", "api.md" ] diff --git a/docs/src/example.md b/docs/src/example.md new file mode 100644 index 0000000..096be0d --- /dev/null +++ b/docs/src/example.md @@ -0,0 +1,121 @@ +# An example implementation of the interface + +In this tutorial we will implement the SciMLStructures.jl interface for a parameter +object. This is useful when differentiating through ODE solves using SciMLSensitivity.jl +and only part of the parameters are differentiable. + +```@example +using OrdinaryDiffEqTsit5 +using LinearAlgebra + +mutable struct SubproblemParameters{P, Q, R} + p::P # tunable + q::Q + r::R +end + +mutable struct Parameters{P, C} + subparams::P + coeffs::C # tunable matrix +end + +# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams) +# and `du[length(subparams)+1:end] .= coeffs * u` +function rhs!(du, u, p::Parameters, t) + for (i, subpars) in enumerate(p.subparams) + du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t + end + N = length(p.subparams) + mul!(view(du, (N+1):(length(du))), p.coeffs, u) + return nothing +end + +u = sin.(0.1:0.1:1.0) +subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5] +p = Parameters(subparams, cos.([0.1i+0.33j for i in 1:5, j in 1:10])) +tspan = (0.0, 1.0) + +prob = ODEProblem(rhs!, u, tspan, p) +solve(prob, Tsit5()) +``` + +The ODE solves fine. Now let's try to differentiate with respect to the tunable parameters. + +```@example +using Zygote +using SciMLSensitivity + +# 5 subparams[i].p, 50 elements in coeffs +Zygote.gradient(0.1ones(55)) do tunables + subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)] + coeffs = reshape(tunables[6:end], size(p.coeffs)) + newp = Parameters(subpars, coeffs) + newprob = remake(prob; p = newp) + sol = solve(prob, Tsit5()) + return sum(sol.u[end]) +end +``` + +SciMLSensitivity does not know how to handle the parameter object, because it does not +implement the SciMLStructures interface. + +```@example +import SciMLStructures as SS + +# Mark the struct as a SciMLStructure +SS.isscimlstructure(::Parameters) = true +# It is mutable +SS.ismutablescimlstructure(::Parameters) = true + +# Only contains `Tunable` portion +# We could also add a `Constants` portion to contain the values that are +# not tunable. The implementation would be similar to this one. +SS.hasportion(::SS.Tunable, ::Parameters) = true + +function SS.canonicalize(::SS.Tunable, p::Parameters) + # concatenate all tunable values into a single vector + buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs)) + + # repack takes a new vector of the same length as `buffer`, and constructs + # a new `Parameters` object using the values from the new vector for tunables + # and retaining old values for other parameters. This is exactly what replace does, + # so we can use that instead. + repack = let p = p + function repack(newbuffer) + SS.replace(SS.Tunable(), p, newbuffer) + end + end + # the canonicalized vector, the repack function, and a boolean indicating + # whether the buffer aliases values in the parameter object (here, it doesn't) + return buffer, repack, false +end + +function SS.replace(::SS.Tunable, p::Parameters, newbuffer) + N = length(p.subparams) + length(p.coeffs) + @assert length(newbuffer) == N + subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)] + coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs)) + return Parameters(subparams, coeffs) +end + +function SS.replace!(::SS.Tunable, p::Parameters, newbuffer) + N = length(p.subparams) + length(p.coeffs) + @assert length(newbuffer) == N + for (subpar, val) in zip(p.subparams, newbuffer) + subpar.p = val + end + copyto!(coeffs, view(newbuffer, (length(p.subparams)+1):length(newbuffer))) + return p +end +``` + +Now, we should be able to differentiate through the ODE solve. + +```@example +Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunables + newp = SS.replace(SS.Tunable(), p, tunables) + newprob = remake(prob; p = newp) + sol = solve(newprob, Tsit5()) + return sum(sol.u[end]) +end +``` From 0310dd44ea720c0ce80260b0d518f09f90c4aa68 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 18 Nov 2024 14:54:38 +0530 Subject: [PATCH 2/4] docs: add documentation for `Constants` portion --- docs/src/example.md | 57 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/docs/src/example.md b/docs/src/example.md index 096be0d..dcb743c 100644 --- a/docs/src/example.md +++ b/docs/src/example.md @@ -57,7 +57,8 @@ end ``` SciMLSensitivity does not know how to handle the parameter object, because it does not -implement the SciMLStructures interface. +implement the SciMLStructures interface. The bare minimum necessary for SciMLSensitivity +is the `Tunable` portion. ```@example import SciMLStructures as SS @@ -119,3 +120,57 @@ Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunable return sum(sol.u[end]) end ``` + +We can also implement a `Constants` portion to store the rest of the values: + +```@example +SS.hasportion(::SS.Constants, ::Parameters) = true + +function SS.canonicalize(::SS.Constants, p::Parameters) + buffer = mapreduce(vcat, p.subparams) do subpar + [subpar.q, subpar.r] + end + repack = let p = p + function repack(newbuffer) + SS.replace(SS.Constants(), p, newbuffer) + end + end + + return buffer, repack, false +end + +function SS.replace(::SS.Constants, p::Parameters, newbuffer) + subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i-1], newbuffer[2i]) for i in eachindex(p.subparams)] + return Parameters(subpars, p.coeffs) +end + +function SS.replace!(::SS.Constants, p::Parameters, newbuffer) + for i in eachindex(p.subparams) + p.subparams[i].q = newbuffer[2i-1] + p.subparams[i].r = newbuffer[2i] + end + return p +end + +buf, repack, alias = SS.canonicalize(SS.Constants(), p) +buf +``` + +```@example +repack(ones(length(buf))) +``` + +```@example +SS.replace(SS.Constants(), p, ones(length(buf))) +``` + +```@example +SS.replace!(SS.Constants(), p, ones(length(buf))) +p +``` + +In general, all values belonging to a portion should be concatenated into an array of the +appropriate length in `canonicalize`. If a higher dimensional array is part of the portion, +it should be flattened. If a portion contains values of multiple types, a non-concrete +array should be used to store the values. `replace` and `replace!` should assume the array +they receive have the same ordering as the one returned from `canonicalize`. From c4b193b0fdf0fab41e3a4501f9db949460951fea Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 18 Nov 2024 13:14:53 -0100 Subject: [PATCH 3/4] Update example.md --- docs/src/example.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/src/example.md b/docs/src/example.md index dcb743c..0692605 100644 --- a/docs/src/example.md +++ b/docs/src/example.md @@ -4,7 +4,7 @@ In this tutorial we will implement the SciMLStructures.jl interface for a parame object. This is useful when differentiating through ODE solves using SciMLSensitivity.jl and only part of the parameters are differentiable. -```@example +```@example basic_tutorial using OrdinaryDiffEqTsit5 using LinearAlgebra @@ -41,7 +41,7 @@ solve(prob, Tsit5()) The ODE solves fine. Now let's try to differentiate with respect to the tunable parameters. -```@example +```@example basic_tutorial using Zygote using SciMLSensitivity @@ -60,7 +60,7 @@ SciMLSensitivity does not know how to handle the parameter object, because it do implement the SciMLStructures interface. The bare minimum necessary for SciMLSensitivity is the `Tunable` portion. -```@example +```@example basic_tutorial import SciMLStructures as SS # Mark the struct as a SciMLStructure @@ -112,7 +112,7 @@ end Now, we should be able to differentiate through the ODE solve. -```@example +```@example basic_tutorial Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunables newp = SS.replace(SS.Tunable(), p, tunables) newprob = remake(prob; p = newp) @@ -123,7 +123,7 @@ end We can also implement a `Constants` portion to store the rest of the values: -```@example +```@example basic_tutorial SS.hasportion(::SS.Constants, ::Parameters) = true function SS.canonicalize(::SS.Constants, p::Parameters) @@ -156,15 +156,15 @@ buf, repack, alias = SS.canonicalize(SS.Constants(), p) buf ``` -```@example +```@example basic_tutorial repack(ones(length(buf))) ``` -```@example +```@example basic_tutorial SS.replace(SS.Constants(), p, ones(length(buf))) ``` -```@example +```@example basic_tutorial SS.replace!(SS.Constants(), p, ones(length(buf))) p ``` From e69ac19e4e5bd19c39fc0b2625b61ceb4f7274f7 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 18 Nov 2024 16:07:27 -0100 Subject: [PATCH 4/4] Update example.md --- docs/src/example.md | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/docs/src/example.md b/docs/src/example.md index 0692605..29caa2a 100644 --- a/docs/src/example.md +++ b/docs/src/example.md @@ -46,7 +46,7 @@ using Zygote using SciMLSensitivity # 5 subparams[i].p, 50 elements in coeffs -Zygote.gradient(0.1ones(55)) do tunables +function simulate_with_tunables(tunables) subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)] coeffs = reshape(tunables[6:end], size(p.coeffs)) newp = Parameters(subpars, coeffs) @@ -113,12 +113,7 @@ end Now, we should be able to differentiate through the ODE solve. ```@example basic_tutorial -Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunables - newp = SS.replace(SS.Tunable(), p, tunables) - newprob = remake(prob; p = newp) - sol = solve(newprob, Tsit5()) - return sum(sol.u[end]) -end +Zygote.gradient(simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) ``` We can also implement a `Constants` portion to store the rest of the values: