From 08334d1257ee395c63c25153f43be42803beaceb Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 3 Apr 2025 14:49:28 +0530 Subject: [PATCH 1/2] feat: accumulate gradients from pullback --- src/concrete_solve.jl | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 5319490d3..0dc7d49de 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -415,9 +415,8 @@ function DiffEqBase._concrete_solve_adjoint( (:callback_adj, :callback))}(values(kwargs)) isq = sensealg isa QuadratureAdjoint - igs, new_u0, new_p = if _prob.f.initialization_data !== nothing - local new_u0 - local new_p + igs, new_u0, new_p, inittype = if _prob.f.initialization_data !== nothing + local new_u0, new_p iy, back = Zygote.pullback(tunables) do tunables new_prob = remake(_prob, p = repack(tunables)) new_u0, new_p, _ = SciMLBase.get_initial_values(new_prob, new_prob, new_prob.f, SciMLBase.OverrideInit(), Val(true); @@ -425,30 +424,26 @@ function DiffEqBase._concrete_solve_adjoint( reltol = 1e-6, sensealg = SteadyStateAdjoint(autojacvec = sensealg.autojacvec)) new_tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), new_p) - if SciMLBase.initialization_status(_prob) == SciMLBase.OVERDETERMINED - sum(new_tunables) - else - sum(new_u0) + sum(new_tunables) - end + sum(new_u0) + sum(new_tunables) end igs = back(one(iy))[1] .- one(eltype(tunables)) - igs, new_u0, new_p + back(one(iy))[1], new_u0, new_p, SciMLBase.NoInit() else - nothing, u0, p + nothing, u0, p, haskey(kwargs, :initializealg) ? kwargs[:initializealg] : SciMLBase.CheckInit() end _prob = remake(_prob, u0 = new_u0, p = new_p) if sensealg isa BacksolveAdjoint - sol = solve(_prob, alg, args...; initializealg = SciMLBase.NoInit(), save_noise = true, + sol = solve(_prob, alg, args...; initializealg = inittype, save_noise = true, save_start = save_start, save_end = save_end, saveat = saveat, kwargs_fwd...) elseif ischeckpointing(sensealg) - sol = solve(_prob, alg, args...; initializealg = SciMLBase.NoInit(), save_noise = true, + sol = solve(_prob, alg, args...; initializealg = inittype, save_noise = true, save_start = true, save_end = true, saveat = saveat, kwargs_fwd...) else - sol = solve(_prob, alg, args...; initializealg = SciMLBase.NoInit(), save_noise = true, save_start = true, + sol = solve(_prob, alg, args...; initializealg = inittype, save_noise = true, save_start = true, save_end = true, kwargs_fwd...) end @@ -669,6 +664,7 @@ function DiffEqBase._concrete_solve_adjoint( dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : dp isa AbstractArray ? reshape(dp', size(tunables)) : dp + dp = Zygote.accum(dp, Δ.prob.p.tunable) dp = Zygote.accum(dp, igs) _, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() || @@ -1733,6 +1729,8 @@ function DiffEqBase._concrete_solve_adjoint( @. _out[_save_idxs] = Δ elseif Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray || Δ isa AbstractArray @. _out[_save_idxs] = Δ[_save_idxs] + elseif isnothing(_out) + _out else @. _out[_save_idxs] = Δ.u[_save_idxs] end From cec46e5374dbbe13b1757f52b2781b4459aa0828 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 3 Apr 2025 16:56:31 +0530 Subject: [PATCH 2/2] chore: rm literal_getproperty dispatch --- src/adjoint_common.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index c10d842eb..2b6d1528e 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -749,16 +749,16 @@ function out_and_ts(_ts, duplicate_iterator_times, sol) return out, ts end -if !hasmethod(Zygote.adjoint, - Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty), - SciMLBase.AbstractTimeseriesSolution, Val{:u}}) - Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution, - ::Val{:u}) - function solu_adjoint(Δ) - zerou = zero(sol.prob.u0) - _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) - (SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) - end - sol.u, solu_adjoint - end -end +# if !hasmethod(Zygote.adjoint, +# Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty), +# SciMLBase.AbstractTimeseriesSolution, Val{:u}}) +# Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution, +# ::Val{:u}) +# function solu_adjoint(Δ) +# zerou = zero(sol.prob.u0) +# _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) +# (SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) +# end +# sol.u, solu_adjoint +# end +# end