@@ -65,10 +65,10 @@ def a1(m):
6565 return _Linearization (init , step )
6666
6767 def ode_taylor_1st (
68- self , ode_order , damp , jvp_samples : int , jvp_samples_seed : int
68+ self , ode_order , damp , jvp_probes : int , jvp_probes_seed : int
6969 ) -> _Linearization :
70- del jvp_samples
71- del jvp_samples_seed
70+ del jvp_probes
71+ del jvp_probes_seed
7272
7373 def init ():
7474 return None
@@ -253,13 +253,13 @@ def __init__(self, unravel):
253253 self .unravel = unravel
254254
255255 def ode_taylor_1st (
256- self , ode_order , damp : float , jvp_samples : int , jvp_samples_seed : int
256+ self , ode_order , damp : float , jvp_probes : int , jvp_probes_seed : int
257257 ):
258258 if ode_order > 1 :
259259 raise ValueError
260260
261261 def init ():
262- return random .prng_key (seed = jvp_samples_seed )
262+ return random .prng_key (seed = jvp_probes_seed )
263263
264264 def step (fun , rv , key ):
265265 mean = rv .mean
@@ -282,7 +282,7 @@ def select_0(s):
282282 # Estimate the trace using Hutchinson's estimator
283283 # J_trace, jacobian_state = jacobian(Jvp, m0, jacobian_state)
284284 key , subkey = random .split (key , num = 2 )
285- sample_shape = (jvp_samples , * m0 .shape )
285+ sample_shape = (jvp_probes , * m0 .shape )
286286 v = random .rademacher (subkey , shape = sample_shape , dtype = m0 .dtype )
287287 J_trace = functools .vmap (lambda s : linalg .vector_dot (s , Jvp (s )))(v )
288288 J_trace = J_trace .mean (axis = 0 )
@@ -368,9 +368,57 @@ def a1(s):
368368 return _Linearization (init , step )
369369
370370 def ode_taylor_1st (
371- self , ode_order , damp : float , jvp_samples : int , jvp_samples_seed : int
371+ self , ode_order , damp : float , jvp_probes : int , jvp_probes_seed : int
372372 ):
373- raise NotImplementedError
373+ if ode_order > 1 :
374+ raise ValueError
375+
376+ def init ():
377+ return random .prng_key (seed = jvp_probes_seed )
378+
379+ def step (fun , rv , key ):
380+ mean = rv .mean
381+
382+ def a1 (s ):
383+ return s [[ode_order ], ...]
384+
385+ linop = functools .vmap (functools .jacrev (a1 ))(mean )
386+
387+ def vf_flat (u ):
388+ return tree_util .ravel_pytree (fun (unravel (u )))[0 ]
389+
390+ def select_0 (s ):
391+ return tree_util .ravel_pytree (self .unravel (s )[0 ])
392+
393+ # Evaluate the linearisation
394+ m0 , unravel = select_0 (rv .mean )
395+ fx , Jvp = functools .linearize (vf_flat , m0 )
396+
397+ key , subkey = random .split (key , num = 2 )
398+ sample_shape = (jvp_probes , * m0 .shape )
399+ v = random .rademacher (subkey , shape = sample_shape , dtype = m0 .dtype )
400+ J_diag = functools .vmap (lambda s : s * Jvp (s ))(v )
401+ J_diag = J_diag .mean (axis = 0 )
402+ E1 = functools .jacrev (lambda s : s [0 ])(rv .mean [0 ])
403+ linop = linop - J_diag [:, None , None ] * E1 [None , None , :]
404+
405+ fx = rv .mean [:, 1 ] - fx
406+ fx = fx [..., None ]
407+ diff = functools .vmap (lambda a , b : a @ b )(linop , rv .mean )
408+ fx = fx - diff
409+
410+ d , * _ = linop .shape
411+ cov_lower = damp * np .ones ((d , 1 , 1 ))
412+ bias = _normal .Normal (fx , cov_lower )
413+
414+ to_latent = np .ones ((linop .shape [2 ],))
415+ to_observed = np .ones ((linop .shape [1 ],))
416+ cond = _conditional .LatentCond (
417+ linop , bias , to_latent = to_latent , to_observed = to_observed
418+ )
419+ return cond , key
420+
421+ return _Linearization (init , step )
374422
375423 def ode_statistical_0th (self , cubature_fun , damp : float ):
376424 raise NotImplementedError
0 commit comments