66from ngclearn .utils .model_utils import create_function
77from ngclearn .utils .diffeq .ode_utils import get_integrator_code , step_euler , step_rk2 , step_rk4
88
9- def _dfz_fn (z , j_input , j_recurrent , eps , tau_x , sigma_rec ): ## raw dynamics ODE
10- dz_dt = - z + (j_recurrent + j_input ) + jnp .sqrt (2. * tau_x * (sigma_rec ) ^ 2 ) * eps
9+ def _dfz_fn (z , j_input , j_recurrent , eps , tau_x , sigma_rec , leak_scale ): ## raw dynamics ODE
10+ dz_dt = - ( z * leak_scale ) + (j_recurrent + j_input ) + jnp .sqrt (2. * tau_x * jnp . square (sigma_rec )) * eps
1111 return dz_dt * (1. / tau_x )
1212
13+ def _dfz (t , z , params ): ## raw dynamics ODE wrapper
14+ j_input , j_recurrent , eps , tau_x , sigma_rec , leak_scale = params
15+ return _dfz_fn (z , j_input , j_recurrent , eps , tau_x , sigma_rec , leak_scale )
16+
1317class LeakyNoiseCell (JaxComponent ): ## Real-valued, leaky noise cell
1418 """
1519 A non-spiking cell driven by the gradient dynamics entailed by a continuous-time noisy, leaky recurrent state.
@@ -55,13 +59,14 @@ class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell
5559 # Define Functions
5660 def __init__ (
5761 self , name , n_units , tau_x , act_fx = "relu" , integration_type = "euler" , batch_size = 1 , sigma_rec = 1. ,
58- shape = None , ** kwargs
62+ leak_scale = 1. , shape = None , ** kwargs
5963 ):
6064 super ().__init__ (name , ** kwargs )
6165
6266
6367 self .tau_x = tau_x
6468 self .sigma_rec = sigma_rec ## a "resistance" scaling factor
69+ self .leak_scale = leak_scale ## the leak scaling factor (most appropriate default is 1)
6570
6671 ## integration properties
6772 self .integrationType = integration_type
@@ -87,22 +92,23 @@ def __init__(
8792 self .r = Compartment (restVals , display_name = "Rectified Rate Activity" ) # rectified output
8893
8994 @compilable
90- def advance_state (self , t , dt ): #dt, fx, tau_x, sigma_rec, intgFlag, key, j_input, j_recurrent, x):
91- key , skey = random .split (self .key .get (), 2 )
95+ def advance_state (self , t , dt ):
9296 ### run a step of integration over neuronal dynamics
93- eps = random .normal ( skey [ 0 ], shape = self .x .get (). shape ) ## sample of unit distributional noise
94- #x = _run_cell(dt, self.j_input.get(), self.j_recurrent .get(), self.x.get(), eps, self.tau_x, self.sigma_rec, integType=self.intgFlag)
97+ key , skey = random .split ( self .key .get (), 2 )
98+ eps = random . normal ( skey , shape = self .x .get (). shape ) ## sample of unit distributional noise
9599
100+ #x = _run_cell(dt, self.j_input.get(), self.j_recurrent.get(), self.x.get(), eps, self.tau_x, self.sigma_rec, integType=self.intgFlag)
96101 _step_fns = {
97102 0 : step_euler ,
98103 1 : step_rk2 ,
99104 2 : step_rk4 ,
100105 }
101- _step_fn = _step_fns .get (self .intgFlag , step_euler )
102- params = (self .j_input .get (), self .j_recurrent .get (), eps , self .tau_x , self .sigma_rec )
103- _ , x = _step_fn (0. , self .x .get (), _dfz_fn , dt , params ) ## update state activation dynamics
106+ _step_fn = _step_fns [ self . intgFlag ] #_step_fns .get(self.intgFlag, step_euler)
107+ params = (self .j_input .get (), self .j_recurrent .get (), eps , self .tau_x , self .sigma_rec , self . leak_scale )
108+ _ , x = _step_fn (0. , self .x .get (), _dfz , dt , params ) ## update state activation dynamics
104109 r = self .fx (x ) ## calculate rectified / post-activation function value(s)
105110
111+ ## set compartments to next state values in accordance with dynamics
106112 self .key .set (key )
107113 self .x .set (x )
108114 self .r .set (r )
0 commit comments