@@ -66,40 +66,25 @@ def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn
66
66
self .vf .register_parameter ('dummy_parameter' , dummy_parameter )
67
67
self .vf_params = torch .cat ([p .contiguous ().flatten () for p in self .vf .parameters ()])
68
68
69
- # instantiates an underlying autograd.Function that overrides the backward pass with the intended version
70
- # sensitivity algorithm
71
- if self .sensalg == 'adjoint' : # alias .apply as direct call to preserve consistency of call signature
72
- self .autograd_function = _gather_odefunc_adjoint (self .vf , self .vf_params , solver , atol , rtol , interpolator ,
73
- solver_adjoint , atol_adjoint , rtol_adjoint , integral_loss ,
74
- problem_type = 'standard' ).apply
75
- elif self .sensalg == 'interpolated_adjoint' :
76
- self .autograd_function = _gather_odefunc_interp_adjoint (self .vf , self .vf_params , solver , atol , rtol , interpolator ,
77
- solver_adjoint , atol_adjoint , rtol_adjoint , integral_loss ,
78
- problem_type = 'standard' ).apply
79
-
80
-
81
- def _prep_odeint (self ):
69
+ def _autograd_func (self ):
82
70
"create autograd functions for backward pass"
83
71
self .vf_params = torch .cat ([p .contiguous ().flatten () for p in self .vf .parameters ()])
84
72
if self .sensalg == 'adjoint' : # alias .apply as direct call to preserve consistency of call signature
85
- self . autograd_function = _gather_odefunc_adjoint (self .vf , self .vf_params , self .solver , self .atol , self .rtol , self .interpolator ,
73
+ return _gather_odefunc_adjoint (self .vf , self .vf_params , self .solver , self .atol , self .rtol , self .interpolator ,
86
74
self .solver_adjoint , self .atol_adjoint , self .rtol_adjoint , self .integral_loss ,
87
75
problem_type = 'standard' ).apply
88
76
elif self .sensalg == 'interpolated_adjoint' :
89
- self . autograd_function = _gather_odefunc_interp_adjoint (self .vf , self .vf_params , self .solver , self .atol , self .rtol , self .interpolator ,
77
+ return _gather_odefunc_interp_adjoint (self .vf , self .vf_params , self .solver , self .atol , self .rtol , self .interpolator ,
90
78
self .solver_adjoint , self .atol_adjoint , self .rtol_adjoint , self .integral_loss ,
91
- problem_type = 'standard' ).apply
92
-
79
+ problem_type = 'standard' ).apply
93
80
94
81
def odeint (self , x :Tensor , t_span :Tensor , save_at :Tensor = (), args = {}):
95
82
"Returns Tuple(`t_eval`, `solution`)"
96
- self ._prep_odeint ()
97
83
if self .sensalg == 'autograd' :
98
84
return odeint (self .vf , x , t_span , self .solver , self .atol , self .rtol , interpolator = self .interpolator ,
99
85
save_at = save_at , args = args )
100
-
101
86
else :
102
- return self .autograd_function (self .vf_params , x , t_span , save_at , args )
87
+ return self ._autograd_func () (self .vf_params , x , t_span , save_at , args )
103
88
104
89
def forward (self , x :Tensor , t_span :Tensor , save_at :Tensor = (), args = {}):
105
90
"For safety redirects to intended method `odeint`"
@@ -128,39 +113,29 @@ def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd'
128
113
self .parallel_solver = solver
129
114
self .fine_steps , self .maxiter = fine_steps , maxiter
130
115
116
+ def _autograd_func (self ):
117
+ "create autograd functions for backward pass"
118
+ self .vf_params = torch .cat ([p .contiguous ().flatten () for p in self .vf .parameters ()])
131
119
if self .sensalg == 'adjoint' : # alias .apply as direct call to preserve consistency of call signature
132
- self . autograd_function = _gather_odefunc_adjoint (self .vf , self .vf_params , solver , 0 , 0 , None ,
133
- solver_adjoint , atol_adjoint , rtol_adjoint , integral_loss ,
134
- 'multiple_shooting' , fine_steps , maxiter ).apply
120
+ return _gather_odefunc_adjoint (self .vf , self .vf_params , self . solver , 0 , 0 , None ,
121
+ self . solver_adjoint , self . atol_adjoint , self . rtol_adjoint , self . integral_loss ,
122
+ 'multiple_shooting' , self . fine_steps , self . maxiter ).apply
135
123
elif self .sensalg == 'interpolated_adjoint' :
136
- self . autograd_function = _gather_odefunc_interp_adjoint (self .vf , self .vf_params , solver , 0 , 0 , None ,
137
- solver_adjoint , atol_adjoint , rtol_adjoint , integral_loss ,
138
- 'multiple_shooting' , fine_steps , maxiter ).apply
139
-
124
+ return _gather_odefunc_interp_adjoint (self .vf , self .vf_params , self . solver , 0 , 0 , None ,
125
+ self . solver_adjoint , self . atol_adjoint , self . rtol_adjoint , self . integral_loss ,
126
+ 'multiple_shooting' , self . fine_steps , self . maxiter ).apply
127
+
140
128
def odeint (self , x :Tensor , t_span :Tensor , B0 :Tensor = None ):
141
129
"Returns Tuple(`t_eval`, `solution`)"
142
- self ._prep_odeint ()
143
130
if self .sensalg == 'autograd' :
144
131
return odeint_mshooting (self .vf , x , t_span , self .parallel_solver , B0 , self .fine_steps , self .maxiter )
145
132
else :
146
- return self .autograd_function (self .vf_params , x , t_span , B0 )
133
+ return self ._autograd_func () (self .vf_params , x , t_span , B0 )
147
134
148
135
def forward (self , x :Tensor , t_span :Tensor , B0 :Tensor = None ):
149
136
"For safety redirects to intended method `odeint`"
150
137
return self .odeint (x , t_span , B0 )
151
138
152
- def _prep_odeint (self ):
153
- "create autograd functions for backward pass"
154
- self .vf_params = torch .cat ([p .contiguous ().flatten () for p in self .vf .parameters ()])
155
- if self .sensalg == 'adjoint' : # alias .apply as direct call to preserve consistency of call signature
156
- self .autograd_function = _gather_odefunc_adjoint (self .vf , self .vf_params , self .solver , 0 , 0 , None ,
157
- self .solver_adjoint , self .atol_adjoint , self .rtol_adjoint , self .integral_loss ,
158
- 'multiple_shooting' , self .fine_steps , self .maxiter ).apply
159
- elif self .sensalg == 'interpolated_adjoint' :
160
- self .autograd_function = _gather_odefunc_interp_adjoint (self .vf , self .vf_params , self .solver , 0 , 0 , None ,
161
- self .solver_adjoint , self .atol_adjoint , self .rtol_adjoint , self .integral_loss ,
162
- 'multiple_shooting' , self .fine_steps , self .maxiter ).apply
163
-
164
139
165
140
class SDEProblem (nn .Module ):
166
141
def __init__ (self ):
0 commit comments