Skip to content

Commit bbd0337

Browse files
Deep-R
1 parent fef9f88 commit bbd0337

5 files changed

Lines changed: 89 additions & 67 deletions

File tree

examples/eprop/latency_mnist_deep_r.py renamed to examples/event_prop/latency_mnist_deep_r.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,65 @@
1-
import matplotlib.pyplot as plt
2-
import numpy as np
1+
import logging
32
import matplotlib.pyplot as plt
43
import mnist
4+
import numpy as np
55

66
from ml_genn import InputLayer, Layer, SequentialNetwork
77
from ml_genn.callbacks import Checkpoint
8-
from ml_genn.compilers import EPropCompiler, InferenceCompiler
9-
from ml_genn.connectivity import Dense,FixedProbability
8+
from ml_genn.compilers import EventPropCompiler, InferenceCompiler
9+
from ml_genn.connectivity import Dense, FixedProbability
1010
from ml_genn.initializers import Normal
1111
from ml_genn.neurons import LeakyIntegrate, LeakyIntegrateFire, SpikeInput
12+
from ml_genn.optimisers import Adam
1213
from ml_genn.serialisers import Numpy
14+
from ml_genn.synapses import Exponential
1315

1416
from time import perf_counter
15-
from ml_genn.utils.data import (calc_latest_spike_time, calc_max_spikes,
16-
log_latency_encode_data)
17-
18-
from ml_genn.compilers.eprop_compiler import default_params
17+
from ml_genn.utils.data import linear_latency_encode_data
1918

2019
NUM_INPUT = 784
2120
NUM_HIDDEN = 128
2221
NUM_OUTPUT = 10
23-
BATCH_SIZE = 128
22+
BATCH_SIZE = 32
2423
NUM_EPOCHS = 10
24+
EXAMPLE_TIME = 20.0
25+
DT = 1.0
2526
SPARSITY = 0.1
2627
TRAIN = True
2728
KERNEL_PROFILING = False
2829
PLOT_REWIRING = True
2930

3031
mnist.datasets_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
3132
labels = mnist.train_labels() if TRAIN else mnist.test_labels()
32-
spikes = log_latency_encode_data(
33+
spikes = linear_latency_encode_data(
3334
mnist.train_images() if TRAIN else mnist.test_images(),
34-
20.0, 51)
35+
EXAMPLE_TIME - (2.0 * DT), 2.0 * DT)
3536

3637
serialiser = Numpy("latency_mnist_checkpoints")
37-
network = SequentialNetwork(default_params)
38+
network = SequentialNetwork()
3839
with network:
3940
# Populations
40-
input = InputLayer(SpikeInput(max_spikes=BATCH_SIZE * calc_max_spikes(spikes)),
41+
input = InputLayer(SpikeInput(max_spikes=BATCH_SIZE * NUM_INPUT),
4142
NUM_INPUT)
42-
hidden = Layer(FixedProbability(SPARSITY, Normal(sd=1.0 / np.sqrt(NUM_INPUT))),
43-
LeakyIntegrateFire(v_thresh=0.61, tau_mem=20.0,
44-
tau_refrac=5.0),
45-
NUM_HIDDEN)
46-
output = Layer(Dense(Normal(sd=1.0 / np.sqrt(NUM_HIDDEN))),
47-
LeakyIntegrate(tau_mem=20.0, readout="sum_var"),
48-
NUM_OUTPUT)
43+
hidden = Layer(FixedProbability(SPARSITY, Normal(mean=0.0, sd=0.8)),
44+
LeakyIntegrateFire(v_thresh=1.0, tau_mem=20.0),
45+
NUM_HIDDEN, Exponential(5.0))
46+
output = Layer(Dense(Normal(mean=0.2, sd=0.37)),
47+
LeakyIntegrate(tau_mem=20.0, readout="avg_var"),
48+
NUM_OUTPUT, Exponential(5.0))
4949

50-
max_example_timesteps = int(np.ceil(calc_latest_spike_time(spikes)))
50+
max_example_timesteps = int(np.ceil(EXAMPLE_TIME / DT))
5151
if TRAIN:
52-
compiler = EPropCompiler(example_timesteps=max_example_timesteps,
53-
losses="sparse_categorical_crossentropy",
54-
optimiser="adam", batch_size=BATCH_SIZE,
55-
deep_r_conns=[hidden], deep_r_l1_strength=1E-8,
56-
deep_r_record_rewirings=({} if not PLOT_REWIRING
57-
else {hidden: "in_hid_rewiring"}),
58-
kernel_profiling=KERNEL_PROFILING)
59-
compiled_net = compiler.compile(network)
52+
compiler = EventPropCompiler(example_timesteps=max_example_timesteps,
53+
losses="sparse_categorical_crossentropy",
54+
batch_size=BATCH_SIZE, dt=DT, deep_r_l1_strength=1E-8,
55+
kernel_profiling=KERNEL_PROFILING)
56+
compiled_net = compiler.compile(network, optimisers={"all_connections": {"weight": Adam(1e-2)}},
57+
deep_r_conns=[hidden],
58+
deep_r_record_rewiring=({} if not PLOT_REWIRING
59+
else {hidden: "in_hid_rewiring"}))
6060

6161
with compiled_net:
62+
visualise_examples = [0, 32, 64, 96]
6263
# Evaluate model on numpy dataset
6364
start_time = perf_counter()
6465
callbacks = ["batch_progress_bar", Checkpoint(serialiser)]
@@ -67,22 +68,20 @@
6768
num_epochs=NUM_EPOCHS, shuffle=True,
6869
callbacks=callbacks)
6970
compiled_net.save_connectivity((NUM_EPOCHS - 1,), serialiser)
70-
7171
end_time = perf_counter()
7272
print(f"Accuracy = {100 * metrics[output].result}%")
7373
print(f"Time = {end_time - start_time}s")
7474

