Skip to content

Commit e25144e

Browse files
authored
fix some bugs (#262)
fix some bugs
2 parents 64d1e21 + c177742 commit e25144e

File tree

17 files changed

+116
-91
lines changed

17 files changed

+116
-91
lines changed

brainpy/analysis/highdim/slow_points.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,7 @@ def train(idx):
355355
return loss
356356

357357
def batch_train(start_i, n_batch):
358-
f = bm.make_loop(train, dyn_vars=dyn_vars, has_return=True)
359-
return f(bm.arange(start_i, start_i + n_batch))
358+
return bm.for_loop(train, dyn_vars, bm.arange(start_i, start_i + n_batch))
360359

361360
# Run the optimization
362361
if self.verbose:
@@ -369,7 +368,7 @@ def batch_train(start_i, n_batch):
369368
break
370369
batch_idx_start = oidx * num_batch
371370
start_time = time.time()
372-
(_, train_losses) = batch_train(start_i=batch_idx_start, n_batch=num_batch)
371+
train_losses = batch_train(start_i=batch_idx_start, n_batch=num_batch)
373372
batch_time = time.time() - start_time
374373
opt_losses.append(train_losses)
375374

@@ -722,8 +721,6 @@ def _generate_ds_cell_function(
722721
shared = DotDict(t=t, dt=dt, i=0)
723722

724723
def f_cell(h: Dict):
725-
target.clear_input()
726-
727724
# update target variables
728725
for k, v in self.target_vars.items():
729726
v.value = (bm.asarray(h[k], dtype=v.dtype)
@@ -735,6 +732,7 @@ def f_cell(h: Dict):
735732
v.value = self.excluded_data[k]
736733

737734
# add inputs
735+
target.clear_input()
738736
if f_input is not None:
739737
f_input(shared)
740738

@@ -743,7 +741,7 @@ def f_cell(h: Dict):
743741
target.update(*args)
744742

745743
# get new states
746-
new_h = {k: (v.value if v.batch_axis is None else jnp.squeeze(v.value, axis=v.batch_axis))
744+
new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis))
747745
for k, v in self.target_vars.items()}
748746
return new_h
749747

brainpy/analysis/lowdim/lowdim_bifurcation.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,17 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
354354
if with_return:
355355
return final_fps, final_pars, jacobians
356356

357-
def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=False,
358-
plot_style=None, tol=0.001, show=False, dt=None, offset=1.):
357+
def plot_limit_cycle_by_sim(
358+
self,
359+
duration=100,
360+
with_plot: bool = True,
361+
with_return: bool = False,
362+
plot_style: dict = None,
363+
tol: float = 0.001,
364+
show: bool = False,
365+
dt: float = None,
366+
offset: float = 1.
367+
):
359368
global pyplot
360369
if pyplot is None: from matplotlib import pyplot
361370
utils.output('I am plotting the limit cycle ...')
@@ -400,10 +409,16 @@ def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=Fals
400409
if len(ps_limit_cycle[0]):
401410
for i, var in enumerate(self.target_var_names):
402411
pyplot.figure(var)
403-
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'],
404-
**plot_style, label='limit cycle (max)')
405-
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'],
406-
**plot_style, label='limit cycle (min)')
412+
pyplot.plot(ps_limit_cycle[0],
413+
ps_limit_cycle[1],
414+
vs_limit_cycle[i]['max'],
415+
**plot_style,
416+
label='limit cycle (max)')
417+
pyplot.plot(ps_limit_cycle[0],
418+
ps_limit_cycle[1],
419+
vs_limit_cycle[i]['min'],
420+
**plot_style,
421+
label='limit cycle (min)')
407422
pyplot.legend()
408423

409424
elif len(self.target_par_names) == 1:
@@ -427,8 +442,16 @@ def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=Fals
427442

428443

