Skip to content

Commit 87c0b86

Browse files
authored
Updating apis for connections and operation registeration (#280)
Updating apis for connections and operation registration
2 parents e5a17c8 + 71e1308 commit 87c0b86

File tree

16 files changed

+529
-369
lines changed

16 files changed

+529
-369
lines changed

brainpy/connect/base.py

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,13 @@ def build_conn(self):
116116
117117
import brainpy as bp
118118
class MyConnector(bp.conn.TwoEndConnector):
119-
def build_mat(self, pre_size, post_size):
119+
def build_mat(self, ):
120120
return conn_matrix
121121
122-
def build_csr(self, pre_size, post_size):
122+
def build_csr(self, ):
123123
return post_ids, inptr
124124
125-
def build_coo(self, pre_size, post_size):
125+
def build_coo(self, ):
126126
return pre_ids, post_ids
127127
128128
"""
@@ -196,8 +196,6 @@ def check(self, structures: Union[Tuple, List, str]):
196196
raise ConnectorError(f'Unknown synapse structure "{n}". '
197197
f'Only {SUPPORTED_SYN_STRUCTURE} is supported.')
198198

199-
200-
201199
def _return_by_mat(self, structures, mat, all_data: dict):
202200
assert mat.ndim == 2
203201
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
@@ -332,70 +330,56 @@ def build_conn(self):
332330
"""
333331
pass
334332

335-
def require(self, *sizes_or_structures):
336-
sizes_or_structures = list(sizes_or_structures)
337-
pre_size = sizes_or_structures.pop(0) if len(sizes_or_structures) >= 1 else None
338-
post_size = sizes_or_structures.pop(0) if len(sizes_or_structures) >= 1 else None
339-
structures = sizes_or_structures
340-
if isinstance(post_size, str):
341-
structures.insert(0, post_size)
342-
post_size = None
343-
if isinstance(pre_size, str):
344-
structures.insert(0, pre_size)
345-
pre_size = None
346-
347-
version2_style = (pre_size is not None) and (post_size is not None)
348-
if not version2_style:
349-
try:
350-
assert self.pre_num is not None and self.post_num is not None
351-
except AssertionError:
352-
raise ConnectorError(f'self.pre_num or self.post_num is not defined. '
353-
f'Please use self.__call__(pre_size, post_size) '
354-
f'before requiring connection data.')
355-
if pre_size is None:
356-
pre_size = self.pre_size
357-
if post_size is None:
358-
post_size = self.post_size
333+
def require(self, *structures):
334+
try:
335+
assert self.pre_num is not None and self.post_num is not None
336+
except AssertionError:
337+
raise ConnectorError(f'self.pre_num or self.post_num is not defined. '
338+
f'Please use self.__call__() '
339+
f'before requiring connection data.')
359340

360341
self.check(structures)
361342
if self.is_version2_style:
362343
if len(structures) == 1:
363344
if PRE2POST in structures and not hasattr(self.build_csr, 'not_customized'):
364-
return self.build_csr(pre_size, post_size)
345+
r = self.build_csr()
346+
return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE)
365347
elif CONN_MAT in structures and not hasattr(self.build_mat, 'not_customized'):
366-
return self.build_mat(pre_size, post_size)
348+
return bm.asarray(self.build_mat(), dtype=MAT_DTYPE)
367349
elif PRE_IDS in structures and not hasattr(self.build_coo, 'not_customized'):
368-
return self.build_coo(pre_size, post_size)[0]
350+
return bm.asarray(self.build_coo()[0], dtype=IDX_DTYPE)
369351
elif POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'):
370-
return self.build_coo(pre_size, post_size)[1]
352+
return bm.asarray(self.build_coo()[1], dtype=IDX_DTYPE)
371353
elif len(structures) == 2:
372354
if PRE_IDS in structures and POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'):
373-
return self.build_coo(pre_size, post_size)
355+
r = self.build_coo()
356+
return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE)
374357

375358
conn_data = dict(csr=None, ij=None, mat=None)
376359
if not hasattr(self.build_coo, 'not_customized'):
377-
conn_data['ij'] = self.build_coo(pre_size, post_size)
360+
conn_data['ij'] = self.build_coo()
378361
elif not hasattr(self.build_csr, 'not_customized'):
379-
conn_data['csr'] = self.build_csr(pre_size, post_size)
362+
conn_data['csr'] = self.build_csr()
380363
elif not hasattr(self.build_mat, 'not_customized'):
381-
conn_data['mat'] = self.build_mat(pre_size, post_size)
364+
conn_data['mat'] = self.build_mat()
365+
382366
else:
383367
conn_data = self.build_conn()
384368
return self.make_returns(structures, conn_data)
385369

