Skip to content

Commit af185c5

Browse files
authored
Fix several bugs, update APIs and docs (#304)
Fix several bugs, update APIs and docs
2 parents 4c3433d + 5281ef7 commit af185c5

File tree

18 files changed

+280
-244
lines changed

18 files changed

+280
-244
lines changed

brainpy/connect/base.py

Lines changed: 188 additions & 123 deletions
Large diffs are not rendered by default.

brainpy/connect/custom_conn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ class CSRConn(TwoEndConnector):
8181
def __init__(self, indices, inptr):
8282
super(CSRConn, self).__init__()
8383

84-
self.indices = bm.asarray(indices).astype(IDX_DTYPE)
85-
self.inptr = bm.asarray(inptr).astype(IDX_DTYPE)
84+
self.indices = bm.asarray(indices, dtype=IDX_DTYPE)
85+
self.inptr = bm.asarray(inptr, dtype=IDX_DTYPE)
8686
self.pre_num = self.inptr.size - 1
8787
self.max_post = bm.max(self.indices)
8888

@@ -110,3 +110,5 @@ def __init__(self, csr_mat):
110110
self.csr_mat = csr_mat
111111
super(SparseMatConn, self).__init__(indices=bm.asarray(self.csr_mat.indices, dtype=IDX_DTYPE),
112112
inptr=bm.asarray(self.csr_mat.indptr, dtype=IDX_DTYPE))
113+
self.pre_num = csr_mat.shape[0]
114+
self.post_num = csr_mat.shape[1]

brainpy/connect/regular_conn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def build_csr(self):
4646
f'same size, but {self.pre_num} != {self.post_num}.')
4747
ind = np.arange(self.pre_num)
4848
indptr = np.arange(self.pre_num + 1)
49-
return np.asarray(ind, dtype=IDX_DTYPE), np.arange(indptr, dtype=IDX_DTYPE),
49+
return (np.asarray(ind, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE))
5050

5151
def build_mat(self, pre_size=None, post_size=None):
5252
if self.pre_num != self.post_num:

brainpy/dyn/base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -988,15 +988,15 @@ def __init__(
988988
ltp.register_master(master=self)
989989
self.ltp: SynLTP = ltp
990990

991-
def init_weights(
991+
def _init_weights(
992992
self,
993993
weight: Union[float, Array, Initializer, Callable],
994994
comp_method: str,
995995
sparse_data: str = 'csr'
996996
) -> Union[float, Array]:
997997
if comp_method not in ['sparse', 'dense']:
998998
raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}')
999-
if sparse_data not in ['csr', 'ij']:
999+
if sparse_data not in ['csr', 'ij', 'coo']:
10001000
raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {sparse_data}')
10011001
if self.conn is None:
10021002
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
@@ -1014,11 +1014,11 @@ def init_weights(
10141014
if comp_method == 'sparse':
10151015
if sparse_data == 'csr':
10161016
conn_mask = self.conn.require('pre2post')
1017-
elif sparse_data == 'ij':
1017+
elif sparse_data in ['ij', 'coo']:
10181018
conn_mask = self.conn.require('post_ids', 'pre_ids')
10191019
else:
10201020
ValueError(f'Unknown sparse data type: {sparse_data}')
1021-
weight = parameter(weight, conn_mask[1].shape, allow_none=False)
1021+
weight = parameter(weight, conn_mask[0].shape, allow_none=False)
10221022
elif comp_method == 'dense':
10231023
weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False)
10241024
conn_mask = self.conn.require('conn_mat')
@@ -1030,7 +1030,7 @@ def init_weights(
10301030
weight = bm.TrainVar(weight)
10311031
return weight, conn_mask
10321032

1033-
def syn2post_with_all2all(self, syn_value, syn_weight):
1033+
def _syn2post_with_all2all(self, syn_value, syn_weight):
10341034
if bm.ndim(syn_weight) == 0:
10351035
if isinstance(self.mode, BatchingMode):
10361036
post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:])
@@ -1043,10 +1043,10 @@ def syn2post_with_all2all(self, syn_value, syn_weight):
10431043
post_vs = syn_value @ syn_weight
10441044
return post_vs
10451045

1046-
def syn2post_with_one2one(self, syn_value, syn_weight):
1046+
def _syn2post_with_one2one(self, syn_value, syn_weight):
10471047
return syn_value * syn_weight
10481048

1049-
def syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
1049+
def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
10501050
if bm.ndim(syn_weight) == 0:
10511051
post_vs = (syn_weight * syn_value) @ conn_mat
10521052
else:

brainpy/dyn/layers/activate.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
1-
from brainpy.dyn.base import DynamicalSystem
2-
from typing import Optional
3-
from brainpy.modes import Mode
41
from typing import Callable
2+
from typing import Optional
3+
4+
from brainpy.dyn.base import DynamicalSystem
5+
from brainpy.modes import Mode, training
56

