Skip to content

Commit 78e58da

Browse files
author
Alexander Ororbia
committed
revised leaky-noise-cell, wrote its unit test, test-passed
1 parent c1a21ce commit 78e58da

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

ngclearn/components/neurons/graded/leakyNoiseCell.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@
66
from ngclearn.utils.model_utils import create_function
77
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2, step_rk4
88

9-
def _dfz_fn(z, j_input, j_recurrent, eps, tau_x, sigma_rec): ## raw dynamics ODE
10-
dz_dt = -z + (j_recurrent + j_input) + jnp.sqrt(2. * tau_x * (sigma_rec) ^ 2) * eps
9+
def _dfz_fn(z, j_input, j_recurrent, eps, tau_x, sigma_rec, leak_scale): ## raw dynamics ODE
10+
dz_dt = -(z * leak_scale) + (j_recurrent + j_input) + jnp.sqrt(2. * tau_x * jnp.square(sigma_rec)) * eps
1111
return dz_dt * (1. / tau_x)
1212

13+
def _dfz(t, z, params): ## raw dynamics ODE wrapper
14+
j_input, j_recurrent, eps, tau_x, sigma_rec, leak_scale = params
15+
return _dfz_fn(z, j_input, j_recurrent, eps, tau_x, sigma_rec, leak_scale)
16+
1317
class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell
1418
"""
1519
A non-spiking cell driven by the gradient dynamics entailed by a continuous-time noisy, leaky recurrent state.
@@ -55,13 +59,14 @@ class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell
5559
# Define Functions
5660
def __init__(
5761
self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_rec=1.,
58-
shape=None, **kwargs
62+
leak_scale=1., shape=None, **kwargs
5963
):
6064
super().__init__(name, **kwargs)
6165

6266

6367
self.tau_x = tau_x
6468
self.sigma_rec = sigma_rec ## a "resistance" scaling factor
69+
self.leak_scale = leak_scale ## the leak scaling factor (most appropriate default is 1)
6570

6671
## integration properties
6772
self.integrationType = integration_type
@@ -87,22 +92,23 @@ def __init__(
8792
self.r = Compartment(restVals, display_name="Rectified Rate Activity") # rectified output
8893

8994
@compilable
90-
def advance_state(self, t, dt): #dt, fx, tau_x, sigma_rec, intgFlag, key, j_input, j_recurrent, x):
91-
key, skey = random.split(self.key.get(), 2)
95+
def advance_state(self, t, dt):
9296
### run a step of integration over neuronal dynamics
93-
eps = random.normal(skey[0], shape=self.x.get().shape) ## sample of unit distributional noise
94-
#x = _run_cell(dt, self.j_input.get(), self.j_recurrent.get(), self.x.get(), eps, self.tau_x, self.sigma_rec, integType=self.intgFlag)
97+
key, skey = random.split(self.key.get(), 2)
98+
eps = random.normal(skey, shape=self.x.get().shape) ## sample of unit distributional noise
9599

100+
#x = _run_cell(dt, self.j_input.get(), self.j_recurrent.get(), self.x.get(), eps, self.tau_x, self.sigma_rec, integType=self.intgFlag)
96101
_step_fns = {
97102
0: step_euler,
98103
1: step_rk2,
99104
2: step_rk4,
100105
}
101-
_step_fn = _step_fns.get(self.intgFlag, step_euler)
102-
params = (self.j_input.get(), self.j_recurrent.get(), eps, self.tau_x, self.sigma_rec)
103-
_, x = _step_fn(0., self.x.get(), _dfz_fn, dt, params) ## update state activation dynamics
106+
_step_fn = _step_fns[self.intgFlag] #_step_fns.get(self.intgFlag, step_euler)
107+
params = (self.j_input.get(), self.j_recurrent.get(), eps, self.tau_x, self.sigma_rec, self.leak_scale)
108+
_, x = _step_fn(0., self.x.get(), _dfz, dt, params) ## update state activation dynamics
104109
r = self.fx(x) ## calculate rectified / post-activation function value(s)
105110

111+
## set compartments to next state values in accordance with dynamics
106112
self.key.set(key)
107113
self.x.set(x)
108114
self.r.set(r)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
import numpy as np
5+
np.random.seed(42)
6+
from ngclearn.components.neurons.graded.leakyNoiseCell import LeakyNoiseCell
7+
from numpy.testing import assert_array_equal
8+
9+
from ngclearn import Context, MethodProcess
10+
11+
12+
def test_LeakyNoiseCell1():
13+
name = "leaky_noise_ctx"
14+
dkey = random.PRNGKey(42)
15+
dkey, *subkeys = random.split(dkey, 100)
16+
dt = 1. # ms
17+
with Context(name) as ctx:
18+
a = LeakyNoiseCell(
19+
name="a", n_units=1, tau_x=50., act_fx="identity", integration_type="euler", batch_size=1, sigma_rec=0.,
20+
leak_scale=0.
21+
)
22+
advance_process = (MethodProcess("advance_proc") >> a.advance_state)
23+
reset_process = (MethodProcess("reset_proc") >> a.reset)
24+
25+
def clamp(x):
26+
a.j_input.set(x)
27+
28+
## input spike train
29+
x_seq = jnp.ones((1, 10))
30+
## desired output/epsp pulses
31+
y_seq = jnp.asarray([[0.02, 0.04, 0.06, 0.08, 0.09999999999999999, 0.11999999999999998, 0.13999999999999999, 0.15999999999999998, 0.17999999999999998, 0.19999999999999998]], dtype=jnp.float32)
32+
33+
outs = []
34+
reset_process.run()
35+
for ts in range(x_seq.shape[1]):
36+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
37+
clamp(x_t)
38+
advance_process.run(t=ts * 1., dt=dt)
39+
outs.append(a.x.get())
40+
outs = jnp.concatenate(outs, axis=1)
41+
# print(outs)
42+
# print(y_seq)
43+
## output should approximately equal input
44+
# assert_array_equal(outs, y_seq, tol=1e-3)
45+
np.testing.assert_allclose(outs, y_seq, atol=1e-3)
46+
47+
#test_LeakyNoiseCell1()

0 commit comments

Comments
 (0)