Skip to content

Commit 6a78ee5

Browse files
committed
Merge branch 'master' of https://github.com/DiffEqML/torchdyn
2 parents 7014032 + c6639c7 commit 6a78ee5

File tree

2 files changed

+20
-47
lines changed

2 files changed

+20
-47
lines changed

torchdyn/core/problems.py

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -66,40 +66,25 @@ def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn
6666
self.vf.register_parameter('dummy_parameter', dummy_parameter)
6767
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
6868

69-
# instantiates an underlying autograd.Function that overrides the backward pass with the intended version
70-
# sensitivity algorithm
71-
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
72-
self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, solver, atol, rtol, interpolator,
73-
solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss,
74-
problem_type='standard').apply
75-
elif self.sensalg == 'interpolated_adjoint':
76-
self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, solver, atol, rtol, interpolator,
77-
solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss,
78-
problem_type='standard').apply
79-
80-
81-
def _prep_odeint(self):
69+
def _autograd_func(self):
8270
"create autograd functions for backward pass"
8371
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
8472
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
85-
self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
73+
return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
8674
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
8775
problem_type='standard').apply
8876
elif self.sensalg == 'interpolated_adjoint':
89-
self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
77+
return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
9078
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
91-
problem_type='standard').apply
92-
79+
problem_type='standard').apply
9380

9481
def odeint(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}):
9582
"Returns Tuple(`t_eval`, `solution`)"
96-
self._prep_odeint()
9783
if self.sensalg == 'autograd':
9884
return odeint(self.vf, x, t_span, self.solver, self.atol, self.rtol, interpolator=self.interpolator,
9985
save_at=save_at, args=args)
100-
10186
else:
102-
return self.autograd_function(self.vf_params, x, t_span, save_at, args)
87+
return self._autograd_func()(self.vf_params, x, t_span, save_at, args)
10388

10489
def forward(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}):
10590
"For safety redirects to intended method `odeint`"
@@ -128,39 +113,29 @@ def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd'
128113
self.parallel_solver = solver
129114
self.fine_steps, self.maxiter = fine_steps, maxiter
130115

116+
def _autograd_func(self):
117+
"create autograd functions for backward pass"
118+
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
131119
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
132-
self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, solver, 0, 0, None,
133-
solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss,
134-
'multiple_shooting', fine_steps, maxiter).apply
120+
return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
121+
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
122+
'multiple_shooting', self.fine_steps, self.maxiter).apply
135123
elif self.sensalg == 'interpolated_adjoint':
136-
self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, solver, 0, 0, None,
137-
solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss,
138-
'multiple_shooting', fine_steps, maxiter).apply
139-
124+
return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
125+
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
126+
'multiple_shooting', self.fine_steps, self.maxiter).apply
127+
140128
def odeint(self, x:Tensor, t_span:Tensor, B0:Tensor=None):
141129
"Returns Tuple(`t_eval`, `solution`)"
142-
self._prep_odeint()
143130
if self.sensalg == 'autograd':
144131
return odeint_mshooting(self.vf, x, t_span, self.parallel_solver, B0, self.fine_steps, self.maxiter)
145132
else:
146-
return self.autograd_function(self.vf_params, x, t_span, B0)
133+
return self._autograd_func()(self.vf_params, x, t_span, B0)
147134

148135
def forward(self, x:Tensor, t_span:Tensor, B0:Tensor=None):
149136
"For safety redirects to intended method `odeint`"
150137
return self.odeint(x, t_span, B0)
151138

152-
def _prep_odeint(self):
153-
"create autograd functions for backward pass"
154-
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
155-
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
156-
self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
157-
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
158-
'multiple_shooting', self.fine_steps, self.maxiter).apply
159-
elif self.sensalg == 'interpolated_adjoint':
160-
self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
161-
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
162-
'multiple_shooting', self.fine_steps, self.maxiter).apply
163-
164139

165140
class SDEProblem(nn.Module):
166141
def __init__(self):

torchdyn/numerics/sensitivity.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ def generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator
3232
def _gather_odefunc_adjoint(vf, vf_params, solver, atol, rtol, interpolator, solver_adjoint,
3333
atol_adjoint, rtol_adjoint, integral_loss, problem_type, maxiter=4, fine_steps=4):
3434
"Prepares definition of autograd.Function for adjoint sensitivity analysis of the above `ODEProblem`"
35-
global _ODEProblemFuncAdjoint
36-
class _ODEProblemFuncAdjoint(Function):
35+
class _ODEProblemFunc(Function):
3736
@staticmethod
3837
def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
3938
t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B,
@@ -98,15 +97,14 @@ def adjoint_dynamics(t, A):
9897
λ_tspan = torch.stack([dLdt[0], dLdt[-1]])
9998
return (μ, λ, λ_tspan, None, None, None)
10099

101-
return _ODEProblemFuncAdjoint
100+
return _ODEProblemFunc
102101

103102

104103
#TODO: introduce `t_span` grad as above
105104
def _gather_odefunc_interp_adjoint(vf, vf_params, solver, atol, rtol, interpolator, solver_adjoint,
106105
atol_adjoint, rtol_adjoint, integral_loss, problem_type, maxiter=4, fine_steps=4):
107106
"Prepares definition of autograd.Function for interpolated adjoint sensitivity analysis of the above `ODEProblem`"
108-
global _ODEProblemFuncInterpAdjoint
109-
class _ODEProblemFuncInterpAdjoint(Function):
107+
class _ODEProblemFunc(Function):
110108
@staticmethod
111109
def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
112110
t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B,
@@ -160,4 +158,4 @@ def adjoint_dynamics(t, A):
160158
λ, μ = λ.reshape(λT.shape), μ.reshape(μT.shape)
161159
return (μ, λ, None, None, None)
162160

163-
return _ODEProblemFuncInterpAdjoint
161+
return _ODEProblemFunc

0 commit comments

Comments
 (0)