Skip to content

Commit 231b566

Browse files
authored
improve oo-to-function transformation speed (#208)
improve oo-to-function transformation speed
2 parents 683e5cb + 44eac4a commit 231b566

File tree

9 files changed

+30
-572
lines changed

9 files changed

+30
-572
lines changed

brainpy/dyn/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ class DynamicalSystem(Base):
5454

5555
"""Global delay variables. Useful when the same target
5656
variable is used in multiple mappings."""
57-
global_delay_vars: Dict[str, bm.LengthDelay] = dict()
57+
global_delay_vars: Dict[str, bm.LengthDelay] = Collector()
5858

5959
def __init__(self, name=None):
6060
super(DynamicalSystem, self).__init__(name=name)
6161

6262
# local delay variables
63-
self.local_delay_vars: Dict[str, bm.LengthDelay] = dict()
63+
self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector()
6464

6565
def __repr__(self):
6666
return f'{self.__class__.__name__}(name={self.name})'
@@ -334,15 +334,17 @@ def reset(self):
334334

335335
@classmethod
336336
def has(cls, **children_cls):
337-
"""
337+
"""The aggressive operation to gather master and children classes.
338338
339339
Parameters
340340
----------
341341
children_cls
342+
The children classes.
342343
343344
Returns
344345
-------
345-
346+
wrapper: ContainerWrapper
347+
A wrapper which has master and its children classes.
346348
"""
347349
return ContainerWrapper(master=cls, **children_cls)
348350

brainpy/math/controls.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def call(xs=None, length=None):
169169
turn_off_global_jit()
170170
except UnexpectedTracerError as e:
171171
turn_off_global_jit()
172-
for v, d in zip(dyn_vars, init_values): v.value = d
172+
for v, d in zip(dyn_vars, init_values): v._value = d
173173
raise errors.JaxTracerError(variables=dyn_vars) from e
174-
for v, d in zip(dyn_vars, dyn_values): v.value = d
174+
for v, d in zip(dyn_vars, dyn_values): v._value = d
175175
return tree_unflatten(tree, out_values), results
176176

177177
else:
@@ -189,7 +189,7 @@ def call(xs):
189189
turn_off_global_jit()
190190
for v, d in zip(dyn_vars, init_values): v._value = d
191191
raise e
192-
for v, d in zip(dyn_vars, dyn_values): v.value = d
192+
for v, d in zip(dyn_vars, dyn_values): v._value = d
193193
return tree_unflatten(tree, out_values)
194194

195195
return call
@@ -271,7 +271,7 @@ def call(x=None):
271271
turn_off_global_jit()
272272
for v, d in zip(dyn_vars, dyn_init): v._value = d
273273
raise e
274-
for v, d in zip(dyn_vars, dyn_values): v.value = d
274+
for v, d in zip(dyn_vars, dyn_values): v._value = d
275275

276276
return call
277277

@@ -359,7 +359,7 @@ def call(pred, x=None):
359359
turn_off_global_jit()
360360
for v, d in zip(dyn_vars, old_values): v._value = d
361361
raise e
362-
for v, d in zip(dyn_vars, dyn_values): v.value = d
362+
for v, d in zip(dyn_vars, dyn_values): v._value = d
363363
return res
364364

365365
else:
@@ -477,7 +477,7 @@ def _false_fun(op):
477477
turn_off_global_jit()
478478
for v, d in zip(dyn_vars, old_values): v._value = d
479479
raise e
480-
for v, d in zip(dyn_vars, dyn_values): v.value = d
480+
for v, d in zip(dyn_vars, dyn_values): v._value = d
481481
else:
482482
turn_on_global_jit()
483483
res = lax.cond(pred, true_fun, false_fun, operands)
@@ -663,7 +663,7 @@ def fun2scan(dyn_vals, x):
663663
turn_off_global_jit()
664664
for v, d in zip(dyn_vars, init_vals): v._value = d
665665
raise e
666-
for v, d in zip(dyn_vars, dyn_vals): v.value = d
666+
for v, d in zip(dyn_vars, dyn_vals): v._value = d
667667
return out_vals
668668

669669

@@ -729,4 +729,4 @@ def _cond_fun(op):
729729
turn_off_global_jit()
730730
for v, d in zip(dyn_vars, dyn_init): v._value = d
731731
raise e
732-
for v, d in zip(dyn_vars, dyn_values): v.value = d
732+
for v, d in zip(dyn_vars, dyn_values): v._value = d

brainpy/math/jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def call(*args, **kwargs):
5656
turn_off_global_jit()
5757
for key, v in vars.items(): v._value = variable_data[key]
5858
raise e
59-
vars.assign(changes)
59+
for key, v in vars.items(): v._value = changes[key]
6060
return out
6161

6262
return change_func_name(name=f_name, f=call) if f_name else call

docs/apis/dyn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@
1313
auto/dyn/neurons
1414
auto/dyn/synapses
1515
auto/dyn/rates
16+
auto/dyn/others
1617
auto/dyn/runners

docs/auto_generater.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ def generate_dyn_docs(path='apis/auto/dyn/'):
255255
module_and_name = [
256256
('biological_models', 'Biological Models'),
257257
('fractional_models', 'Fractional-order Models'),
258-
('input_models', 'Input Models'),
259258
('reduced_models', 'Reduced Models'),
260259
]
261260
write_submodules(module_name='brainpy.dyn.neurons',
@@ -278,14 +277,23 @@ def generate_dyn_docs(path='apis/auto/dyn/'):
278277
module_and_name = [
279278
('populations', 'Population Models'),
280279
('couplings', 'Coupling Models'),
281-
('noises', 'Noise Models'),
282280
]
283281
write_submodules(module_name='brainpy.dyn.rates',
284282
filename=os.path.join(path, 'rates.rst'),
285283
header='Rate Models',
286284
submodule_names=[a[0] for a in module_and_name],
287285
section_names=[a[1] for a in module_and_name])
288286

287+
module_and_name = [
288+
('noises', 'Noise Models'),
289+
('inputs', 'Input Models'),
290+
]
291+
write_submodules(module_name='brainpy.dyn.others',
292+
filename=os.path.join(path, 'others.rst'),
293+
header='Helper Models',
294+
submodule_names=[a[0] for a in module_and_name],
295+
section_names=[a[1] for a in module_and_name])
296+
289297
write_module(module_name='brainpy.dyn.runners',
290298
filename=os.path.join(path, 'runners.rst'),
291299
header='Runners')

docs/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ The code of BrainPy is open-sourced at GitHub:
5959

6060
tutorial_toolbox/ode_numerical_solvers
6161
tutorial_toolbox/sde_numerical_solvers
62-
tutorial_toolbox/dde_numerical_solvers
6362
tutorial_toolbox/fde_numerical_solvers
63+
tutorial_toolbox/dde_numerical_solvers
6464
tutorial_toolbox/joint_equations
6565
tutorial_toolbox/synaptic_connections
6666
tutorial_toolbox/synaptic_weights

docs/tutorial_toolbox/fde_numerical_solvers.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
"source": [
4747
"Factional differential equations have several definitions. It can be defined in a variety of different ways that do often do not all lead to the same result even for smooth functions. In neuroscience, we usually use the following two definitions:\n",
4848
"\n",
49-
"- Riemann–Liouville fractional derivative\n",
49+
"- Grünwald-Letnikov derivative\n",
5050
"- Caputo fractional derivative\n",
5151
"\n",
5252
"See [Fractional calculus - Wikipedia](https://en.wikipedia.org/wiki/Fractional_calculus) for more details."
@@ -421,7 +421,7 @@
421421
{
422422
"cell_type": "markdown",
423423
"source": [
424-
"## Methods for Riemann–Liouville FDEs"
424+
"## Methods for Grünwald-Letnikov FDEs"
425425
],
426426
"metadata": {
427427
"collapsed": false,

0 commit comments

Comments
 (0)