Skip to content

Commit fd216d8

Browse files
author
Alexander Ororbia
committed
revised tests to no longer use weight_distribution/revisions throughout
1 parent a8ff7d7 commit fd216d8

13 files changed

+28
-25
lines changed

ngclearn/components/synapses/patched/hebbianPatchedSynapse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _calc_update(
7070

7171
dW = dW + prior_lmbda * dW_reg
7272

73-
if mask!=None:
73+
if mask != None:
7474
dW = dW * mask
7575

7676
return dW * signVal, db * signVal
@@ -95,12 +95,12 @@ def _enforce_constraints(W, block_mask, w_bound, is_nonnegative=True):
9595
"""
9696
_W = W
9797
if w_bound > 0.:
98-
if is_nonnegative == True:
98+
if is_nonnegative:
9999
_W = jnp.clip(_W, 0., w_bound)
100100
else:
101101
_W = jnp.clip(_W, -w_bound, w_bound)
102102

103-
if block_mask!=None:
103+
if block_mask != None:
104104
_W = _W * block_mask
105105

106106
return _W

ngclearn/components/synapses/patched/patchedSynapse.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@ def _create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_in
2727

2828
shape_ = (end_i - start_i, end_j - start_j) # (di + 2 * si, dj + 2 * sj)
2929

30-
weights[start_i: end_i,
31-
start_j: end_j] = weight_init(shape_, key[2])
30+
## FIXME: this line below might be wonky...
31+
weights.at[start_i: end_i, start_j: end_j].set( weight_init(shape_, key[2]) )
3232
# weights[start_i : end_i,
3333
# start_j : end_j] = initialize_params(key[2], init_kernel=weight_init, shape=shape_, use_numpy=True)
34-
if si!=0:
35-
weights[:si,:] = 0.
36-
weights[-si:,:] = 0.
37-
if sj!=0:
38-
weights[:,:sj] = 0.
39-
weights[:, -sj:] = 0.
34+
if si != 0:
35+
weights.at[:si,:].set(0.) ## FIXME: this setter line might be wonky...
36+
weights.at[-si:,:].set(0.) ## FIXME: this setter line might be wonky...
37+
if sj != 0:
38+
weights.at[:,:sj].set(0.) ## FIXME: this setter line might be wonky...
39+
weights.at[:, -sj:].set(0.) ## FIXME: this setter line might be wonky...
4040

4141
return weights
4242

@@ -109,7 +109,8 @@ def __init__(
109109
tmp_key, *subkeys = random.split(self.key.get(), 4)
110110
if self.weight_init is None:
111111
info(self.name, "is using default weight initializer!")
112-
self.weight_init = {"dist": "fan_in_gaussian"}
112+
#self.weight_init = {"dist": "fan_in_gaussian"}
113+
self.weight_init = DistributionGenerator.fan_in_gaussian()
113114

114115
weights = _create_multi_patch_synapses(
115116
key=subkeys, shape=shape, n_sub_models=self.n_sub_models, sub_stride=self.sub_stride,

tests/components/synapses/convolution/test_hebbianConvSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
np.random.seed(42)
44

55
from ngclearn import Context, MethodProcess
6-
import ngclearn.utils.weight_distribution as dist
6+
from ngclearn.utils.distribution_generator import DistributionGenerator as dist
77
from ngclearn.components.synapses.convolution.hebbianConvSynapse import HebbianConvSynapse
88
from numpy.testing import assert_array_equal
99

tests/components/synapses/convolution/test_hebbianDeconvSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
np.random.seed(42)
44

55
from ngclearn import Context, MethodProcess
6-
import ngclearn.utils.weight_distribution as dist
6+
from ngclearn.utils.distribution_generator import DistributionGenerator as dist
77
from ngclearn.components.synapses.convolution.hebbianDeconvSynapse import HebbianDeconvSynapse
88
from numpy.testing import assert_array_equal
99

tests/components/synapses/convolution/test_traceSTDPConvSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
np.random.seed(42)
55

66
from ngclearn import Context, MethodProcess
7-
import ngclearn.utils.weight_distribution as dist
7+
from ngclearn.utils.distribution_generator import DistributionGenerator as dist
88
from ngclearn.components.synapses.convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse
99
from numpy.testing import assert_array_equal
1010

tests/components/synapses/convolution/test_traceSTDPDeconvSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
np.random.seed(42)
55

66
from ngclearn import Context, MethodProcess
7-
import ngclearn.utils.weight_distribution as dist
7+
from ngclearn.utils.distribution_generator import DistributionGenerator as dist
88
from ngclearn.components.synapses.convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse
99
from numpy.testing import assert_array_equal
1010

tests/components/synapses/hebbian/test_BCMSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
np.random.seed(42)
55

66
from ngclearn import Context, MethodProcess
7-
import ngclearn.utils.weight_distribution as dist
7+
#from ngclearn.utils.distribution_generator import DistributionGenerator as dist
88
from ngclearn.components.synapses.hebbian.BCMSynapse import BCMSynapse
99
from numpy.testing import assert_array_equal
1010

tests/components/synapses/hebbian/test_eventSTDPSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
np.random.seed(42)
55

66
from ngclearn import Context, MethodProcess
7-
import ngclearn.utils.weight_distribution as dist
7+
#from ngclearn.utils.distribution_generator import DistributionGenerator as dist
88
from ngclearn.components.synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse
99
from numpy.testing import assert_array_equal
1010

tests/components/synapses/hebbian/test_expSTDPSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
np.random.seed(42)
66

77
from ngclearn import Context, MethodProcess
8-
import ngclearn.utils.weight_distribution as dist
8+
#from ngclearn.utils.distribution_generator import DistributionGenerator as dist
99
from ngclearn.components.synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse
1010
from numpy.testing import assert_array_equal
1111

tests/components/synapses/hebbian/test_traceSTDPSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
np.random.seed(42)
55

66
from ngclearn import Context, MethodProcess
7-
#import ngclearn.utils.weight_distribution as dist
7+
#from ngclearn.utils.distribution_generator import DistributionGenerator as dist
88
from ngclearn.components.synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse
99
from numpy.testing import assert_array_equal
1010

0 commit comments

Comments
 (0)