Skip to content

Commit 5eb9f03

Browse files
authored
Merge pull request #121 from PKU-NIP-Lab/whole-brain-modeling
Whole brain modeling
2 parents 3086c69 + 0de0693 commit 5eb9f03

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+3548
-1274
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ BrainModels/
1717
book/
1818
docs/examples
1919
docs/apis/jaxsetting.rst
20+
docs/quickstart/data
2021
examples/recurrent_neural_network/neurogym
2122
develop/iconip_paper
2223
develop/benchmark/COBA/results

README.md

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ runner.run(100.)
150150
Numerical methods for delay differential equations (SDEs).
151151

152152
```python
153-
xdelay = bm.TimeDelay(1, delay_len=1., before_t0=1., dt=0.01)
153+
xdelay = bm.TimeDelay(bm.zeros(1), delay_len=1., before_t0=1., dt=0.01)
154154

155155

156156
@bp.ddeint(method='rk4', state_delays={'x': xdelay})
@@ -191,6 +191,31 @@ runner = bp.dyn.DSRunner(net)
191191
runner(100.)
192192
```
193193

194+
Simulating a whole brain network by using rate models.
195+
196+
```python
197+
import numpy as np
198+
199+
class WholeBrainNet(bp.dyn.Network):
200+
def __init__(self, signal_speed=20.):
201+
super(WholeBrainNet, self).__init__()
202+
203+
self.fhn = bp.dyn.RateFHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn')
204+
self.syn = bp.dyn.DiffusiveDelayCoupling(self.fhn, self.fhn,
205+
'x->input',
206+
conn_mat=conn_mat,
207+
delay_mat=delay_mat)
208+
209+
def update(self, _t, _dt):
210+
self.syn.update(_t, _dt)
211+
self.fhn.update(_t, _dt)
212+
213+
214+
net = WholeBrainNet()
215+
runner = bp.dyn.DSRunner(net, monitors=['fhn.x'], inputs=['fhn.input', 0.72])
216+
runner.run(6e3)
217+
```
218+
194219

195220

196221
### 4. Dynamics training level

brainpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "2.1.1"
3+
__version__ = "2.1.2"
44

55

66
try:
@@ -15,7 +15,7 @@
1515

1616

1717
# fundamental modules
18-
from . import errors, tools
18+
from . import errors, tools, check
1919

2020

2121
# "base" module

brainpy/analysis/highdim/slow_points.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
# -*- coding: utf-8 -*-
22

3-
import inspect
43
import time
4+
import warnings
55
from functools import partial
66

7+
from jax import vmap
78
import jax.numpy
89
import numpy as np
910
from jax.scipy.optimize import minimize
1011

1112
import brainpy.math as bm
13+
from brainpy import optimizers as optim
1214
from brainpy.analysis import utils
1315
from brainpy.errors import AnalyzerError
14-
from brainpy import optimizers as optim
1516

1617
__all__ = [
1718
'SlowPointFinder',
@@ -56,15 +57,15 @@ def __init__(self, f_cell, f_type='continuous', f_loss_batch=None, verbose=True)
5657
if f_loss_batch is None:
5758
if f_type == 'discrete':
5859
self.f_loss = bm.jit(lambda h: bm.mean((h - f_cell(h)) ** 2))
59-
self.f_loss_batch = bm.jit(lambda h: bm.mean((h - bm.vmap(f_cell, auto_infer=False)(h)) ** 2, axis=1))
60+
self.f_loss_batch = bm.jit(lambda h: bm.mean((h - vmap(f_cell)(h)) ** 2, axis=1))
6061
if f_type == 'continuous':
6162
self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2))
62-
self.f_loss_batch = bm.jit(lambda h: bm.mean((bm.vmap(f_cell, auto_infer=False)(h)) ** 2, axis=1))
63+
self.f_loss_batch = bm.jit(lambda h: bm.mean((vmap(f_cell)(h)) ** 2, axis=1))
6364