67

78
class Activation(DynamicalSystem):
8-
r"""Applies a activation to the inputs
9+
r"""Applies an activation function to the inputs
910
1011
Parameters:
1112
----------
12-
activate_fun: Callable
13+
activate_fun: Callable, function
1314
The function of Activation
1415
name: str, Optional
1516
The name of the object
1617
mode: Mode
1718
Enable training this node or not. (default True).
1819
"""
1920

20-
def __init__(self,
21-
activate_fun: Callable,
22-
name: Optional[str] = None,
23-
mode: Optional[Mode] = None,
24-
**kwargs,
25-
):
21+
def __init__(
22+
self,
23+
activate_fun: Callable,
24+
name: Optional[str] = None,
25+
mode: Mode = training,
26+
**kwargs,
27+
):
2628
super().__init__(name, mode)
2729
self.activate_fun = activate_fun
2830
self.kwargs = kwargs

brainpy/dyn/layers/linear.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from brainpy.dyn.base import DynamicalSystem
1010
from brainpy.errors import MathError
1111
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
12-
from brainpy.modes import Mode, TrainingMode, BatchingMode, training, batching
12+
from brainpy.modes import Mode, TrainingMode, BatchingMode, training, batching
1313
from brainpy.tools.checking import check_initializer
1414
from brainpy.types import Array
1515

@@ -201,17 +201,19 @@ class Flatten(DynamicalSystem):
201201
mode: Mode
202202
Enable training this node or not. (default True)
203203
"""
204-
def __init__(self,
205-
name: Optional[str] = None,
206-
mode: Optional[Mode] = batching,
207-
):
204+
205+
def __init__(
206+
self,
207+
name: Optional[str] = None,
208+
mode: Optional[Mode] = batching,
209+
):
208210
super().__init__(name, mode)
209-
211+
210212
def update(self, shr, x):
211213
if isinstance(self.mode, BatchingMode):
212214
return x.reshape((x.shape[0], -1))
213215
else:
214216
return x.flatten()
215-
217+
216218
def reset_state(self, batch_size=None):
217-
pass
219+
pass

brainpy/dyn/synapses/abstract_models.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def __init__(
119119
self.comp_method = comp_method
120120

121121
# connections and weights
122-
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method=comp_method, sparse_data='csr')
122+
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')
123123

124124
# register delay
125125
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
@@ -143,10 +143,10 @@ def update(self, tdi, pre_spike=None):
143143
# synaptic values onto the post
144144
if isinstance(self.conn, All2All):
145145
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
146-
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
146+
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
147147
elif isinstance(self.conn, One2One):
148148
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
149-
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
149+
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
150150
else:
151151
if self.comp_method == 'sparse':
152152
f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max)
@@ -160,7 +160,7 @@ def update(self, tdi, pre_spike=None):
160160
# post_vs *= f2(stp_value)
161161
else:
162162
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
163-
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
163+
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
164164
if self.post_ref_key:
165165
post_vs = post_vs * (1. - getattr(self.post, self.post_ref_key))
166166

@@ -296,7 +296,7 @@ def __init__(
296296
raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}')
297297

298298
# connections and weights
299-
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='csr')
299+
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr')
300300

301301
# variables
302302
self.g = variable_(bm.zeros, self.post.num, mode)
@@ -328,11 +328,11 @@ def update(self, tdi, pre_spike=None):
328328
if isinstance(self.conn, All2All):
329329
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
330330
if self.stp is not None: syn_value = self.stp(syn_value)
331-
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
331+
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
332332
elif isinstance(self.conn, One2One):
333333
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
334334
if self.stp is not None: syn_value = self.stp(syn_value)
335-
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
335+
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
336336
else:
337337
if self.comp_method == 'sparse':
338338
f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max)
@@ -343,7 +343,7 @@ def update(self, tdi, pre_spike=None):
343343
else:
344344
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
345345
if self.stp is not None: syn_value = self.stp(syn_value)
346-
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
346+
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
347347
# updates
348348
self.g.value = self.integral(self.g.value, t, dt) + post_vs
349349

@@ -487,7 +487,7 @@ def __init__(
487487
f'But we got {self.tau_decay}')
488488

489489
# connections
490-
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
490+
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')
491491

492492
# variables
493493
self.h = variable_(bm.zeros, self.pre.num, mode)
@@ -531,16 +531,16 @@ def update(self, tdi, pre_spike=None):
531531
syn_value = self.g.value
532532
if self.stp is not None: syn_value = self.stp(syn_value)
533533
if isinstance(self.conn, All2All):
534-
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
534+
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
535535
elif isinstance(self.conn, One2One):
536-
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
536+
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
537537
else:
538538
if self.comp_method == 'sparse':
539539
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
540540
if isinstance(self.mode, BatchingMode): f = vmap(f)
541541
post_vs = f(syn_value)
542542
else:
543-
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
543+
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
544544

