@@ -478,53 +478,6 @@ def on_batch_begin(self, state, batch: int):
478478 self .genn_pop .set_dynamic_param_value ("Trial" , batch )
479479
480480
481- class CustomUpdateOnLastTimestep (Callback ):
482- """Callback that triggers a GeNN custom update
483- at the start of the last timestep in each example"""
484- def __init__ (self , name : str , example_timesteps : int ):
485- self .name = name
486- self .example_timesteps = example_timesteps
487-
488- def create_state (self , compiled_network , ** kwargs ):
489- return compiled_network
490-
491- def on_timestep_begin (self , state , timestep : int ):
492- if timestep == (self .example_timesteps - 1 ):
493- logger .debug (f"Running custom update { self .name } "
494- f"at start of timestep { timestep } " )
495- state .genn_model .custom_update (self .name )
496-
497-
498- class CustomUpdateOnBatchEndNotFirst (Callback ):
499- """Callback that triggers a GeNN custom update
500- at the end of every batch after the first."""
501- def __init__ (self , name : str ):
502- self .name = name
503-
504- def create_state (self , compiled_network , ** kwargs ):
505- return compiled_network
506-
507- def on_batch_end (self , state , batch , metric_state ):
508- if batch > 0 :
509- logger .debug (f"Running custom update { self .name } "
510- f"at end of batch { batch } " )
511- state .genn_model .custom_update (self .name )
512-
513- class CustomUpdateOnFirstBatchEnd (Callback ):
514- """Callback that triggers a GeNN custom update
515- at the end of first batch."""
516- def __init__ (self , name : str ):
517- self .name = name
518-
519- def create_state (self , compiled_network , ** kwargs ):
520- return compiled_network
521-
522- def on_batch_end (self , state , batch , metric_state ):
523- if batch == 0 :
524- logger .debug (f"Running custom update { self .name } "
525- f"at end of batch { batch } " )
526- state .genn_model .custom_update (self .name )
527-
528481class EventPropCompiler (Compiler ):
529482 """Compiler for training models using EventProp [Wunderlich2021]_.
530483
@@ -1279,8 +1232,10 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
12791232
12801233 # Add custom uopdate for adjoint limit calculation if required
12811234 if len (compile_state .adjoint_limit_pops_vars ) > 0 :
1282- base_train_callbacks .append (CustomUpdateOnBatchEndNotFirst ("AbsSumReduceBatch" ))
1283- base_train_callbacks .append (CustomUpdateOnBatchEndNotFirst ("ReduceAssign" ))
1235+ base_train_callbacks .append (CustomUpdateOnBatchEnd ("AbsSumReduceBatch" ,
1236+ lambda batch : batch > 0 ))
1237+ base_train_callbacks .append (CustomUpdateOnBatchEnd ("ReduceAssign" ,
1238+ lambda batch : batch > 0 ))
12841239
12851240 # If spike count reduction is required at end of batch, add callback
12861241 if len (compile_state .spike_count_populations ) > 0 and self .full_batch_size > 1 :
0 commit comments