|
| 1 | +### A Pluto.jl notebook ### |
| 2 | +# v0.20.17 |
| 3 | + |
| 4 | +using Markdown |
| 5 | +using InteractiveUtils |
| 6 | + |
| 7 | +# ╔═╡ b72e9218-81ba-11f0-1eba-5bd949c7ade4 |
| 8 | +begin |
| 9 | + import Pkg |
| 10 | + # careful: this is _not_ a reproducible environment |
| 11 | + # activate the local environment |
| 12 | + Pkg.activate(".") |
| 13 | + Pkg.instantiate() |
| 14 | + using PlutoUI, PlutoLinks |
| 15 | +end |
| 16 | + |
| 17 | +# ╔═╡ 9f5c0822-a19a-4c63-95e7-d2f066a7440f |
| 18 | +@revise using Enzyme |
| 19 | + |
| 20 | +# ╔═╡ a4453d23-6e31-451f-b2cd-97346accac82 |
| 21 | +@revise using EnzymeCore |
| 22 | + |
| 23 | +# ╔═╡ bd0352c3-1b3c-42f5-ab93-7ca4cb67b9ad |
| 24 | +begin |
| 25 | + using CairoMakie |
| 26 | + set_theme!( |
| 27 | + theme_latexfonts(); |
| 28 | + fontsize = 16, |
| 29 | + Lines = (linewidth = 2,), |
| 30 | + markersize = 16 |
| 31 | + ) |
| 32 | +end |
| 33 | + |
| 34 | +# ╔═╡ df72e42f-7eec-476f-8ce5-72b09f620005 |
| 35 | +md""" |
| 36 | +# Reproducing "Stabilizing backpropagation through time to learn complex physics" |
| 37 | +
|
| 38 | +Fig 1 from https://openreview.net/pdf?id=bozbTTWcaw |
| 39 | +""" |
| 40 | + |
| 41 | +# ╔═╡ fabba18a-b8d8-479d-babd-c18279273fb5 |
| 42 | +begin |
| 43 | + N(xᵢ, θ) = θ[1] * xᵢ^2 + θ[2] * xᵢ |
| 44 | + S(xᵢ, cᵢ) = xᵢ + cᵢ |
| 45 | +end |
| 46 | + |
| 47 | +# ╔═╡ 5baa757c-c611-4d8b-ac37-4f97e585613e |
| 48 | +function simulate(N, S, x₀, y, θ, n) |
| 49 | + xᵢ = x₀ |
| 50 | + for i in 1:n |
| 51 | + cᵢ = N(xᵢ, θ) |
| 52 | + xᵢ = S(xᵢ, cᵢ) |
| 53 | + end |
| 54 | + return L = 1 / 2 * (xᵢ - y)^2 |
| 55 | +end |
| 56 | + |
| 57 | +# ╔═╡ adf9ae2c-92b6-4826-bb01-12e46f365610 |
| 58 | +begin |
| 59 | + x₀ = -0.3 |
| 60 | + y = 2.0 |
| 61 | + n = 4 |
| 62 | +end |
| 63 | + |
| 64 | +# ╔═╡ bdc31f7f-fa2d-45d5-bc7a-843340ad5426 |
| 65 | +begin |
| 66 | + θ₁ = -4:0.01:4 |
| 67 | + θ₂ = -4:0.01:4 |
| 68 | + θ_space = collect(Iterators.product(θ₁, θ₂)) |
| 69 | +end; |
| 70 | + |
| 71 | +# ╔═╡ 3101e04b-7cbb-4a30-851d-c6183a65c8ae |
| 72 | +L_space = simulate.(N, S, x₀, y, θ_space, n); |
| 73 | + |
| 74 | +# ╔═╡ 0c0ebb20-e794-4545-b94c-e026cb7fa3e2 |
| 75 | +let |
| 76 | + fig, ax, hm = heatmap( |
| 77 | + θ₁, θ₂, L_space, |
| 78 | + colorscale = log10, |
| 79 | + colormap = Makie.Reverse(:Blues), |
| 80 | + colorrange = (10^-5, 10^5) |
| 81 | + ) |
| 82 | + Colorbar(fig[:, end + 1], hm) |
| 83 | + |
| 84 | + fig |
| 85 | +end |
| 86 | + |
| 87 | +# ╔═╡ 45ee18f4-d6d3-40f4-bbc0-04cbd3b7b840 |
| 88 | +function ∇simulate(N, S, x₀, y, θ, n) |
| 89 | + dθ = MixedDuplicated(θ, Ref(Enzyme.make_zero(θ))) |
| 90 | + Enzyme.autodiff(Enzyme.Reverse, simulate, Const(N), Const(S), Const(x₀), Const(y), dθ, Const(n)) |
| 91 | + return dθ.dval[] |
| 92 | +end |
| 93 | + |
| 94 | +# ╔═╡ ae6a671d-1559-4bff-af6e-78d2b54db020 |
| 95 | +function plot_gradientfield(N, S, x₀, y, θ₁, θ₂, n) |
| 96 | + θ_space = collect(Iterators.product(θ₁, θ₂)) |
| 97 | + gradient_field = ∇simulate.(N, S, x₀, y, θ_space, n) |
| 98 | + |
| 99 | + fig, ax, hm = heatmap( |
| 100 | + θ₁, θ₂, map(x -> sqrt(x[1]^2 + x[2]^2), gradient_field), |
| 101 | + colorscale = log10, |
| 102 | + colormap = Makie.Reverse(:Blues), |
| 103 | + colorrange = (10^-3, 10^3) |
| 104 | + ) |
| 105 | + Colorbar(fig[:, end + 1], hm) |
| 106 | + |
| 107 | + streamplot!( |
| 108 | + ax, (θ) -> -∇simulate(N, S, x₀, y, θ, n), θ₁, θ₂, |
| 109 | + alpha = 0.5, |
| 110 | + colorscale = log10, color = p -> :red, |
| 111 | + arrow_size = 10 |
| 112 | + ) |
| 113 | + return fig |
| 114 | +end |
| 115 | + |
| 116 | +# ╔═╡ be852753-126d-42fa-a55c-c907f5dce99d |
| 117 | +plot_gradientfield(N, S, x₀, y, θ₁, θ₂, n) |
| 118 | + |
| 119 | +# ╔═╡ 0b6d0456-1f94-479f-b690-89ad7fc61e44 |
| 120 | +begin |
| 121 | + @noinline function ignore_derivatives(x::T) where {T} |
| 122 | + return Core.inferencebarrier(x)::T |
| 123 | + end |
| 124 | + |
| 125 | + function EnzymeRules.forward( |
| 126 | + config, |
| 127 | + ::Const{typeof(ignore_derivatives)}, |
| 128 | + A, x::Duplicated |
| 129 | + ) |
| 130 | + return Enzyme.make_zero(x.val) |
| 131 | + end |
| 132 | + |
| 133 | + function EnzymeRules.augmented_primal( |
| 134 | + config, |
| 135 | + ::Const{typeof(ignore_derivatives)}, |
| 136 | + FA, x |
| 137 | + ) |
| 138 | + primal = EnzymeRules.needs_primal(config) ? x.val : nothing |
| 139 | + if x isa Active |
| 140 | + shadow = nothing |
| 141 | + else |
| 142 | + shadow = Enzyme.make_zero(x.val) |
| 143 | + end |
| 144 | + |
| 145 | + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) |
| 146 | + end |
| 147 | + function EnzymeRules.reverse( |
| 148 | + config, |
| 149 | + ::Const{typeof(ignore_derivatives)}, |
| 150 | + dret::Active, tape, x::Active |
| 151 | + ) |
| 152 | + return (Enzyme.make_zero(x.val),) |
| 153 | + end |
| 154 | + |
| 155 | + function EnzymeRules.reverse( |
| 156 | + config, |
| 157 | + ::Const{typeof(ignore_derivatives)}, |
| 158 | + ::Type{<:Duplicated}, tape, x::Duplicated |
| 159 | + ) |
| 160 | + return (nothing,) |
| 161 | + end |
| 162 | +end |
| 163 | + |
| 164 | +# ╔═╡ 873e7792-99a1-4472-92c2-6fc32e2889fa |
| 165 | +N_stop(xᵢ, θ) = θ[1] * ignore_derivatives(xᵢ^2) + θ[2] * ignore_derivatives(xᵢ) |
| 166 | + |
| 167 | +# ╔═╡ d71a22cc-c1f3-4425-8a6f-442a0bc4f215 |
| 168 | +plot_gradientfield(N_stop, S, x₀, y, θ₁, θ₂, n) |
| 169 | + |
| 170 | +# ╔═╡ Cell order: |
| 171 | +# ╠═b72e9218-81ba-11f0-1eba-5bd949c7ade4 |
| 172 | +# ╠═9f5c0822-a19a-4c63-95e7-d2f066a7440f |
| 173 | +# ╠═a4453d23-6e31-451f-b2cd-97346accac82 |
| 174 | +# ╠═bd0352c3-1b3c-42f5-ab93-7ca4cb67b9ad |
| 175 | +# ╟─df72e42f-7eec-476f-8ce5-72b09f620005 |
| 176 | +# ╠═fabba18a-b8d8-479d-babd-c18279273fb5 |
| 177 | +# ╠═5baa757c-c611-4d8b-ac37-4f97e585613e |
| 178 | +# ╠═adf9ae2c-92b6-4826-bb01-12e46f365610 |
| 179 | +# ╠═bdc31f7f-fa2d-45d5-bc7a-843340ad5426 |
| 180 | +# ╠═3101e04b-7cbb-4a30-851d-c6183a65c8ae |
| 181 | +# ╠═0c0ebb20-e794-4545-b94c-e026cb7fa3e2 |
| 182 | +# ╠═45ee18f4-d6d3-40f4-bbc0-04cbd3b7b840 |
| 183 | +# ╠═ae6a671d-1559-4bff-af6e-78d2b54db020 |
| 184 | +# ╠═be852753-126d-42fa-a55c-c907f5dce99d |
| 185 | +# ╠═0b6d0456-1f94-479f-b690-89ad7fc61e44 |
| 186 | +# ╠═873e7792-99a1-4472-92c2-6fc32e2889fa |
| 187 | +# ╠═d71a22cc-c1f3-4425-8a6f-442a0bc4f215 |
0 commit comments