@@ -46,15 +46,16 @@ function inplace_vjp(prob, u0, p, verbose, repack)
46
46
47
47
vjp = try
48
48
f = unwrapped_f (prob. f)
49
+ tspan_ = prob isa AbstractNonlinearProblem ? nothing : [prob. tspan[1 ]]
49
50
if p === nothing || p isa SciMLBase. NullParameters
50
- ReverseDiff. GradientTape ((copy (u0), [prob . tspan[ 1 ]] )) do u, t
51
+ ReverseDiff. GradientTape ((copy (u0), tspan_ )) do u, t
51
52
du1 = similar (u, size (u))
52
53
du1 .= 0
53
54
f (du1, u, p, first (t))
54
55
return vec (du1)
55
56
end
56
57
else
57
- ReverseDiff. GradientTape ((copy (u0), p, [prob . tspan[ 1 ]] )) do u, p, t
58
+ ReverseDiff. GradientTape ((copy (u0), p, tspan_ )) do u, p, t
58
59
du1 = similar (u, size (u))
59
60
du1 .= 0
60
61
f (du1, u, repack (p), first (t))
@@ -299,6 +300,7 @@ function DiffEqBase._concrete_solve_adjoint(
299
300
tunables, repack = Functors. functor (p)
300
301
end
301
302
303
+ u0 = state_values (prob) === nothing ? Float64[] : u0
302
304
default_sensealg = automatic_sensealg_choice (prob, u0, tunables, verbose, repack)
303
305
DiffEqBase. _concrete_solve_adjoint (prob, alg, default_sensealg, u0, p,
304
306
originator:: SciMLBase.ADOriginator , args... ; verbose,
@@ -371,6 +373,7 @@ function DiffEqBase._concrete_solve_adjoint(
371
373
args... ; save_start = true , save_end = true ,
372
374
saveat = eltype (prob. tspan)[],
373
375
save_idxs = nothing ,
376
+ initializealg_default = SciMLBase. OverrideInit (; abstol = 1e-6 , reltol = 1e-3 ),
374
377
kwargs... )
375
378
if ! (sensealg isa GaussAdjoint) &&
376
379
! (p isa Union{Nothing, SciMLBase. NullParameters, AbstractArray}) ||
@@ -412,16 +415,61 @@ function DiffEqBase._concrete_solve_adjoint(
412
415
Base. diff_names (Base. _nt_names (values (kwargs)),
413
416
(:callback_adj , :callback ))}(values (kwargs))
414
417
isq = sensealg isa QuadratureAdjoint
418
+ kwargs_init = kwargs_adj[Base. diff_names (Base. _nt_names (kwargs_adj), (:initializealg ,))]
419
+
420
+ if haskey (kwargs, :initializealg ) || haskey (prob. kwargs, :initializealg )
421
+ initializealg = haskey (kwargs, :initializealg ) ? kwargs[:initializealg ] : prob. kwargs[:initializealg ]
422
+ else
423
+ initializealg = DefaultInit ()
424
+ end
425
+
426
+ default_inits = Union{OverrideInit, Nothing, DefaultInit}
427
+ igs, new_u0, new_p, new_initializealg = if (SciMLBase. has_initialization_data (_prob. f) && initializealg isa default_inits)
428
+ local new_u0
429
+ local new_p
430
+ initializeprob = prob. f. initialization_data. initializeprob
431
+ iu0 = state_values (initializeprob)
432
+ isAD = if iu0 === nothing
433
+ AutoForwardDiff
434
+ elseif has_autodiff (alg)
435
+ OrdinaryDiffEqCore. alg_autodiff (alg) isa AutoForwardDiff
436
+ else
437
+ true
438
+ end
439
+ nlsolve_alg = default_nlsolve (nothing , Val (isinplace (_prob)), iu0, initializeprob, isAD)
440
+ initializealg = initializealg isa Union{Nothing, DefaultInit} ? initializealg_default : initializealg
441
+
442
+ iy, back = Zygote. pullback (tunables) do tunables
443
+ new_prob = remake (_prob, p = repack (tunables))
444
+ new_u0, new_p, _ = SciMLBase. get_initial_values (new_prob, new_prob, new_prob. f, initializealg, Val (isinplace (new_prob));
445
+ sensealg = SteadyStateAdjoint (autojacvec = sensealg. autojacvec),
446
+ nlsolve_alg,
447
+ kwargs_init... )
448
+ new_tunables, _, _ = SciMLStructures. canonicalize (SciMLStructures. Tunable (), new_p)
449
+ if SciMLBase. initialization_status (_prob) == SciMLBase. OVERDETERMINED
450
+ sum (new_tunables)
451
+ else
452
+ sum (new_u0) + sum (new_tunables)
453
+ end
454
+ end
455
+ igs = back (one (iy))[1 ] .- one (eltype (tunables))
456
+
457
+ igs, new_u0, new_p, SciMLBase. NoInit ()
458
+ else
459
+ nothing , u0, p, initializealg
460
+ end
461
+ _prob = remake (_prob, u0 = new_u0, p = new_p)
462
+
415
463
if sensealg isa BacksolveAdjoint
416
- sol = solve (_prob, alg, args... ; save_noise = true ,
464
+ sol = solve (_prob, alg, args... ; initializealg = new_initializealg, save_noise = true ,
417
465
save_start = save_start, save_end = save_end,
418
466
saveat = saveat, kwargs_fwd... )
419
467
elseif ischeckpointing (sensealg)
420
- sol = solve (_prob, alg, args... ; save_noise = true ,
468
+ sol = solve (_prob, alg, args... ; initializealg = new_initializealg, save_noise = true ,
421
469
save_start = true , save_end = true ,
422
470
saveat = saveat, kwargs_fwd... )
423
471
else
424
- sol = solve (_prob, alg, args... ; save_noise = true , save_start = true ,
472
+ sol = solve (_prob, alg, args... ; initializealg = new_initializealg, save_noise = true , save_start = true ,
425
473
save_end = true , kwargs_fwd... )
426
474
end
427
475
@@ -491,6 +539,7 @@ function DiffEqBase._concrete_solve_adjoint(
491
539
_save_idxs = save_idxs === nothing ? Colon () : save_idxs
492
540
493
541
function adjoint_sensitivity_backpass (Δ)
542
+ Δ = Δ isa AbstractThunk ? unthunk (Δ) : Δ
494
543
function df_iip (_out, u, p, t, i)
495
544
outtype = _out isa SubArray ?
496
545
ArrayInterface. parameterless_type (_out. parent) :
@@ -628,20 +677,22 @@ function DiffEqBase._concrete_solve_adjoint(
628
677
dgdu_discrete = df_iip,
629
678
sensealg = sensealg,
630
679
callback = cb2,
631
- kwargs_adj ... )
680
+ kwargs_init ... )
632
681
else
633
682
du0, dp = adjoint_sensitivities (sol, alg, args... ; t = ts,
634
683
dgdu_discrete = df_oop,
635
684
sensealg = sensealg,
636
685
callback = cb2,
637
- kwargs_adj ... )
686
+ kwargs_init ... )
638
687
end
639
688
640
689
du0 = reshape (du0, size (u0))
641
690
642
691
dp = p === nothing || p === DiffEqBase. NullParameters () ? nothing :
643
692
dp isa AbstractArray ? reshape (dp' , size (tunables)) : dp
644
693
694
+ dp = Zygote. accum (dp, igs)
695
+
645
696
_, repack_adjoint = if p === nothing || p === DiffEqBase. NullParameters () ||
646
697
! isscimlstructure (p)
647
698
nothing , x -> (x,)
@@ -1679,6 +1730,7 @@ function DiffEqBase._concrete_solve_adjoint(
1679
1730
u0, p, originator:: SciMLBase.ADOriginator ,
1680
1731
args... ; save_idxs = nothing , kwargs... )
1681
1732
_prob = remake (prob, u0 = u0, p = p)
1733
+
1682
1734
sol = solve (_prob, alg, args... ; kwargs... )
1683
1735
_save_idxs = save_idxs === nothing ? Colon () : save_idxs
1684
1736
@@ -1688,26 +1740,74 @@ function DiffEqBase._concrete_solve_adjoint(
1688
1740
out = SciMLBase. sensitivity_solution (sol, sol[_save_idxs])
1689
1741
end
1690
1742
1743
+ _, repack_adjoint = if isscimlstructure (p)
1744
+ Zygote. pullback (p) do p
1745
+ t, _, _ = canonicalize (Tunable (), p)
1746
+ t
1747
+ end
1748
+ elseif isfunctor (p)
1749
+ ps, re = Functors. functor (p)
1750
+ ps, x -> (re (x),)
1751
+ else
1752
+ nothing , x -> (x,)
1753
+ end
1754
+
1691
1755
function steadystatebackpass (Δ)
1756
+ Δ = Δ isa AbstractThunk ? unthunk (Δ) : Δ
1692
1757
# Δ = dg/dx or diffcache.dg_val
1693
1758
# del g/del p = 0
1694
1759
function df (_out, u, p, t, i)
1695
1760
if _save_idxs isa Number
1696
1761
_out[_save_idxs] = Δ[_save_idxs]
1697
1762
elseif Δ isa Number
1698
1763
@. _out[_save_idxs] = Δ
1699
- else
1764
+ elseif Δ isa AbstractArray{ <: AbstractArray } || Δ isa AbstractVectorOfArray || Δ isa AbstractArray
1700
1765
@. _out[_save_idxs] = Δ[_save_idxs]
1766
+ elseif isnothing (_out)
1767
+ _out
1768
+ else
1769
+ @. _out[_save_idxs] = Δ. u[_save_idxs]
1770
+ end
1771
+ end
1772
+ dp = adjoint_sensitivities (sol, alg; sensealg = sensealg, dgdu = df, initializealg = BrownFullBasicInit ())
1773
+
1774
+ dp, Δtunables = if Δ isa AbstractArray || Δ isa Number
1775
+ # if Δ isa AbstractArray, the gradients correspond to `u`
1776
+ # this is something that needs changing in the future, but
1777
+ # this is the applicable till the movement to structuaral
1778
+ # tangents is completed
1779
+ dp, Δtunables = if isscimlstructure (dp)
1780
+ dp, _, _ = canonicalize (Tunable (), dp)
1781
+ dp, nothing
1782
+ elseif isfunctor (dp)
1783
+ dp, _ = Functors. functor (dp)
1784
+ dp, nothing
1785
+ else
1786
+ dp, nothing
1787
+ end
1788
+ else
1789
+ dp, Δtunables = if isscimlstructure (p)
1790
+ Δp = setproperties (dp, to_nt (Δ. prob. p))
1791
+ Δtunables, _, _ = canonicalize (Tunable (), Δp)
1792
+ dp, _, _ = canonicalize (Tunable (), dp)
1793
+ dp, Δtunables
1794
+ elseif isfunctor (p)
1795
+ dp, _ = Functors. functor (dp)
1796
+ Δtunables, _ = Functors. functor (Δ. prob. p)
1797
+ dp, Δtunables
1798
+ else
1799
+ dp, Δ. prob. p
1701
1800
end
1702
1801
end
1703
- dp = adjoint_sensitivities (sol, alg; sensealg = sensealg, dgdu = df)
1802
+
1803
+ dp = Zygote. accum (dp, (isnothing (Δtunables) || isempty (Δtunables)) ? nothing : Δtunables)
1704
1804
1705
1805
if originator isa SciMLBase. TrackerOriginator ||
1706
1806
originator isa SciMLBase. ReverseDiffOriginator
1707
- (NoTangent (), NoTangent (), NoTangent (), dp , NoTangent (),
1807
+ (NoTangent (), NoTangent (), NoTangent (), repack_adjoint (dp)[ 1 ] , NoTangent (),
1708
1808
ntuple (_ -> NoTangent (), length (args))... )
1709
1809
else
1710
- (NoTangent (), NoTangent (), NoTangent (), NoTangent (), dp , NoTangent (),
1810
+ (NoTangent (), NoTangent (), NoTangent (), NoTangent (), repack_adjoint (dp)[ 1 ] , NoTangent (),
1711
1811
ntuple (_ -> NoTangent (), length (args))... )
1712
1812
end
1713
1813
end
0 commit comments