Skip to content

Commit 70df0b5

Browse files
authored
Add support for Pluto Notebooks in documentation and add ignore_derivatives example (#2527)
1 parent aecb975 commit 70df0b5

File tree

4 files changed

+234
-9
lines changed

4 files changed

+234
-9
lines changed

docs/Project.toml

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,15 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
44
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
55
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
66
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
7+
PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad"
78
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
89

10+
[sources]
11+
Enzyme = {path = ".."}
12+
EnzymeCore = {path = "../lib/EnzymeCore"}
13+
EnzymeTestUtils = {path = "../lib/EnzymeTestUtils"}
14+
915
[compat]
1016
Documenter = "1"
1117
EnzymeTestUtils = "0.2"
1218
Literate = "2"
13-
14-
[sources.Enzyme]
15-
path = ".."
16-
17-
[sources.EnzymeCore]
18-
path = "../lib/EnzymeCore"
19-
20-
[sources.EnzymeTestUtils]
21-
path = "../lib/EnzymeTestUtils"

docs/make.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,31 @@ end
3030

3131
examples = [title => joinpath("generated", string(name, ".md")) for (title, name) in examples]
3232

33+
# Generate Pluto notebooks
34+
35+
using PlutoStaticHTML
36+
const NOTEBOOK_DIR = joinpath(@__DIR__, "src", "notebooks")
37+
38+
"""
39+
build()
40+
41+
Run all Pluto notebooks (".jl" files) in `NOTEBOOK_DIR`.
42+
"""
43+
function build()
44+
println("Building notebooks in $NOTEBOOK_DIR")
45+
oopts = OutputOptions(; append_build_context = false)
46+
output_format = documenter_output
47+
bopts = BuildOptions(NOTEBOOK_DIR; output_format)
48+
build_notebooks(bopts, oopts)
49+
return nothing
50+
end
51+
52+
# Build the notebooks; defaults to true.
53+
if get(ENV, "BUILD_DOCS_NOTEBOOKS", "true") == "true"
54+
build()
55+
end
56+
57+
3358
makedocs(;
3459
modules=[Enzyme, EnzymeCore, EnzymeTestUtils],
3560
authors="William Moses <wmoses@mit.edu>, Valentin Churavy <vchuravy@mit.edu>",
@@ -44,10 +69,15 @@ makedocs(;
4469
attributes=Dict(Symbol("data-domain") => "enzyme.mit.edu", :defer => "")
4570
)
4671
],
72+
mathengine = MathJax3(),
73+
size_threshold = 10_000_000
4774
),
4875
pages = [
4976
"Home" => "index.md",
5077
"Examples" => examples,
78+
"Notebooks" => [
79+
"Ignore derivatives" => "notebooks/ignore_derivatives.md",
80+
],
5181
"FAQ" => "faq.md",
5282
"API reference" => "api.md",
5383
"Advanced" => [

docs/src/notebooks/Project.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[deps]
2+
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
5+
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
6+
PlutoLinks = "0ff47ea0-7a50-410d-8455-4348d5de0420"
7+
PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
8+
9+
[sources]
10+
Enzyme = {path = "../../.."}
11+
EnzymeCore = {path = "../../../lib/EnzymeCore"}
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
### A Pluto.jl notebook ###
2+
# v0.20.17
3+
4+
using Markdown
5+
using InteractiveUtils
6+
7+
# ╔═╡ b72e9218-81ba-11f0-1eba-5bd949c7ade4
8+
begin
9+
import Pkg
10+
# careful: this is _not_ a reproducible environment
11+
# activate the local environment
12+
Pkg.activate(".")
13+
Pkg.instantiate()
14+
using PlutoUI, PlutoLinks
15+
end
16+
17+
# ╔═╡ 9f5c0822-a19a-4c63-95e7-d2f066a7440f
18+
@revise using Enzyme
19+
20+
# ╔═╡ a4453d23-6e31-451f-b2cd-97346accac82
21+
@revise using EnzymeCore
22+
23+
# ╔═╡ bd0352c3-1b3c-42f5-ab93-7ca4cb67b9ad
24+
begin
25+
using CairoMakie
26+
set_theme!(
27+
theme_latexfonts();
28+
fontsize = 16,
29+
Lines = (linewidth = 2,),
30+
markersize = 16
31+
)
32+
end
33+
34+
# ╔═╡ df72e42f-7eec-476f-8ce5-72b09f620005
35+
md"""
36+
# Reproducing "Stabilizing backpropagation through time to learn complex physics"
37+
38+
Fig 1 from https://openreview.net/pdf?id=bozbTTWcaw
39+
"""
40+
41+
# ╔═╡ fabba18a-b8d8-479d-babd-c18279273fb5
42+
begin
43+
N(xᵢ, θ) = θ[1] * xᵢ^2 + θ[2] * xᵢ
44+
S(xᵢ, cᵢ) = xᵢ + cᵢ
45+
end
46+
47+
# ╔═╡ 5baa757c-c611-4d8b-ac37-4f97e585613e
48+
function simulate(N, S, x₀, y, θ, n)
49+
xᵢ = x₀
50+
for i in 1:n
51+
cᵢ = N(xᵢ, θ)
52+
xᵢ = S(xᵢ, cᵢ)
53+
end
54+
return L = 1 / 2 * (xᵢ - y)^2
55+
end
56+
57+
# ╔═╡ adf9ae2c-92b6-4826-bb01-12e46f365610
58+
begin
59+
x₀ = -0.3
60+
y = 2.0
61+
n = 4
62+
end
63+
64+
# ╔═╡ bdc31f7f-fa2d-45d5-bc7a-843340ad5426
65+
begin
66+
θ₁ = -4:0.01:4
67+
θ₂ = -4:0.01:4
68+
θ_space = collect(Iterators.product(θ₁, θ₂))
69+
end;
70+
71+
# ╔═╡ 3101e04b-7cbb-4a30-851d-c6183a65c8ae
72+
L_space = simulate.(N, S, x₀, y, θ_space, n);
73+
74+
# ╔═╡ 0c0ebb20-e794-4545-b94c-e026cb7fa3e2
75+
let
76+
fig, ax, hm = heatmap(
77+
θ₁, θ₂, L_space,
78+
colorscale = log10,
79+
colormap = Makie.Reverse(:Blues),
80+
colorrange = (10^-5, 10^5)
81+
)
82+
Colorbar(fig[:, end + 1], hm)
83+
84+
fig
85+
end
86+
87+
# ╔═╡ 45ee18f4-d6d3-40f4-bbc0-04cbd3b7b840
88+
function ∇simulate(N, S, x₀, y, θ, n)
89+
= MixedDuplicated(θ, Ref(Enzyme.make_zero(θ)))
90+
Enzyme.autodiff(Enzyme.Reverse, simulate, Const(N), Const(S), Const(x₀), Const(y), dθ, Const(n))
91+
return.dval[]
92+
end
93+
94+
# ╔═╡ ae6a671d-1559-4bff-af6e-78d2b54db020
95+
function plot_gradientfield(N, S, x₀, y, θ₁, θ₂, n)
96+
θ_space = collect(Iterators.product(θ₁, θ₂))
97+
gradient_field = ∇simulate.(N, S, x₀, y, θ_space, n)
98+
99+
fig, ax, hm = heatmap(
100+
θ₁, θ₂, map(x -> sqrt(x[1]^2 + x[2]^2), gradient_field),
101+
colorscale = log10,
102+
colormap = Makie.Reverse(:Blues),
103+
colorrange = (10^-3, 10^3)
104+
)
105+
Colorbar(fig[:, end + 1], hm)
106+
107+
streamplot!(
108+
ax, (θ) -> -∇simulate(N, S, x₀, y, θ, n), θ₁, θ₂,
109+
alpha = 0.5,
110+
colorscale = log10, color = p -> :red,
111+
arrow_size = 10
112+
)
113+
return fig
114+
end
115+
116+
# ╔═╡ be852753-126d-42fa-a55c-c907f5dce99d
117+
plot_gradientfield(N, S, x₀, y, θ₁, θ₂, n)
118+
119+
# ╔═╡ 0b6d0456-1f94-479f-b690-89ad7fc61e44
120+
begin
121+
@noinline function ignore_derivatives(x::T) where {T}
122+
return Core.inferencebarrier(x)::T
123+
end
124+
125+
function EnzymeRules.forward(
126+
config,
127+
::Const{typeof(ignore_derivatives)},
128+
A, x::Duplicated
129+
)
130+
return Enzyme.make_zero(x.val)
131+
end
132+
133+
function EnzymeRules.augmented_primal(
134+
config,
135+
::Const{typeof(ignore_derivatives)},
136+
FA, x
137+
)
138+
primal = EnzymeRules.needs_primal(config) ? x.val : nothing
139+
if x isa Active
140+
shadow = nothing
141+
else
142+
shadow = Enzyme.make_zero(x.val)
143+
end
144+
145+
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
146+
end
147+
function EnzymeRules.reverse(
148+
config,
149+
::Const{typeof(ignore_derivatives)},
150+
dret::Active, tape, x::Active
151+
)
152+
return (Enzyme.make_zero(x.val),)
153+
end
154+
155+
function EnzymeRules.reverse(
156+
config,
157+
::Const{typeof(ignore_derivatives)},
158+
::Type{<:Duplicated}, tape, x::Duplicated
159+
)
160+
return (nothing,)
161+
end
162+
end
163+
164+
# ╔═╡ 873e7792-99a1-4472-92c2-6fc32e2889fa
165+
N_stop(xᵢ, θ) = θ[1] * ignore_derivatives(xᵢ^2) + θ[2] * ignore_derivatives(xᵢ)
166+
167+
# ╔═╡ d71a22cc-c1f3-4425-8a6f-442a0bc4f215
168+
plot_gradientfield(N_stop, S, x₀, y, θ₁, θ₂, n)
169+
170+
# ╔═╡ Cell order:
171+
# ╠═b72e9218-81ba-11f0-1eba-5bd949c7ade4
172+
# ╠═9f5c0822-a19a-4c63-95e7-d2f066a7440f
173+
# ╠═a4453d23-6e31-451f-b2cd-97346accac82
174+
# ╠═bd0352c3-1b3c-42f5-ab93-7ca4cb67b9ad
175+
# ╟─df72e42f-7eec-476f-8ce5-72b09f620005
176+
# ╠═fabba18a-b8d8-479d-babd-c18279273fb5
177+
# ╠═5baa757c-c611-4d8b-ac37-4f97e585613e
178+
# ╠═adf9ae2c-92b6-4826-bb01-12e46f365610
179+
# ╠═bdc31f7f-fa2d-45d5-bc7a-843340ad5426
180+
# ╠═3101e04b-7cbb-4a30-851d-c6183a65c8ae
181+
# ╠═0c0ebb20-e794-4545-b94c-e026cb7fa3e2
182+
# ╠═45ee18f4-d6d3-40f4-bbc0-04cbd3b7b840
183+
# ╠═ae6a671d-1559-4bff-af6e-78d2b54db020
184+
# ╠═be852753-126d-42fa-a55c-c907f5dce99d
185+
# ╠═0b6d0456-1f94-479f-b690-89ad7fc61e44
186+
# ╠═873e7792-99a1-4472-92c2-6fc32e2889fa
187+
# ╠═d71a22cc-c1f3-4425-8a6f-442a0bc4f215

0 commit comments

Comments
 (0)