6465
else:
6566
self.f_loss_batch = f_loss_batch
6667
self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2))
67-
self.f_jacob_batch = bm.jit(bm.vmap(bm.jacobian(f_cell)))
68+
self.f_jacob_batch = bm.jit(vmap(bm.jacobian(f_cell)))
6869

6970
# essential variables
7071
self._losses = None
@@ -87,8 +88,13 @@ def selected_ids(self):
8788
"""The selected ids of candidate points."""
8889
return self._selected_ids
8990

90-
def find_fps_with_gd_method(self, candidates, tolerance=1e-5, num_batch=100,
91-
num_opt=10000, opt_setting=None):
91+
def find_fps_with_gd_method(self,
92+
candidates,
93+
tolerance=1e-5,
94+
num_batch=100,
95+
num_opt=10000,
96+
optimizer=None,
97+
opt_setting=None):
9298
"""Optimize fixed points with gradient descent methods.
9399
94100
Parameters
@@ -104,44 +110,56 @@ def find_fps_with_gd_method(self, candidates, tolerance=1e-5, num_batch=100,
104110
Print training information during optimization every so often.
105111
opt_setting: optional, dict
106112
The optimization settings.
113+
114+
.. deprecated:: 2.1.2
115+
Use "optimizer" to set optimization method instead.
116+
117+
optimizer: optim.Optimizer
118+
The optimizer instance.
119+
120+
.. versionadded:: 2.1.2
107121
"""
108122

109123
# optimization settings
110124
if opt_setting is None:
111-
opt_method = optim.Adam
112-
opt_lr = optim.ExponentialDecay(0.2, 1, 0.9999)
113-
opt_setting = {'beta1': 0.9,
114-
'beta2': 0.999,
115-
'eps': 1e-8,
116-
'name': None}
125+
if optimizer is None:
126+
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
127+
beta1=0.9, beta2=0.999, eps=1e-8)
128+
else:
129+
assert isinstance(optimizer, optim.Optimizer), (f'Must be an instance of '
130+
f'{optim.Optimizer.__name__}, '
131+
f'while we got {type(optimizer)}')
117132
else:
133+
warnings.warn('Please use "optimizer" to set optimization method. '
134+
'"opt_setting" is deprecated since version 2.1.2. ',
135+
DeprecationWarning)
136+
118137
assert isinstance(opt_setting, dict)
119138
assert 'method' in opt_setting
120139
assert 'lr' in opt_setting
121140
opt_method = opt_setting.pop('method')
122141
if isinstance(opt_method, str):
123142
assert opt_method in optim.__dict__
124143
opt_method = getattr(optim, opt_method)
125-
assert isinstance(opt_method, type)
126-
if optim.Optimizer not in inspect.getmro(opt_method):
127-
raise ValueError
144+
assert issubclass(opt_method, optim.Optimizer)
128145
opt_lr = opt_setting.pop('lr')
129146
assert isinstance(opt_lr, (int, float, optim.Scheduler))
130147
opt_setting = opt_setting
148+
optimizer = opt_method(lr=opt_lr, **opt_setting)
131149

132150
if self.verbose:
133-
print(f"Optimizing with {opt_method.__name__} to find fixed points:")
151+
print(f"Optimizing with {optimizer.__name__} to find fixed points:")
134152

135153
# set up optimization
136154
fixed_points = bm.Variable(bm.asarray(candidates))
137155
grad_f = bm.grad(lambda: self.f_loss_batch(fixed_points.value).mean(),
138156
grad_vars={'a': fixed_points}, return_value=True)
139-
opt = opt_method(train_vars={'a': fixed_points}, lr=opt_lr, **opt_setting)
140-
dyn_vars = opt.vars() + {'_a': fixed_points}
157+
optimizer.register_vars({'a': fixed_points})
158+
dyn_vars = optimizer.vars() + {'_a': fixed_points}
141159

142160
def train(idx):
143161
gradients, loss = grad_f()
144-
opt.update(gradients)
162+
optimizer.update(gradients)
145163
return loss
146164

