33import sympy
44
55from string import Template
6- from typing import Mapping , Union , Tuple
6+ from typing import Mapping , Sequence , Set , Union , Tuple
77from pygenn import (CustomUpdateVarAccess , SynapseMatrixType ,
88 VarAccess , VarAccessMode )
99
1313from .ground_truths import GroundTruth
1414from .. import Connection , InputLayer , Layer , Population , Network
1515from ..callbacks import (Callback , CustomUpdateOnBatchBegin ,
16- CustomUpdateOnBatchEnd , CustomUpdateOnTimestepEnd )
16+ CustomUpdateOnBatchEnd , CustomUpdateOnEpochBegin ,
17+ CustomUpdateOnTimestepBegin ,
18+ CustomUpdateOnTimestepEnd )
1719from ..communicators import Communicator
1820from ..connection import Connection
1921from ..losses import (Loss , MeanSquareError , PerNeuronMeanSquareError ,
3840from .deep_r import add_deep_r
3941from ..utils .auto_tools import solve_ode
4042from ..utils .module import get_object , get_object_mapping
43+ from ..utils .network import get_underlying_conn
4144from ..utils .value import is_value_constant
4245
4346from .compiler import softmax_1_model , softmax_2_model
@@ -294,8 +297,8 @@ def _add_required_wum_psm_parameters(model: AutoSynapseModel,
294297 WeightUpdateModel .has_psm_var_ref )
295298
296299class CompileState :
297- def __init__ (self , network : Network , losses , optimisers ,
298- supported_matrix_type , backend_name ):
300+ def __init__ (self , network : Network , losses , optimisers , deep_r_conns ,
301+ deep_r_record_rewiring , supported_matrix_type , backend_name ):
299302 self .backend_name = backend_name
300303 self ._neuron_reset_vars = []
301304 self ._synapse_reset_vars = []
@@ -307,7 +310,11 @@ def __init__(self, network: Network, losses, optimisers,
307310 self .update_trial_pops = []
308311 self .adjoint_limit_pops_vars = []
309312 self .optimisers = {}
310-
313+
314+ self .deep_r_conns = set (get_underlying_conn (c ) for c in deep_r_conns )
315+ self .deep_r_record_rewiring = {get_underlying_conn (c ): k
316+ for c , k in deep_r_record_rewiring .items ()}
317+
311318 # Build list of output populations
312319 readouts = [p for p in network .populations
313320 if p .neuron .readout is not None ]
@@ -530,6 +537,7 @@ class EventPropCompiler(Compiler):
530537 spike number, strength for overshoot)
531538 reg_nu_upper: Target number of hidden neuron
532539 spikes used for regularisation
540+ grad_limit: TODO
533541 max_spikes: What is the maximum number of spikes each
534542 neuron (input and hidden) can emit each
535543 trial? This is used to allocate memory
@@ -541,8 +549,9 @@ class EventPropCompiler(Compiler):
541549 per_timestep_loss: Should we use the per-timestep or
542550 per-trial loss functions described above?
543551 dt: Simulation timestep [ms]
544- ttfs_alpha TODO
545- softmax_temperature TODO
552+ ttfs_alpha: TODO
553+ softmax_temperature: TODO
554+ deep_r_l1_strength: TODO
546555 batch_size: What batch size should be used for
547556 training? In our experience, EventProp works
548557 best with modest batch sizes (32-128)
@@ -565,14 +574,13 @@ class EventPropCompiler(Compiler):
565574 """
566575
567576 def __init__ (self , example_timesteps : int , losses ,
568- reg_lambda : Union [float , Tuple [float , float ]] = 0.0 , reg_nu_upper : float = 0.0 ,
569- grad_limit : float = 100.0 ,
570- max_spikes : int = 500 ,
571- strict_buffer_checking : bool = False ,
577+ reg_lambda : Union [float , Tuple [float , float ]] = 0.0 ,
578+ reg_nu_upper : float = 0.0 , grad_limit : float = 100.0 ,
579+ max_spikes : int = 500 , strict_buffer_checking : bool = False ,
572580 per_timestep_loss : bool = False , dt : float = 1.0 ,
573581 ttfs_alpha : float = 0.01 , softmax_temperature : float = 1.0 ,
574- batch_size : int = 1 , rng_seed : int = 0 ,
575- kernel_profiling : bool = False ,
582+ deep_r_l1_strength : float = 0.01 , batch_size : int = 1 ,
583+ rng_seed : int = 0 , kernel_profiling : bool = False ,
576584 communicator : Communicator = None ,
577585 ** genn_kwargs ):
578586 supported_matrix_types = [SynapseMatrixType .TOEPLITZ ,
@@ -625,6 +633,7 @@ def __init__(self, example_timesteps: int, losses,
625633 self .per_timestep_loss = per_timestep_loss
626634 self .ttfs_alpha = ttfs_alpha
627635 self .softmax_temperature = softmax_temperature
636+ self .deep_r_l1_strength = deep_r_l1_strength
628637
629638
630639 def pre_compile (self , network : Network ,
@@ -633,14 +642,28 @@ def pre_compile(self, network: Network,
633642 # to training all weights using the adam optimiser with default params
634643 optimisers = kwargs .get ("optimisers" ,
635644 {"all_connections" : {"weight" : "adam" }})
636-
645+
637646 # Check dictionary has been provided
638647 if not isinstance (optimisers , Mapping ):
639- raise RuntimeError ("optimisers should be "
648+ raise RuntimeError ("' optimisers' should be "
640649 "specified as a dictionary" )
641650
651+ # Get sequence of connections to apply Deep-R to and
652+ # check provided value is indeed a sequence
653+ deep_r_conns = kwargs .get ("deep_r_conns" , [])
654+ if not isinstance (deep_r_conns , (Set , Sequence )):
655+ raise RuntimeError ("'deep_r_conns' should be "
656+ "specified as a sequence" )
657+
658+ # Get dictionary of connections to names to record their deep-r
659+ # rewiring stats to and check provided value is indeed a mapping
660+ deep_r_record_rewiring = kwargs .get ("deep_r_record_rewiring" , {})
661+ if not isinstance (deep_r_record_rewiring , Mapping ):
662+ raise RuntimeError ("'deep_r_record_rewiring' should be "
663+ "specified as a dictionary" )
642664
643665 return CompileState (network , self .losses , optimisers ,
666+ deep_r_conns , deep_r_record_rewiring ,
644667 self .supported_matrix_type ,
645668 genn_model .backend_name )
646669
@@ -1063,7 +1086,7 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
10631086 # If connection is in list of those to use Deep-R on
10641087 gradient_var_ref = create_wu_var_ref (genn_pop , "weightGradient" )
10651088 weight_var_ref = create_wu_var_ref (genn_pop , "weight" )
1066- if c in self .deep_r_conns :
1089+ if k in compile_state .deep_r_conns :
10671090 # Add infrastructure
10681091 deep_r_2_ccu = add_deep_r (genn_pop , genn_model , self ,
10691092 self .deep_r_l1_strength ,
@@ -1072,14 +1095,14 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
10721095
10731096 # If we should record rewirings from
10741097 # this connection, add to list with key
1075- if c in self . deep_r_record_rewirings :
1098+ if k in compile_state . deep_r_record_rewiring :
10761099 deep_r_record_rewirings_ccus .append (
1077- (deep_r_2_ccu , self .deep_r_record_rewirings [c ]))
1100+ (deep_r_2_ccu ,
1101+ compile_state .deep_r_record_rewiring [k ]))
10781102
10791103 # Create weight optimiser custom update
10801104 cu_weight = self ._create_optimiser_custom_update (
1081- f"Weight{ i } " , weight_var_ref ,
1082- create_wu_var_ref (genn_pop , gradient_var_ref ),
1105+ f"Weight{ i } " , weight_var_ref , gradient_var_ref ,
10831106 vars ["weight" ], genn_model )
10841107
10851108 # Add custom update to list of optimisers
@@ -1187,6 +1210,7 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
11871210 base_validate_callbacks = []
11881211
11891212 # If Deep-R and L1 regularisation are required, add callback
1213+ deep_r_required = (len (compile_state .deep_r_conns ) > 0 )
11901214 if deep_r_required and self .deep_r_l1_strength > 0.0 :
11911215 base_train_callbacks .append (
11921216 CustomUpdateOnBatchEnd ("DeepRL1" , lambda batch : batch > 0 ))
0 commit comments