429444
class FastSlow1D(Bifurcation1D):
430-
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
431-
pars_update=None, resolutions=None, options=None):
445+
def __init__(
446+
self,
447+
model,
448+
fast_vars: dict,
449+
slow_vars: dict,
450+
fixed_vars: dict = None,
451+
pars_update: dict = None,
452+
resolutions=None,
453+
options: dict = None
454+
):
432455
super(FastSlow1D, self).__init__(model=model,
433456
target_pars=slow_vars,
434457
target_vars=fast_vars,
@@ -510,8 +533,16 @@ def plot_trajectory(self, initials, duration, plot_durations=None,
510533

511534

512535
class FastSlow2D(Bifurcation2D):
513-
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
514-
pars_update=None, resolutions=0.1, options=None):
536+
def __init__(
537+
self,
538+
model,
539+
fast_vars: dict,
540+
slow_vars: dict,
541+
fixed_vars: dict = None,
542+
pars_update: dict = None,
543+
resolutions=0.1,
544+
options: dict = None
545+
):
515546
super(FastSlow2D, self).__init__(model=model,
516547
target_pars=slow_vars,
517548
target_vars=fast_vars,

brainpy/analysis/utils/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,14 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):
112112

113113
# variables
114114
assert isinstance(initial_vars, dict)
115-
initial_vars = {k: bm.Variable(jnp.asarray(bm.as_device_array(v), dtype=jnp.float_))
115+
initial_vars = {k: bm.Variable(jnp.asarray(bm.as_device_array(v), dtype=bm.dftype()))
116116
for k, v in initial_vars.items()}
117117
self.register_implicit_vars(initial_vars)
118118

119119
# parameters
120120
pars = dict() if pars is None else pars
121121
assert isinstance(pars, dict)
122-
self.pars = [jnp.asarray(bm.as_device_array(v), dtype=jnp.float_)
122+
self.pars = [jnp.asarray(bm.as_device_array(v), dtype=bm.dftype())
123123
for k, v in pars.items()]
124124

125125
# integrals
@@ -128,7 +128,8 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):
128128
# runner
129129
self.runner = DSRunner(self,
130130
monitors=list(initial_vars.keys()),
131-
dyn_vars=self.vars().unique(), dt=dt,
131+
dyn_vars=self.vars().unique(),
132+
dt=dt,
132133
progress_bar=False)
133134

134135
def update(self, sha):

brainpy/dyn/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,7 @@ def offline_fit(self,
343343
raise NoImplementationError('Subclass must implement offline_fit() function when using OfflineTrainer.')
344344

345345
def clear_input(self):
346-
for node in self.nodes(level=1, include_self=False).subset(NeuGroup).unique().values():
347-
node.clear_input()
346+
pass
348347

349348

350349
class Container(DynamicalSystem):
@@ -430,6 +429,10 @@ def __getattr__(self, item):
430429
else:
431430
return super(Container, self).__getattribute__(item)
432431

432+
def clear_input(self):
433+
for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values():
434+
node.clear_input()
435+
433436

434437
class Sequential(Container):
435438
def __init__(

brainpy/dyn/neurons/biological_models.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -244,20 +244,17 @@ def __init__(
244244

245245
# variables
246246
self.V = variable(self._V_initializer, mode, self.varshape)
247-
if self._m_initializer is None:
248-
self.m = bm.Variable(self.m_inf(self.V.value))
249-
else:
250-
self.m = variable(self._m_initializer, mode, self.varshape)
251-
if self._h_initializer is None:
252-
self.h = bm.Variable(self.h_inf(self.V.value))
253-
else:
254-
self.h = variable(self._h_initializer, mode, self.varshape)
255-
if self._n_initializer is None:
256-
self.n = bm.Variable(self.n_inf(self.V.value))
257-
else:
258-
self.n = variable(self._n_initializer, mode, self.varshape)
259-
self.input = variable(bm.zeros, mode, self.varshape)
247+
self.m = (bm.Variable(self.m_inf(self.V.value))
248+
if m_initializer is None else
249+
variable(self._m_initializer, mode, self.varshape))
250+
self.h = (bm.Variable(self.h_inf(self.V.value))
251+
if h_initializer is None else
252+
variable(self._h_initializer, mode, self.varshape))
253+
self.n = (bm.Variable(self.n_inf(self.V.value))
254+
if n_initializer is None else
255+
variable(self._n_initializer, mode, self.varshape))
260256
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
257+
self.input = variable(bm.zeros, mode, self.varshape)
261258

262259
# integral
263260
if self.noise is None:
@@ -309,7 +306,7 @@ def dV(self, V, t, m, h, n, I_ext):
309306

310307
@property
311308
def derivative(self):
312-
return JointEq([self.dV, self.dm, self.dh, self.dn])
309+
return JointEq(self.dV, self.dm, self.dh, self.dn)
313310

314311
def update(self, tdi, x=None):
315312
t, dt = tdi['t'], tdi['dt']

brainpy/dyn/runners.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -566,8 +566,7 @@ def f_predict(self, shared_args: Dict = None):
566566

567567
monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args)
568568

569-
def _step_func(inputs):
570-
t, i, x = inputs
569+
def _step_func(t, i, x):
571570
self.target.clear_input()
572571
# input step
573572
shared = DotDict(t=t, i=i, dt=self.dt)
@@ -586,8 +585,7 @@ def _step_func(inputs):
586585
if self.jit['predict']:
587586
dyn_vars = self.target.vars()
588587
dyn_vars.update(self.dyn_vars)
589-
f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True)
590-
run_func = lambda all_inputs: f(all_inputs)[1]
588+
run_func = lambda all_inputs: bm.for_loop(_step_func, dyn_vars.unique(), all_inputs)
591589

592590
else:
593591
def run_func(xs):
@@ -601,7 +599,7 @@ def run_func(xs):
601599
x = tree_map(lambda x: x[i], xs, is_leaf=lambda x: isinstance(x, bm.JaxArray))
602600

603601
# step at the i
604-
output, mon = _step_func((times[i], indices[i], x))
602+
output, mon = _step_func(times[i], indices[i], x)
605603

606604
# append output and monitor
607605
outputs.append(output)

brainpy/inputs/currents.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,9 @@ def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None,
307307

308308
def _f(t):
309309
x.value = x + dt * ((mean - x) / tau) + sigma * dt_sqrt * rng.rand(n)
310+
return x.value
310311

311-
f = bm.make_loop(_f, dyn_vars=[x, rng], out_vars=x)
312-
noises = f(jnp.arange(t_start, t_end, dt))
312+
noises = bm.for_loop(_f, [x, rng], jnp.arange(t_start, t_end, dt))
313313

314314
t_end = duration if t_end is None else t_end
315315
i_start = int(t_start / dt)

brainpy/integrators/ode/tests/test_ode_method_adaptive_rk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def f(t):
4545

4646
if show:
4747
fig = plt.figure()
48-
ax = fig.gca(projection='3d')
48+
ax = fig.add_subplot(111, projection='3d')
4949
plt.plot(mon_x, mon_y, mon_z)
5050
ax.set_xlabel('x')
5151
ax.set_xlabel('y')

brainpy/integrators/runner.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -217,16 +217,12 @@ def __init__(
217217

218218
# build the update step
219219
if self.jit['predict']:
220-
_loop_func = bm.make_loop(
221-
self._step,
222-
dyn_vars=self.dyn_vars,
223-
out_vars={k: self.variables[k] for k in self.monitors.keys()},
224-
has_return=True
225-
)
220+
def _loop_func(times):
221+
return bm.for_loop(self._step, self.dyn_vars, times)
226222
else:
227223
def _loop_func(times):
228-
out_vars = {k: [] for k in self.monitors.keys()}
229224
returns = {k: [] for k in self.fun_monitors.keys()}
225+
returns.update({k: [] for k in self.monitors.keys()})
230226
for i in range(len(times)):
231227
_t = times[i]
232228
_dt = self.dt
@@ -237,9 +233,9 @@ def _loop_func(times):
237233
self._step(_t)
238234
# variable monitors
239235
for k in self.monitors.keys():
240-
out_vars[k].append(bm.as_device_array(self.variables[k]))
241-
out_vars = {k: bm.asarray(out_vars[k]) for k in self.monitors.keys()}
242-
return out_vars, returns
236+
returns[k].append(bm.as_device_array(self.variables[k]))
237+
returns = {k: bm.asarray(returns[k]) for k in returns.keys()}
238+
return returns
243239
self.step_func = _loop_func
244240

245241
def _step(self, t):
@@ -252,11 +248,6 @@ def _step(self, t):
252248
kwargs.update({k: v[self.idx.value] for k, v in self._dyn_args.items()})
253249
self.idx += 1
254250

255-
# return of function monitors
256-
returns = dict()
257-
for key, func in self.fun_monitors.items():
258-
returns[key] = func(t, self.dt)
259-
260251
# call integrator function
261252
update_values = self.target(**kwargs)
262253
if len(self.target.variables) == 1:
@@ -268,6 +259,13 @@ def _step(self, t):
268259
# progress bar
269260
if self.progress_bar:
270261
id_tap(lambda *args: self._pbar.update(), ())
262+
263+
# return of function monitors
264+
returns = dict()
265+
for key, func in self.fun_monitors.items():
266+
returns[key] = func(t, self.dt)
267+
for k in self.monitors.keys():
268+
returns[k] = self.variables[k].value
271269
return returns
272270

273271
def run(self, duration, start_t=None, eval_time=False):
@@ -302,14 +300,13 @@ def run(self, duration, start_t=None, eval_time=False):
302300
refresh=True)
303301
if eval_time:
304302
t0 = time.time()
305-
hists, returns = self.step_func(times)
303+
hists = self.step_func(times)
306304
if eval_time:
307305
running_time = time.time() - t0
308306
if self.progress_bar:
309307
self._pbar.close()
310308

311309
# post-running
312-
hists.update(returns)
313310
times += self.dt
314311
if self.numpy_mon_after_run:
315312
times = np.asarray(times)

brainpy/integrators/sde/tests/test_sde_scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def lorenz_system(method, **kwargs):
5050
mon3 = bp.math.array(mon3).to_numpy()
5151

5252
fig = plt.figure()
53-
ax = fig.gca(projection='3d')
53+
ax = fig.add_subplot(111, projection='3d')
5454
plt.plot(mon1, mon2, mon3)
5555
ax.set_xlabel('x')
5656
ax.set_xlabel('y')

0 commit comments

Comments
 (0)