147165
@partial(bm.jit, dyn_vars=dyn_vars, static_argnames=('start_i', 'num_batch'))
@@ -191,7 +209,7 @@ def find_fps_with_opt_solver(self, candidates, opt_method=None):
191209
opt_method = lambda f, x0: minimize(f, x0, method='BFGS')
192210
if self.verbose:
193211
print(f"Optimizing to find fixed points:")
194-
f_opt = bm.jit(bm.vmap(lambda x0: opt_method(self.f_loss, x0)))
212+
f_opt = bm.jit(vmap(lambda x0: opt_method(self.f_loss, x0)))
195213
res = f_opt(bm.as_device_array(candidates))
196214
valid_ids = jax.numpy.where(res.success)[0]
197215
self._fixed_points = np.asarray(res.x[valid_ids])

brainpy/analysis/lowdim/lowdim_analyzer.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import partial
44

55
import numpy as np
6+
from jax import vmap
67
from jax import numpy as jnp
78
from jax.scipy.optimize import minimize
89

@@ -262,7 +263,7 @@ def F_fx(self):
262263
@property
263264
def F_vmap_fx(self):
264265
if C.F_vmap_fx not in self.analyzed_results:
265-
self.analyzed_results[C.F_vmap_fx] = bm.jit(bm.vmap(self.F_fx), device=self.jit_device)
266+
self.analyzed_results[C.F_vmap_fx] = bm.jit(vmap(self.F_fx), device=self.jit_device)
266267
return self.analyzed_results[C.F_vmap_fx]
267268

268269
@property
@@ -289,7 +290,7 @@ def F_vmap_fp_aux(self):
289290
# ---
290291
# "X": a two-dimensional matrix: (num_batch, num_var)
291292
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
292-
self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(bm.vmap(self.F_fixed_point_aux))
293+
self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(vmap(self.F_fixed_point_aux))
293294
return self.analyzed_results[C.F_vmap_fp_aux]
294295

295296
@property
@@ -308,7 +309,7 @@ def F_vmap_fp_opt(self):
308309
# ---
309310
# "X": a two-dimensional matrix: (num_batch, num_var)
310311
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
311-
self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(bm.vmap(self.F_fixed_point_opt))
312+
self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(vmap(self.F_fixed_point_opt))
312313
return self.analyzed_results[C.F_vmap_fp_opt]
313314

314315
def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_screen=None):
@@ -501,7 +502,7 @@ def F_y_by_x_in_fy(self):
501502
@property
502503
def F_vmap_fy(self):
503504
if C.F_vmap_fy not in self.analyzed_results:
504-
self.analyzed_results[C.F_vmap_fy] = bm.jit(bm.vmap(self.F_fy), device=self.jit_device)
505+
self.analyzed_results[C.F_vmap_fy] = bm.jit(vmap(self.F_fy), device=self.jit_device)
505506
return self.analyzed_results[C.F_vmap_fy]
506507

507508
@property
@@ -663,7 +664,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
663664

664665
if self.F_x_by_y_in_fx is not None:
665666
utils.output("I am evaluating fx-nullcline by F_x_by_y_in_fx ...")
666-
vmap_f = bm.jit(bm.vmap(self.F_x_by_y_in_fx), device=self.jit_device)
667+
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device)
667668
for j, pars in enumerate(par_seg):
668669
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
669670
mesh_values = jnp.meshgrid(*((ys,) + pars))
@@ -679,7 +680,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
679680

680681
elif self.F_y_by_x_in_fx is not None:
681682
utils.output("I am evaluating fx-nullcline by F_y_by_x_in_fx ...")
682-
vmap_f = bm.jit(bm.vmap(self.F_y_by_x_in_fx), device=self.jit_device)
683+
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device)
683684
for j, pars in enumerate(par_seg):
684685
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
685686
mesh_values = jnp.meshgrid(*((xs,) + pars))
@@ -697,9 +698,9 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
697698
utils.output("I am evaluating fx-nullcline by optimization ...")
698699
# auxiliary functions
699700
f2 = lambda y, x, *pars: self.F_fx(x, y, *pars)
700-
vmap_f2 = bm.jit(bm.vmap(f2), device=self.jit_device)
701-
vmap_brentq_f2 = bm.jit(bm.vmap(utils.jax_brentq(f2)), device=self.jit_device)
702-
vmap_brentq_f1 = bm.jit(bm.vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device)
701+
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device)
702+
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
703+
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device)
703704