7575
if KERNEL_PROFILING:
7676
print(f"Neuron update time = {compiled_net.genn_model.neuron_update_time}")
7777
print(f"Presynaptic update time = {compiled_net.genn_model.presynaptic_update_time}")
78-
print(f"Synapse dynamics time = {compiled_net.genn_model.synapse_dynamics_time}")
7978
print(f"Gradient batch reduce time = {compiled_net.genn_model.get_custom_update_time('GradientBatchReduce')}")
8079
print(f"Gradient learn time = {compiled_net.genn_model.get_custom_update_time('GradientLearn')}")
8180
print(f"Reset time = {compiled_net.genn_model.get_custom_update_time('Reset')}")
82-
print(f"Softmax1 time = {compiled_net.genn_model.get_custom_update_time('Softmax1')}")
83-
print(f"Softmax2 time = {compiled_net.genn_model.get_custom_update_time('Softmax2')}")
84-
print(f"Softmax3 time = {compiled_net.genn_model.get_custom_update_time('Softmax3')}")
85-
81+
print(f"Softmax1 time = {compiled_net.genn_model.get_custom_update_time('BatchSoftmax1')}")
82+
print(f"Softmax2 time = {compiled_net.genn_model.get_custom_update_time('BatchSoftmax2')}")
83+
print(f"Softmax3 time = {compiled_net.genn_model.get_custom_update_time('BatchSoftmax3')}")
84+
8685
# Plot rewiring curves
8786
if PLOT_REWIRING:
8887
fig, axis = plt.subplots()
@@ -96,7 +95,8 @@
9695
network.load((NUM_EPOCHS - 1,), serialiser)
9796

9897
compiler = InferenceCompiler(evaluate_timesteps=max_example_timesteps,
99-
batch_size=BATCH_SIZE)
98+
reset_in_syn_between_batches=True,
99+
batch_size=BATCH_SIZE, dt=DT)
100100
compiled_net = compiler.compile(network)
101101

102102
with compiled_net:

ml_genn/ml_genn/callbacks/custom_update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,4 @@ class CustomUpdateOnTimestepEnd(CustomUpdate):
7979
def on_timestep_end(self, state, timestep):
8080
if self._custom_update(state, timestep):
8181
logger.debug(f"Running custom update {self.name} "
82-
f"at start of timestep {timestep}")
82+
f"at start of timestep {timestep}")

ml_genn/ml_genn/compilers/deep_r.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,18 +157,16 @@ def __init__(self, deep_r_2_ccu, key=None):
157157
self.num_failed_rewirings = deep_r_2_ccu.extra_global_params["NumFailedRewirings"]
158158
self.key = key
159159

160-
def set_params(self, data, **kwargs):
161-
# Create empty list to hold recorded data
162-
data[self.key] = []
163-
self._data = data[self.key]
160+
def create_state(self, compiled_network, **kwargs):
161+
return []
164162

165-
def on_batch_end(self, batch, metrics):
163+
def on_batch_end(self, state, batch, metrics):
166164
# Read number of rewirings out of view and add to list
167-
self._data.append((self.num_rewirings.view[0],
168-
self.num_failed_rewirings.view[0]))
165+
state.append((self.num_rewirings.view[0],
166+
self.num_failed_rewirings.view[0]))
169167

