-
Notifications
You must be signed in to change notification settings - Fork 154
Plateau Neuron - Fixed Point Model #781
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0263677
d70b6e7
3f649d1
6647107
5db90d7
bab1fd1
5d8b10c
1961780
4be4918
70433d4
4b26480
fb217ba
d1c66f0
c469bd8
e8a8215
6ccbf74
cea6296
6e3744d
0c6bf24
d2aed48
6dd7e3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
|
||
import numpy as np | ||
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol | ||
from lava.magma.core.model.py.ports import PyInPort, PyOutPort | ||
from lava.magma.core.model.py.type import LavaPyType | ||
from lava.magma.core.resources import CPU | ||
from lava.magma.core.decorator import implements, requires, tag | ||
from lava.magma.core.model.py.model import PyLoihiProcessModel | ||
from lava.proc.plateau.process import Plateau | ||
|
||
|
||
@implements(proc=Plateau, protocol=LoihiProtocol) | ||
@requires(CPU) | ||
@tag("fixed_pt") | ||
class PyPlateauModelFixed(PyLoihiProcessModel): | ||
""" Implementation of Plateau neuron process in fixed point precision. | ||
|
||
Precisions of state variables | ||
|
||
- dv_dend : unsigned 12-bit integer (0 to 4095) | ||
- dv_soma : unsigned 12-bit integer (0 to 4095) | ||
- vth_dend : unsigned 17-bit integer (0 to 131071) | ||
- vth_soma : unsigned 17-bit integer (0 to 131071) | ||
kds300 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- up_dur : unsigned 8-bit integer (0 to 255) | ||
""" | ||
|
||
a_dend_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, np.int16, precision=16) | ||
a_soma_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, np.int16, precision=16) | ||
s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=24) | ||
v_dend: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24) | ||
v_soma: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24) | ||
dv_dend: int = LavaPyType(int, np.uint16, precision=12) | ||
dv_soma: int = LavaPyType(int, np.uint16, precision=12) | ||
vth_dend: int = LavaPyType(int, np.int32, precision=17) | ||
vth_soma: int = LavaPyType(int, np.int32, precision=17) | ||
up_dur: int = LavaPyType(int, np.uint16, precision=8) | ||
up_state: int = LavaPyType(np.ndarray, np.uint16, precision=8) | ||
|
||
def __init__(self, proc_params): | ||
super(PyPlateauModelFixed, self).__init__(proc_params) | ||
self._validate_inputs(proc_params) | ||
self.uv_bitwidth = 24 | ||
self.max_uv_val = 2 ** (self.uv_bitwidth - 1) | ||
self.decay_shift = 12 | ||
self.decay_unity = 2 ** self.decay_shift - 1 | ||
self.vth_shift = 6 | ||
self.act_shift = 6 | ||
self.isthrscaled = False | ||
self.effective_vth_dend = None | ||
self.effective_vth_soma = None | ||
self.s_out_buff = None | ||
|
||
def _validate_var(self, var, var_type, min_val, max_val, var_name): | ||
if not isinstance(var, var_type): | ||
raise ValueError(f"'{var_name}' must have type {var_type}") | ||
if var < min_val or var > max_val: | ||
raise ValueError( | ||
f"'{var_name}' must be in range [{min_val}, {max_val}]" | ||
) | ||
|
||
def _validate_inputs(self, proc_params): | ||
self._validate_var(proc_params['dv_dend'], int, 0, 4095, 'dv_dend') | ||
self._validate_var(proc_params['dv_soma'], int, 0, 4095, 'dv_soma') | ||
self._validate_var(proc_params['vth_dend'], int, 0, 131071, 'vth_dend') | ||
self._validate_var(proc_params['vth_soma'], int, 0, 131071, 'vth_soma') | ||
self._validate_var(proc_params['up_dur'], int, 0, 255, 'up_dur') | ||
|
||
def scale_threshold(self): | ||
self.effective_vth_dend = np.left_shift(self.vth_dend, self.vth_shift) | ||
self.effective_vth_soma = np.left_shift(self.vth_soma, self.vth_shift) | ||
self.isthrscaled = True | ||
|
||
def subthr_dynamics( | ||
self, | ||
activation_dend_in: np.ndarray, | ||
activation_soma_in: np.ndarray | ||
): | ||
"""Run the sub-threshold dynamics for both the dendrite and soma of the | ||
neuron. Both use 'leaky integration'. | ||
""" | ||
for v, dv, a_in in [ | ||
(self.v_dend, self.dv_dend, activation_dend_in), | ||
(self.v_soma, self.dv_soma, activation_soma_in), | ||
]: | ||
decayed_volt = np.int64(v) * (self.decay_unity - dv) | ||
decayed_volt = np.sign(decayed_volt) * np.right_shift( | ||
np.abs(decayed_volt), 12 | ||
) | ||
decayed_volt = np.int32(decayed_volt) | ||
updated_volt = decayed_volt + np.left_shift(a_in, self.act_shift) | ||
|
||
neg_voltage_limit = -np.int32(self.max_uv_val) + 1 | ||
pos_voltage_limit = np.int32(self.max_uv_val) - 1 | ||
|
||
v[:] = np.clip( | ||
updated_volt, neg_voltage_limit, pos_voltage_limit | ||
) | ||
|
||
def update_up_state(self): | ||
"""Decrements the up state (if necessary) and checks v_dend to see if | ||
up state needs to be (re)set. If up state is (re)set, then v_dend is | ||
reset to 0. | ||
""" | ||
self.up_state[self.up_state > 0] -= 1 | ||
self.up_state[self.v_dend > self.effective_vth_dend] = self.up_dur | ||
self.v_dend[self.v_dend > self.effective_vth_dend] = 0 | ||
|
||
def soma_spike_and_reset(self): | ||
"""Check the spiking conditions for the plateau soma. Checks if: | ||
v_soma > v_th_soma | ||
up_state > 0 | ||
|
||
For any neurons n that satisfy both conditions, sets: | ||
s_out_buff[n] = True | ||
v_soma = 0 | ||
""" | ||
s_out_buff = np.logical_and( | ||
self.v_soma > self.effective_vth_soma, | ||
self.up_state > 0 | ||
) | ||
self.v_soma[s_out_buff] = 0 | ||
|
||
return s_out_buff | ||
|
||
def run_spk(self): | ||
"""The run function that performs the actual computation during | ||
execution orchestrated by a PyLoihiProcessModel using the | ||
LoihiProtocol. | ||
""" | ||
|
||
# Receive synaptic input | ||
a_dend_in_data = self.a_dend_in.recv() | ||
a_soma_in_data = self.a_soma_in.recv() | ||
|
||
# Check threshold scaling | ||
if not self.isthrscaled: | ||
self.scale_threshold() | ||
|
||
self.subthr_dynamics(a_dend_in_data, a_soma_in_data) | ||
|
||
self.update_up_state() | ||
|
||
self.s_out_buff = self.soma_spike_and_reset() | ||
|
||
self.s_out.send(self.s_out_buff) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
|
||
import typing as ty | ||
from lava.magma.core.process.process import AbstractProcess | ||
from lava.magma.core.process.variable import Var | ||
from lava.magma.core.process.ports.ports import InPort, OutPort | ||
|
||
|
||
class Plateau(AbstractProcess): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a minor point. Since this is a LIF neuron, did you consider to let it inherit from the AbstractLIF process, and adding a 'LIF' in the class name? Not sure if it makes sense in this specific example, though. Up to you. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I originally didn't do this since I thought of it as two combined LIF neurons instead of a modified LIF neuron. I'll look over the AbstractLIF process and see if it makes sense to inherit from that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the class should inherit from AbstrictLIF, since it does not have current or bias vars. |
||
"""Plateau Neuron Process. | ||
|
||
Couples two modified LIF dynamics. The neuron posesses two voltages, | ||
v_dend and v_soma. Both follow sub-threshold LIF dynamics. When v_dend | ||
crosses v_th_dend, it resets and sets the up_state to the value up_dur. | ||
The supra-threshold behavior of v_soma depends on up_state: | ||
if up_state == 0: | ||
v_soma follows sub-threshold dynamics | ||
if up_state > 0: | ||
v_soma resets and the neuron sends out a spike | ||
|
||
Parameters | ||
---------- | ||
shape : tuple(int) | ||
Number and topology of Plateau neurons. | ||
dv_dend : int | ||
Inverse of the decay time-constant for the dendrite voltage. | ||
dv_soma : int | ||
Inverse of the decay time-constant for the soma voltage. | ||
vth_dend : int | ||
Dendrite threshold voltage, exceeding which, the neuron will enter the | ||
UP state. | ||
vth_soma : int | ||
Soma threshold voltage, exceeding which, the neuron will spike if it is | ||
also in the UP state. | ||
up_dur : int | ||
The duration, in timesteps, of the UP state. | ||
""" | ||
def __init__( | ||
self, | ||
shape: ty.Tuple[int, ...], | ||
dv_dend: int, | ||
dv_soma: int, | ||
vth_dend: int, | ||
vth_soma: int, | ||
up_dur: int, | ||
name: ty.Optional[str] = None, | ||
): | ||
super().__init__( | ||
shape=shape, | ||
dv_dend=dv_dend, | ||
dv_soma=dv_soma, | ||
name=name, | ||
up_dur=up_dur, | ||
vth_dend=vth_dend, | ||
vth_soma=vth_soma | ||
) | ||
self.a_dend_in = InPort(shape=shape) | ||
self.a_soma_in = InPort(shape=shape) | ||
self.s_out = OutPort(shape=shape) | ||
self.v_dend = Var(shape=shape, init=0) | ||
self.v_soma = Var(shape=shape, init=0) | ||
self.dv_dend = Var(shape=(1,), init=dv_dend) | ||
self.dv_soma = Var(shape=(1,), init=dv_soma) | ||
self.vth_dend = Var(shape=(1,), init=vth_dend) | ||
self.vth_soma = Var(shape=(1,), init=vth_soma) | ||
self.up_dur = Var(shape=(1,), init=up_dur) | ||
self.up_state = Var(shape=shape, init=0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
kds300 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
|
||
import unittest | ||
import numpy as np | ||
from lava.proc.plateau.process import Plateau | ||
from lava.proc.dense.process import Dense | ||
from lava.proc.io.source import RingBuffer as Source | ||
from lava.magma.core.run_configs import Loihi2SimCfg | ||
from lava.magma.core.run_conditions import RunSteps | ||
from lava.tests.lava.proc.lif.test_models import VecRecvProcess | ||
|
||
|
||
def create_spike_source(spike_list, n_indices, n_timesteps): | ||
"""Use list of spikes [(idx, timestep), ...] to create a RingBuffer source | ||
with data shape (n_indices, n_timesteps) and spikes at all specified points | ||
in the spike_list. | ||
""" | ||
data = np.zeros(shape=(n_indices, n_timesteps)) | ||
for idx, timestep in spike_list: | ||
data[idx, timestep - 1] = 1 | ||
return Source(data=data) | ||
|
||
|
||
class TestPlateauProcessModelsFixed(unittest.TestCase): | ||
"""Tests for the fixed point Plateau process models.""" | ||
def test_fixed_max_decay(self): | ||
""" | ||
Tests fixed point Plateau with max voltage decays. | ||
""" | ||
shape = (3,) | ||
num_steps = 20 | ||
spikes_in_dend = [(0, 5), (1, 5), (2, 5)] | ||
spikes_in_soma = [(0, 3), (1, 10), (2, 17)] | ||
sg_dend = create_spike_source(spikes_in_dend, shape[0], num_steps) | ||
sg_soma = create_spike_source(spikes_in_soma, shape[0], num_steps) | ||
dense_dend = Dense(weights=2 * np.diag(np.ones(shape=shape))) | ||
dense_soma = Dense(weights=2 * np.diag(np.ones(shape=shape))) | ||
plat = Plateau( | ||
shape=shape, | ||
dv_dend=4095, | ||
dv_soma=4095, | ||
vth_soma=1, | ||
vth_dend=1, | ||
up_dur=10 | ||
) | ||
vr = VecRecvProcess(shape=(num_steps, shape[0])) | ||
sg_dend.s_out.connect(dense_dend.s_in) | ||
sg_soma.s_out.connect(dense_soma.s_in) | ||
dense_dend.a_out.connect(plat.a_dend_in) | ||
dense_soma.a_out.connect(plat.a_soma_in) | ||
plat.s_out.connect(vr.s_in) | ||
# run model | ||
plat.run(RunSteps(num_steps), Loihi2SimCfg(select_tag='fixed_pt')) | ||
test_spk_data = vr.spk_data.get() | ||
plat.stop() | ||
# Gold standard for the test | ||
expected_spk_data = np.zeros((num_steps, shape[0])) | ||
# Neuron 2 should spike when receiving soma input | ||
expected_spk_data[10, 1] = 1 | ||
self.assertTrue(np.all(expected_spk_data == test_spk_data)) | ||
|
||
def test_up_dur(self): | ||
""" | ||
Tests that the UP state lasts for the time specified by the model. | ||
Checks that up_state decreases by one each time step after activation. | ||
""" | ||
shape = (1,) | ||
num_steps = 10 | ||
spikes_in_dend = [(0, 3)] | ||
sg_dend = create_spike_source(spikes_in_dend, shape[0], num_steps) | ||
dense_dend = Dense(weights=2 * (np.diag(np.ones(shape=shape)))) | ||
plat = Plateau( | ||
shape=shape, | ||
dv_dend=4095, | ||
dv_soma=4095, | ||
vth_soma=1, | ||
vth_dend=1, | ||
up_dur=5 | ||
) | ||
sg_dend.s_out.connect(dense_dend.s_in) | ||
dense_dend.a_out.connect(plat.a_dend_in) | ||
# run model | ||
test_up_state = [] | ||
for _ in range(num_steps): | ||
plat.run(RunSteps(1), Loihi2SimCfg(select_tag='fixed_pt')) | ||
test_up_state.append(plat.up_state.get().astype(int)[0]) | ||
plat.stop() | ||
# Gold standard for the test | ||
# UP state active time steps 4 - 9 (5 timesteps) | ||
# this is delayed by one b.c. of the Dense process | ||
expected_up_state = [0, 0, 0, 5, 4, 3, 2, 1, 0, 0] | ||
self.assertListEqual(expected_up_state, test_up_state) | ||
|
||
def test_fixed_dvs(self): | ||
""" | ||
Tests fixed point Plateau voltage decays. | ||
""" | ||
shape = (1,) | ||
num_steps = 10 | ||
spikes_in = [(0, 1)] | ||
sg_dend = create_spike_source(spikes_in, shape[0], num_steps) | ||
sg_soma = create_spike_source(spikes_in, shape[0], num_steps) | ||
dense_dend = Dense(weights=100 * np.diag(np.ones(shape=shape))) | ||
dense_soma = Dense(weights=100 * np.diag(np.ones(shape=shape))) | ||
plat = Plateau( | ||
shape=shape, | ||
dv_dend=2048, | ||
dv_soma=1024, | ||
vth_soma=100, | ||
vth_dend=100, | ||
up_dur=10 | ||
) | ||
sg_dend.s_out.connect(dense_dend.s_in) | ||
sg_soma.s_out.connect(dense_soma.s_in) | ||
dense_dend.a_out.connect(plat.a_dend_in) | ||
dense_soma.a_out.connect(plat.a_soma_in) | ||
# run model | ||
test_v_dend = [] | ||
test_v_soma = [] | ||
for _ in range(num_steps): | ||
plat.run(RunSteps(1), Loihi2SimCfg(select_tag='fixed_pt')) | ||
test_v_dend.append(plat.v_dend.get().astype(int)[0]) | ||
test_v_soma.append(plat.v_soma.get().astype(int)[0]) | ||
plat.stop() | ||
# Gold standard for the test | ||
# 100<<6 = 6400 -- initial value at time step 2 | ||
expected_v_dend = [ | ||
0, 6400, 3198, 1598, 798, 398, 198, 98, 48, 23 | ||
] | ||
expected_v_soma = [ | ||
0, 6400, 4798, 3597, 2696, 2021, 1515, 1135, 850, 637 | ||
] | ||
self.assertListEqual(expected_v_dend, test_v_dend) | ||
self.assertListEqual(expected_v_soma, test_v_soma) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you check if lintin passes? My guess is that there are a few points that must change, including missing lines at the end of files. Not functionally relevant, but important to keep a clean code base :-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've run the linting, (flakeheaven and bandit) and they pass on my local copy of the code. Also, I have the empty line at the end of the files locally, but it doesn't seem to show up on the github versions. Does github just not show the empty line at the end? |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
|
||
import unittest | ||
from lava.proc.plateau.process import Plateau | ||
|
||
|
||
class TestPlateauProcess(unittest.TestCase): | ||
"""Tests for Plateau class""" | ||
def test_init(self): | ||
"""Tests instantiation of Plateau""" | ||
plat = Plateau( | ||
shape=(100,), | ||
dv_dend=100, | ||
dv_soma=1, | ||
vth_dend=10, | ||
vth_soma=1, | ||
up_dur=10, | ||
name="Plat" | ||
) | ||
|
||
self.assertEqual(plat.name, "Plat") | ||
self.assertEqual(plat.dv_dend.init, 100) | ||
self.assertEqual(plat.dv_soma.init, 1) | ||
self.assertEqual(plat.vth_dend.init, 10) | ||
self.assertEqual(plat.vth_soma.init, 1) | ||
self.assertEqual(plat.up_dur.init, 10) |
Uh oh!
There was an error while loading. Please reload this page.