Skip to content

Commit d8d23db

Browse files
authored
Merge pull request #342 from chaoming0625/master
Formalize new styles of neuron and synapse models
2 parents 8692d3e + f21aff0 commit d8d23db

File tree

27 files changed

+665
-608
lines changed

27 files changed

+665
-608
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu
3030

3131
- **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming.
3232
- **[brainpylib](https://github.com/brainpy/brainpylib)**: Efficient operators for the sparse and event-driven computation.
33-
- **[BrainPyExamples](https://github.com/brainpy/BrainPyExamples)**: Comprehensive examples of BrainPy computation.
34-
- **[brainpy-largescale](https://github.com/NH-NCL/brainpy-largescale)**: One solution for the large-scale brain modeling.
33+
- **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation.
34+
- **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling.
3535

3636

3737

brainpy/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,16 @@
5858
synapses, # synaptic dynamics
5959
synouts, # synaptic output
6060
synplast, # synaptic plasticity
61-
syn,
61+
experimental,
6262
)
63-
from brainpy._src.dyn.base import not_pass_sha
63+
from brainpy._src.dyn.base import not_pass_shared
6464
from brainpy._src.dyn.base import (DynamicalSystem,
6565
DynamicalSystemNS,
6666
Container as Container,
6767
Sequential as Sequential,
6868
Network as Network,
6969
NeuGroup as NeuGroup,
70+
NeuGroupNS as NeuGroupNS,
7071
SynConn as SynConn,
7172
SynOut as SynOut,
7273
SynSTP as SynSTP,
@@ -77,8 +78,7 @@
7778
from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations
7879
LoopOverTime as LoopOverTime,)
7980
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
80-
from brainpy._src.dyn.context import share
81-
from brainpy._src.dyn.delay import Delay
81+
from brainpy._src.dyn.context import share, Delay
8282

8383

8484
# Part 4: Training #

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def f_loss():
349349
def train(idx):
350350
gradients, loss = grad_f()
351351
optimizer.update(gradients if isinstance(gradients, dict) else {'a': gradients})
352+
optimizer.lr.step_epoch()
352353
return loss
353354

354355
def batch_train(start_i, n_batch):

brainpy/_src/dyn/base.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
SLICE_VARS = 'slice_vars'
5151

5252

53-
def not_pass_sha(func: Callable):
53+
def not_pass_shared(func: Callable):
5454
"""Label the update function as the one without passing shared arguments.
5555
5656
The original update function explicitly requires shared arguments at the first place::
@@ -610,7 +610,8 @@ def __repr__(self):
610610
entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(self._modules))
611611
return f'{self.__class__.__name__}(\n{entries}\n)'
612612

613-
def update(self, s, x) -> ArrayType:
613+
@not_pass_shared
614+
def update(self, x) -> ArrayType:
614615
"""Update function of a sequential model.
615616
616617
Parameters
@@ -626,12 +627,7 @@ def update(self, s, x) -> ArrayType:
626627
The output tensor.
627628
"""
628629
for m in self._modules:
629-
if isinstance(m, DynamicalSystemNS):
630-
x = m(x)
631-
elif isinstance(m, DynamicalSystem):
632-
x = m(s, x)
633-
else:
634-
x = m(x)
630+
x = m(x)
635631
return x
636632

637633

@@ -665,7 +661,7 @@ def __init__(
665661
mode=mode,
666662
**ds_dict)
667663

668-
@not_pass_sha
664+
@not_pass_shared
669665
def update(self, *args, **kwargs):
670666
"""Step function of a network.
671667
@@ -807,6 +803,11 @@ def __getitem__(self, item):
807803
return NeuGroupView(target=self, index=item)
808804

809805

806+
class NeuGroupNS(NeuGroup):
807+
"""Base class for neuron group without shared arguments passed."""
808+
pass_shared = False
809+
810+
810811
class SynConn(DynamicalSystem):
811812
"""Base class to model two-end synaptic connections.
812813

0 commit comments

Comments
 (0)