170-
def get_data(self):
171-
return self.key, self._data
168+
def get_data(self, state):
169+
return self.key, state
172170

173171

174172
def add_deep_r(synapse_group, genn_model, compiler, l1_strength,

ml_genn/ml_genn/compilers/event_prop_compiler.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sympy
44

55
from string import Template
6-
from typing import Mapping, Union, Tuple
6+
from typing import Mapping, Sequence, Set, Union, Tuple
77
from pygenn import (CustomUpdateVarAccess, SynapseMatrixType,
88
VarAccess, VarAccessMode)
99

@@ -13,7 +13,9 @@
1313
from .ground_truths import GroundTruth
1414
from .. import Connection, InputLayer, Layer, Population, Network
1515
from ..callbacks import (Callback, CustomUpdateOnBatchBegin,
16-
CustomUpdateOnBatchEnd, CustomUpdateOnTimestepEnd)
16+
CustomUpdateOnBatchEnd, CustomUpdateOnEpochBegin,
17+
CustomUpdateOnTimestepBegin,
18+
CustomUpdateOnTimestepEnd)
1719
from ..communicators import Communicator
1820
from ..connection import Connection
1921
from ..losses import (Loss, MeanSquareError, PerNeuronMeanSquareError,
@@ -38,6 +40,7 @@
3840
from .deep_r import add_deep_r
3941
from ..utils.auto_tools import solve_ode
4042
from ..utils.module import get_object, get_object_mapping
43+
from ..utils.network import get_underlying_conn
4144
from ..utils.value import is_value_constant
4245

4346
from .compiler import softmax_1_model, softmax_2_model
@@ -294,8 +297,8 @@ def _add_required_wum_psm_parameters(model: AutoSynapseModel,
294297
WeightUpdateModel.has_psm_var_ref)
295298

296299
class CompileState:
297-
def __init__(self, network: Network, losses, optimisers,
298-
supported_matrix_type, backend_name):
300+
def __init__(self, network: Network, losses, optimisers, deep_r_conns,
301+
deep_r_record_rewiring, supported_matrix_type, backend_name):
299302
self.backend_name = backend_name
300303
self._neuron_reset_vars = []
301304
self._synapse_reset_vars = []
@@ -307,7 +310,11 @@ def __init__(self, network: Network, losses, optimisers,
307310
self.update_trial_pops = []
308311
self.adjoint_limit_pops_vars = []
309312
self.optimisers = {}
310-
313+
314+
self.deep_r_conns = set(get_underlying_conn(c) for c in deep_r_conns)
315+
self.deep_r_record_rewiring = {get_underlying_conn(c): k
316+
for c, k in deep_r_record_rewiring.items()}
317+
311318
# Build list of output populations
312319
readouts = [p for p in network.populations
313320
if p.neuron.readout is not None]
@@ -530,6 +537,7 @@ class EventPropCompiler(Compiler):
530537
spike number, strength for overshoot)
531538
reg_nu_upper: Target number of hidden neuron
532539
spikes used for regularisation
540+
grad_limit: TODO
533541
max_spikes: What is the maximum number of spikes each
534542
neuron (input and hidden) can emit each
535543
trial? This is used to allocate memory
@@ -541,8 +549,9 @@ class EventPropCompiler(Compiler):
541549
per_timestep_loss: Should we use the per-timestep or
542550
per-trial loss functions described above?
543551
dt: Simulation timestep [ms]
544-
ttfs_alpha TODO
545-
softmax_temperature TODO
552+
ttfs_alpha: TODO
553+
softmax_temperature: TODO
554+
deep_r_l1_strength: TODO
546555
batch_size: What batch size should be used for
547556
training? In our experience, EventProp works
548557
best with modest batch sizes (32-128)
@@ -565,14 +574,13 @@ class EventPropCompiler(Compiler):
565574
"""
566575

567576
def __init__(self, example_timesteps: int, losses,
568-
reg_lambda: Union[float, Tuple[float, float]] = 0.0, reg_nu_upper: float = 0.0,
569-
grad_limit: float = 100.0,
570-
max_spikes: int = 500,
571-
strict_buffer_checking: bool = False,
577+
reg_lambda: Union[float, Tuple[float, float]] = 0.0,
578+
reg_nu_upper: float = 0.0, grad_limit: float = 100.0,
579+
max_spikes: int = 500, strict_buffer_checking: bool = False,
572580
per_timestep_loss: bool = False, dt: float = 1.0,
573581
ttfs_alpha: float = 0.01, softmax_temperature: float = 1.0,
574-
batch_size: int = 1, rng_seed: int = 0,
575-
kernel_profiling: bool = False,
582+
deep_r_l1_strength: float = 0.01, batch_size: int = 1,
583+
rng_seed: int = 0, kernel_profiling: bool = False,
576584
communicator: Communicator = None,
577585
**genn_kwargs):
578586
supported_matrix_types = [SynapseMatrixType.TOEPLITZ,
@@ -625,6 +633,7 @@ def __init__(self, example_timesteps: int, losses,
625633
self.per_timestep_loss = per_timestep_loss
626634
self.ttfs_alpha = ttfs_alpha
627635
self.softmax_temperature = softmax_temperature
636+
self.deep_r_l1_strength = deep_r_l1_strength
628637

629638

630639
def pre_compile(self, network: Network,
@@ -633,14 +642,28 @@ def pre_compile(self, network: Network,
633642
# to training all weights using the adam optimiser with default params
634643
optimisers = kwargs.get("optimisers",
635644
{"all_connections": {"weight": "adam"}})
636-
645+
637646
# Check dictionary has been provided
638647
if not isinstance(optimisers, Mapping):
639-
raise RuntimeError("optimisers should be "
648+
raise RuntimeError("'optimisers' should be "
640649
"specified as a dictionary")
641650

651+
# Get sequence of connections to apply Deep-R to and
652+
# check provided value is indeed a sequence
653+
deep_r_conns = kwargs.get("deep_r_conns", [])
654+
if not isinstance(deep_r_conns, (Set, Sequence)):
655+
raise RuntimeError("'deep_r_conns' should be "
656+
"specified as a sequence")
657+
658+
# Get dictionary of connections to names to record their deep-r
659+
# rewiring stats to and check provided value is indeed a mapping
660+
deep_r_record_rewiring = kwargs.get("deep_r_record_rewiring", {})
661+
if not isinstance(deep_r_record_rewiring, Mapping):
662+
raise RuntimeError("'deep_r_record_rewiring' should be "
663+
"specified as a dictionary")
642664

643665
return CompileState(network, self.losses, optimisers,
666+
deep_r_conns, deep_r_record_rewiring,
644667
self.supported_matrix_type,
645668
genn_model.backend_name)
646669

@@ -1063,7 +1086,7 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
10631086
# If connection is in list of those to use Deep-R on
10641087
gradient_var_ref = create_wu_var_ref(genn_pop, "weightGradient")
10651088
weight_var_ref = create_wu_var_ref(genn_pop, "weight")
1066-
if c in self.deep_r_conns:
1089+
if k in compile_state.deep_r_conns:
10671090
# Add infrastructure
10681091
deep_r_2_ccu = add_deep_r(genn_pop, genn_model, self,
10691092
self.deep_r_l1_strength,
@@ -1072,14 +1095,14 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
10721095

10731096
# If we should record rewirings from
10741097
# this connection, add to list with key
1075-
if c in self.deep_r_record_rewirings:
1098+
if k in compile_state.deep_r_record_rewiring:
10761099
deep_r_record_rewirings_ccus.append(
1077-
(deep_r_2_ccu, self.deep_r_record_rewirings[c]))
1100+
(deep_r_2_ccu,
1101+
compile_state.deep_r_record_rewiring[k]))
10781102

10791103
# Create weight optimiser custom update
10801104
cu_weight = self._create_optimiser_custom_update(
1081-
f"Weight{i}", weight_var_ref,
1082-
create_wu_var_ref(genn_pop, gradient_var_ref),
1105+
f"Weight{i}", weight_var_ref, gradient_var_ref,
10831106
vars["weight"], genn_model)
10841107

10851108
# Add custom update to list of optimisers
@@ -1187,6 +1210,7 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
11871210
base_validate_callbacks = []
11881211

11891212
# If Deep-R and L1 regularisation are required, add callback
1213+
deep_r_required = (len(compile_state.deep_r_conns) > 0)
11901214
if deep_r_required and self.deep_r_l1_strength > 0.0:
11911215
base_train_callbacks.append(
11921216
CustomUpdateOnBatchEnd("DeepRL1", lambda batch: batch > 0))

ml_genn/ml_genn/utils/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def __init__(self, model, param_vals={}, var_vals={}, pre_var_vals={},
265265
self.egp_refs = egp_refs
266266

267267
def process(self):
268-
return (super(CustomConnectivityUpdateModel, self).process()
268+
return (super(CustomConnectivityUpdateModel, self)._process()
269269
+ (self.pre_var_vals,) + (self.post_var_vals,)
270270
+ (self.var_refs,) + (self.pre_var_refs,)
271271
+ (self.post_var_refs,) + (self.egp_refs,))

0 commit comments

Comments
 (0)