Skip to content

Commit fef9f88

Browse files
removed legacy callbacks
1 parent c3fc5bc commit fef9f88

1 file changed

Lines changed: 4 additions & 49 deletions

File tree

ml_genn/ml_genn/compilers/event_prop_compiler.py

Lines changed: 4 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -478,53 +478,6 @@ def on_batch_begin(self, state, batch: int):
478478
self.genn_pop.set_dynamic_param_value("Trial", batch)
479479

480480

481-
class CustomUpdateOnLastTimestep(Callback):
482-
"""Callback that triggers a GeNN custom update
483-
at the start of the last timestep in each example"""
484-
def __init__(self, name: str, example_timesteps: int):
485-
self.name = name
486-
self.example_timesteps = example_timesteps
487-
488-
def create_state(self, compiled_network, **kwargs):
489-
return compiled_network
490-
491-
def on_timestep_begin(self, state, timestep: int):
492-
if timestep == (self.example_timesteps - 1):
493-
logger.debug(f"Running custom update {self.name} "
494-
f"at start of timestep {timestep}")
495-
state.genn_model.custom_update(self.name)
496-
497-
498-
class CustomUpdateOnBatchEndNotFirst(Callback):
499-
"""Callback that triggers a GeNN custom update
500-
at the end of every batch after the first."""
501-
def __init__(self, name: str):
502-
self.name = name
503-
504-
def create_state(self, compiled_network, **kwargs):
505-
return compiled_network
506-
507-
def on_batch_end(self, state, batch, metric_state):
508-
if batch > 0:
509-
logger.debug(f"Running custom update {self.name} "
510-
f"at end of batch {batch}")
511-
state.genn_model.custom_update(self.name)
512-
513-
class CustomUpdateOnFirstBatchEnd(Callback):
514-
"""Callback that triggers a GeNN custom update
515-
at the end of first batch."""
516-
def __init__(self, name: str):
517-
self.name = name
518-
519-
def create_state(self, compiled_network, **kwargs):
520-
return compiled_network
521-
522-
def on_batch_end(self, state, batch, metric_state):
523-
if batch == 0:
524-
logger.debug(f"Running custom update {self.name} "
525-
f"at end of batch {batch}")
526-
state.genn_model.custom_update(self.name)
527-
528481
class EventPropCompiler(Compiler):
529482
"""Compiler for training models using EventProp [Wunderlich2021]_.
530483
@@ -1279,8 +1232,10 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
12791232

12801233
# Add custom uopdate for adjoint limit calculation if required
12811234
if len(compile_state.adjoint_limit_pops_vars) > 0:
1282-
base_train_callbacks.append(CustomUpdateOnBatchEndNotFirst("AbsSumReduceBatch"))
1283-
base_train_callbacks.append(CustomUpdateOnBatchEndNotFirst("ReduceAssign"))
1235+
base_train_callbacks.append(CustomUpdateOnBatchEnd("AbsSumReduceBatch",
1236+
lambda batch: batch > 0))
1237+
base_train_callbacks.append(CustomUpdateOnBatchEnd("ReduceAssign",
1238+
lambda batch: batch > 0))
12841239

12851240
# If spike count reduction is required at end of batch, add callback
12861241
if len(compile_state.spike_count_populations) > 0 and self.full_batch_size > 1:

0 commit comments

Comments
 (0)