704705
# num segments
705706
for _j, Ps in enumerate(par_seg):
@@ -756,7 +757,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
756757

757758
if self.F_x_by_y_in_fy is not None:
758759
utils.output("I am evaluating fy-nullcline by F_x_by_y_in_fy ...")
759-
vmap_f = bm.jit(bm.vmap(self.F_x_by_y_in_fy), device=self.jit_device)
760+
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device)
760761
for j, pars in enumerate(par_seg):
761762
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
762763
mesh_values = jnp.meshgrid(*((ys,) + pars))
@@ -772,7 +773,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
772773

773774
elif self.F_y_by_x_in_fy is not None:
774775
utils.output("I am evaluating fy-nullcline by F_y_by_x_in_fy ...")
775-
vmap_f = bm.jit(bm.vmap(self.F_y_by_x_in_fy), device=self.jit_device)
776+
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device)
776777
for j, pars in enumerate(par_seg):
777778
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
778779
mesh_values = jnp.meshgrid(*((xs,) + pars))
@@ -791,9 +792,9 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
791792

792793
# auxiliary functions
793794
f2 = lambda y, x, *pars: self.F_fy(x, y, *pars)
794-
vmap_f2 = bm.jit(bm.vmap(f2), device=self.jit_device)
795-
vmap_brentq_f2 = bm.jit(bm.vmap(utils.jax_brentq(f2)), device=self.jit_device)
796-
vmap_brentq_f1 = bm.jit(bm.vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device)
795+
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device)
796+
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
797+
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device)
797798

798799
for j, Ps in enumerate(par_seg):
799800
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
@@ -841,7 +842,7 @@ def _get_fp_candidates_by_aux_rank(self, num_segments=1, num_rank=100):
841842
xs = self.resolutions[self.x_var].value
842843
ys = self.resolutions[self.y_var].value
843844
P = tuple(self.resolutions[p].value for p in self.target_par_names)
844-
f_select = bm.jit(bm.vmap(lambda vals, ids: vals[ids], in_axes=(1, 1)))
845+
f_select = bm.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1)))
845846

846847
# num seguments
847848
if isinstance(num_segments, int):
@@ -921,10 +922,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
921922

922923
if self.convert_type() == C.x_by_y:
923924
num_seg = len(self.resolutions[self.y_var])
924-
f_vmap = bm.jit(bm.vmap(self.F_y_convert[1]))
925+
f_vmap = bm.jit(vmap(self.F_y_convert[1]))
925926
else:
926927
num_seg = len(self.resolutions[self.x_var])
927-
f_vmap = bm.jit(bm.vmap(self.F_x_convert[1]))
928+
f_vmap = bm.jit(vmap(self.F_x_convert[1]))
928929
# get the signs
929930
signs = jnp.sign(f_vmap(candidates, *args))
930931
signs = signs.reshape((num_seg, -1))
@@ -954,10 +955,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
954955
# get another value
955956
if self.convert_type() == C.x_by_y:
956957
y_values = fps
957-
x_values = bm.jit(bm.vmap(self.F_y_convert[0]))(y_values, *args)
958+
x_values = bm.jit(vmap(self.F_y_convert[0]))(y_values, *args)
958959
else:
959960
x_values = fps
960-
y_values = bm.jit(bm.vmap(self.F_x_convert[0]))(x_values, *args)
961+
y_values = bm.jit(vmap(self.F_x_convert[0]))(x_values, *args)
961962
fps = jnp.stack([x_values, y_values]).T
962963
return fps, selected_ids, args
963964

0 commit comments

Comments
 (0)