@@ -119,7 +119,7 @@ def __init__(
119119 self .comp_method = comp_method
120120
121121 # connections and weights
122- self .g_max , self .conn_mask = self .init_weights (g_max , comp_method = comp_method , sparse_data = 'csr' )
122+ self .g_max , self .conn_mask = self ._init_weights (g_max , comp_method = comp_method , sparse_data = 'csr' )
123123
124124 # register delay
125125 self .delay_step = self .register_delay (f"{ self .pre .name } .spike" , delay_step , self .pre .spike )
@@ -143,10 +143,10 @@ def update(self, tdi, pre_spike=None):
143143 # synaptic values onto the post
144144 if isinstance (self .conn , All2All ):
145145 syn_value = self .stp (bm .asarray (pre_spike , dtype = bm .dftype ()))
146- post_vs = self .syn2post_with_all2all (syn_value , self .g_max )
146+ post_vs = self ._syn2post_with_all2all (syn_value , self .g_max )
147147 elif isinstance (self .conn , One2One ):
148148 syn_value = self .stp (bm .asarray (pre_spike , dtype = bm .dftype ()))
149- post_vs = self .syn2post_with_one2one (syn_value , self .g_max )
149+ post_vs = self ._syn2post_with_one2one (syn_value , self .g_max )
150150 else :
151151 if self .comp_method == 'sparse' :
152152 f = lambda s : bm .pre2post_event_sum (s , self .conn_mask , self .post .num , self .g_max )
@@ -160,7 +160,7 @@ def update(self, tdi, pre_spike=None):
160160 # post_vs *= f2(stp_value)
161161 else :
162162 syn_value = self .stp (bm .asarray (pre_spike , dtype = bm .dftype ()))
163- post_vs = self .syn2post_with_dense (syn_value , self .g_max , self .conn_mask )
163+ post_vs = self ._syn2post_with_dense (syn_value , self .g_max , self .conn_mask )
164164 if self .post_ref_key :
165165 post_vs = post_vs * (1. - getattr (self .post , self .post_ref_key ))
166166
@@ -296,7 +296,7 @@ def __init__(
296296 raise ValueError (f'"tau" must be a scalar or a tensor with size of 1. But we got { self .tau } ' )
297297
298298 # connections and weights
299- self .g_max , self .conn_mask = self .init_weights (g_max , comp_method , sparse_data = 'csr' )
299+ self .g_max , self .conn_mask = self ._init_weights (g_max , comp_method , sparse_data = 'csr' )
300300
301301 # variables
302302 self .g = variable_ (bm .zeros , self .post .num , mode )
@@ -328,11 +328,11 @@ def update(self, tdi, pre_spike=None):
328328 if isinstance (self .conn , All2All ):
329329 syn_value = bm .asarray (pre_spike , dtype = bm .dftype ())
330330 if self .stp is not None : syn_value = self .stp (syn_value )
331- post_vs = self .syn2post_with_all2all (syn_value , self .g_max )
331+ post_vs = self ._syn2post_with_all2all (syn_value , self .g_max )
332332 elif isinstance (self .conn , One2One ):
333333 syn_value = bm .asarray (pre_spike , dtype = bm .dftype ())
334334 if self .stp is not None : syn_value = self .stp (syn_value )
335- post_vs = self .syn2post_with_one2one (syn_value , self .g_max )
335+ post_vs = self ._syn2post_with_one2one (syn_value , self .g_max )
336336 else :
337337 if self .comp_method == 'sparse' :
338338 f = lambda s : bm .pre2post_event_sum (s , self .conn_mask , self .post .num , self .g_max )
@@ -343,7 +343,7 @@ def update(self, tdi, pre_spike=None):
343343 else :
344344 syn_value = bm .asarray (pre_spike , dtype = bm .dftype ())
345345 if self .stp is not None : syn_value = self .stp (syn_value )
346- post_vs = self .syn2post_with_dense (syn_value , self .g_max , self .conn_mask )
346+ post_vs = self ._syn2post_with_dense (syn_value , self .g_max , self .conn_mask )
347347 # updates
348348 self .g .value = self .integral (self .g .value , t , dt ) + post_vs
349349
@@ -487,7 +487,7 @@ def __init__(
487487 f'But we got { self .tau_decay } ' )
488488
489489 # connections
490- self .g_max , self .conn_mask = self .init_weights (g_max , comp_method , sparse_data = 'ij' )
490+ self .g_max , self .conn_mask = self ._init_weights (g_max , comp_method , sparse_data = 'ij' )
491491
492492 # variables
493493 self .h = variable_ (bm .zeros , self .pre .num , mode )
@@ -531,16 +531,16 @@ def update(self, tdi, pre_spike=None):
531531 syn_value = self .g .value
532532 if self .stp is not None : syn_value = self .stp (syn_value )
533533 if isinstance (self .conn , All2All ):
534- post_vs = self .syn2post_with_all2all (syn_value , self .g_max )
534+ post_vs = self ._syn2post_with_all2all (syn_value , self .g_max )
535535 elif isinstance (self .conn , One2One ):
536- post_vs = self .syn2post_with_one2one (syn_value , self .g_max )
536+ post_vs = self ._syn2post_with_one2one (syn_value , self .g_max )
537537 else :
538538 if self .comp_method == 'sparse' :
539539 f = lambda s : bm .pre2post_sum (s , self .post .num , * self .conn_mask )
540540 if isinstance (self .mode , BatchingMode ): f = vmap (f )
541541 post_vs = f (syn_value )
542542 else :
543- post_vs = self .syn2post_with_dense (syn_value , self .g_max , self .conn_mask )
543+ post_vs = self ._syn2post_with_dense (syn_value , self .g_max , self .conn_mask )
544544
545545 # output
546546 return self .output (post_vs )
@@ -829,7 +829,7 @@ def __init__(
829829 self .stop_spike_gradient = stop_spike_gradient
830830
831831 # connections and weights
832- self .g_max , self .conn_mask = self .init_weights (g_max , comp_method , sparse_data = 'ij' )
832+ self .g_max , self .conn_mask = self ._init_weights (g_max , comp_method , sparse_data = 'ij' )
833833
834834 # variables
835835 self .g = variable_ (bm .zeros , self .pre .num , mode )
@@ -872,16 +872,16 @@ def update(self, tdi, pre_spike=None):
872872 syn_value = self .g .value
873873 if self .stp is not None : syn_value = self .stp (syn_value )
874874 if isinstance (self .conn , All2All ):
875- post_vs = self .syn2post_with_all2all (syn_value , self .g_max )
875+ post_vs = self ._syn2post_with_all2all (syn_value , self .g_max )
876876 elif isinstance (self .conn , One2One ):
877- post_vs = self .syn2post_with_one2one (syn_value , self .g_max )
877+ post_vs = self ._syn2post_with_one2one (syn_value , self .g_max )
878878 else :
879879 if self .comp_method == 'sparse' :
880880 f = lambda s : bm .pre2post_sum (s , self .post .num , * self .conn_mask )
881881 if isinstance (self .mode , BatchingMode ): f = vmap (f )
882882 post_vs = f (syn_value )
883883 else :
884- post_vs = self .syn2post_with_dense (syn_value , self .g_max , self .conn_mask )
884+ post_vs = self ._syn2post_with_dense (syn_value , self .g_max , self .conn_mask )
885885
886886 # output
887887 return self .output (post_vs )
0 commit comments