545545
# output
546546
return self.output(post_vs)
@@ -829,7 +829,7 @@ def __init__(
829829
self.stop_spike_gradient = stop_spike_gradient
830830

831831
# connections and weights
832-
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
832+
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')
833833

834834
# variables
835835
self.g = variable_(bm.zeros, self.pre.num, mode)
@@ -872,16 +872,16 @@ def update(self, tdi, pre_spike=None):
872872
syn_value = self.g.value
873873
if self.stp is not None: syn_value = self.stp(syn_value)
874874
if isinstance(self.conn, All2All):
875-
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
875+
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
876876
elif isinstance(self.conn, One2One):
877-
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
877+
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
878878
else:
879879
if self.comp_method == 'sparse':
880880
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
881881
if isinstance(self.mode, BatchingMode): f = vmap(f)
882882
post_vs = f(syn_value)
883883
else:
884-
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
884+
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
885885

886886
# output
887887
return self.output(post_vs)

brainpy/dyn/synapses/biological_models.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def __init__(
181181
raise ValueError(f'"T_duration" must be a scalar or a tensor with size of 1. But we got {T_duration}')
182182

183183
# connection
184-
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
184+
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')
185185

186186
# variables
187187
self.g = variable(bm.zeros, mode, self.pre.num)
@@ -226,16 +226,16 @@ def update(self, tdi, pre_spike=None):
226226
syn_value = self.g.value
227227
if self.stp is not None: syn_value = self.stp(syn_value)
228228
if isinstance(self.conn, All2All):
229-
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
229+
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
230230
elif isinstance(self.conn, One2One):
231-
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
231+
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
232232
else:
233233
if self.comp_method == 'sparse':
234234
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
235235
if isinstance(self.mode, BatchingMode): f = vmap(f)
236236
post_vs = f(syn_value)
237237
else:
238-
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
238+
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
239239

240240
# output
241241
return self.output(post_vs)
@@ -526,7 +526,7 @@ def __init__(
526526
self.stop_spike_gradient = stop_spike_gradient
527527

528528
# connections and weights
529-
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
529+
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')
530530

531531
# variables
532532
self.g = variable(bm.zeros, mode, self.pre.num)
@@ -575,16 +575,16 @@ def update(self, tdi, pre_spike=None):
575575
syn_value = self.g.value
576576
if self.stp is not None: syn_value = self.stp(syn_value)
577577
if isinstance(self.conn, All2All):
578-
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
578+
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
579579
elif isinstance(self.conn, One2One):
580-
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
580+
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
581581
else:
582582
if self.comp_method == 'sparse':
583583
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
584584
if isinstance(self.mode, BatchingMode): f = vmap(f)
585585
post_vs = f(syn_value)
586586
else:
587-
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
587+
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
588588

589589
# output
590590
return self.output(post_vs)

brainpy/math/setting.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import os
44
import re
55

6-
from jax import dtypes, config, numpy as jnp
6+
from jax import dtypes, config, numpy as jnp, devices
77
from jax.lib import xla_bridge
88

99
__all__ = [
1010
'enable_x64',
1111
'disable_x64',
1212
'set_platform',
13+
'get_platform',
1314
'set_host_device_count',
1415

1516
# device memory
@@ -92,7 +93,7 @@ def disable_x64():
9293
config.update("jax_enable_x64", False)
9394

9495

95-
def set_platform(platform):
96+
def set_platform(platform: str):
9697
"""
9798
Changes platform to CPU, GPU, or TPU. This utility only takes
9899
effect at the beginning of your program.
@@ -101,6 +102,17 @@ def set_platform(platform):
101102
config.update("jax_platform_name", platform)
102103

103104

105+
def get_platform() -> str:
106+
"""Get the computing platform.
107+
108+
Returns
109+
-------
110+
platform: str
111+
Either 'cpu', 'gpu' or 'tpu'.
112+
"""
113+
return devices()[0].platform
114+
115+
104116
def set_host_device_count(n):
105117
"""
106118
By default, XLA considers all CPU cores as one device. This utility tells XLA

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
'sphinx.ext.autosummary',
7070
'sphinx.ext.intersphinx',
7171
'sphinx.ext.mathjax',
72-
'sphinx-mathjax-offline',
72+
# 'sphinx-mathjax-offline',
7373
'sphinx.ext.napoleon',
7474
'sphinx.ext.viewcode',
7575
'sphinx_autodoc_typehints',

0 commit comments

Comments
 (0)