386-
def requires(self, *sizes_or_structures):
387-
return self.require(*sizes_or_structures)
370+
def requires(self, *structures):
371+
return self.require(*structures)
388372

389373
@tools.not_customized
390-
def build_mat(self, pre_size=None, post_size=None):
374+
def build_mat(self):
391375
pass
392376

393377
@tools.not_customized
394-
def build_csr(self, pre_size=None, post_size=None):
378+
def build_csr(self):
395379
pass
396380

397381
@tools.not_customized
398-
def build_coo(self, pre_size=None, post_size=None):
382+
def build_coo(self):
399383
pass
400384

401385

@@ -425,7 +409,6 @@ def __call__(self, pre_size, post_size=None):
425409
else:
426410
post_size = tuple(post_size)
427411
self.pre_size, self.post_size = pre_size, post_size
428-
429412
self.pre_num = tools.size2num(self.pre_size)
430413
self.post_num = tools.size2num(self.post_size)
431414
return self

brainpy/connect/custom_conn.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from brainpy import tools
88
from brainpy.errors import ConnectorError
99
from .base import *
10-
from .utils import *
1110

1211
__all__ = [
1312
'MatConn',
@@ -34,11 +33,9 @@ def __call__(self, pre_size, post_size):
3433
assert self.post_num == tools.size2num(post_size)
3534
return self
3635

37-
def build_mat(self, pre_size=None, post_size=None):
38-
pre_num = get_pre_num(self, pre_size)
39-
post_num = get_post_num(self, post_size)
40-
assert self.conn_mat.shape[0] == pre_num
41-
assert self.conn_mat.shape[1] == post_num
36+
def build_mat(self):
37+
assert self.conn_mat.shape[0] == self.pre_num
38+
assert self.conn_mat.shape[1] == self.post_num
4239
return self.conn_mat
4340

4441

@@ -68,14 +65,12 @@ def __call__(self, pre_size, post_size):
6865
f'the maximum id ({self.max_post}) of self.post_ids.')
6966
return self
7067

71-
def build_coo(self, pre_size=None, post_size=None):
72-
pre_num = get_pre_num(self, pre_size)
73-
post_num = get_post_num(self, post_size)
74-
if pre_num <= self.max_pre:
75-
raise ConnectorError(f'pre_num ({pre_num}) should be greater than '
68+
def build_coo(self):
69+
if self.pre_num <= self.max_pre:
70+
raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than '
7671
f'the maximum id ({self.max_pre}) of self.pre_ids.')
77-
if post_num <= self.max_post:
78-
raise ConnectorError(f'post_num ({post_num}) should be greater than '
72+
if self.post_num <= self.max_post:
73+
raise ConnectorError(f'post_num ({self.post_num}) should be greater than '
7974
f'the maximum id ({self.max_post}) of self.post_ids.')
8075
return self.pre_ids, self.post_ids
8176

@@ -91,16 +86,12 @@ def __init__(self, indices, inptr):
9186
self.pre_num = self.inptr.size - 1
9287
self.max_post = bm.max(self.indices)
9388

94-
def build_csr(self, pre_size=None, post_size=None):
95-
pre_size = get_pre_size(self, pre_size)
96-
post_size = get_post_size(self, post_size)
97-
pre_num = np.prod(pre_size)
98-
post_num = np.prod(post_size)
99-
if pre_num != self.pre_num:
89+
def build_csr(self):
90+
if self.pre_num != self.pre_num:
10091
raise ConnectorError(f'(pre_size, post_size) is inconsistent with '
10192
f'the shape of the sparse matrix.')
102-
if post_num <= self.max_post:
103-
raise ConnectorError(f'post_num ({post_num}) should be greater than '
93+
if self.post_num <= self.max_post:
94+
raise ConnectorError(f'post_num ({self.post_num}) should be greater than '
10495
f'the maximum id ({self.max_post}) of self.post_ids.')
10596
return self.indices, self.inptr
10697

0 commit comments

Comments
 (0)