Skip to content

Commit dd238fc

Browse files
authored
[documentation] update documentation to brainpy>=2.4.0 (#361)
[documentation] Update documentation to brainpy>=2.4.0
2 parents 3d63531 + be3c613 commit dd238fc

File tree

61 files changed

+5429
-3568
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+5429
-3568
lines changed

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ def __init__(
234234
if f_loss_batch is not None:
235235
raise UnsupportedError('"f_loss_batch" is no longer supported, please '
236236
'use "f_loss" instead.')
237+
if fun_inputs is not None:
238+
raise UnsupportedError('"fun_inputs" is no longer supported.')
237239
if f_loss is None:
238240
f_loss = losses.mean_squared_error if f_type == constants.DISCRETE else losses.mean_square
239241
self.f_loss = f_loss

brainpy/_src/analysis/lowdim/lowdim_bifurcation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from functools import partial
44

5+
import jax
56
import jax.numpy as jnp
67
from jax import vmap
78
import numpy as np

brainpy/_src/analysis/lowdim/lowdim_phase_plane.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3+
import jax
34
import jax.numpy as jnp
45
import numpy as np
56
from jax import vmap

brainpy/_src/checkpoints/io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from brainpy import errors
1010
import brainpy.math as bm
11-
from brainpy._src.math.object_transform.base import BrainPyObject, ArrayCollector
11+
from brainpy._src.math.object_transform.base import BrainPyObject
12+
from brainpy._src.math.object_transform.collectors import ArrayCollector
1213

1314

1415
logger = logging.getLogger('brainpy.brainpy_object.io')

brainpy/_src/checkpoints/serialization.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,7 @@ def _record_saved_duration(checkpoint_start_time: float):
951951
# Note: for the very first checkpoint, this is the interval between program
952952
# init and the current checkpoint start time.
953953
duration_since_last_checkpoint = checkpoint_start_time - _LAST_CHECKPOINT_WRITE_TIME
954-
if jax.version.__version_info__ > (0, 3, 25):
954+
if monitoring is not None:
955955
monitoring.record_event_duration_secs(
956956
'/jax/checkpoint/write/duration_since_last_checkpoint_secs',
957957
duration_since_last_checkpoint)
@@ -1151,7 +1151,7 @@ def save_main_ckpt_task():
11511151
else:
11521152
save_main_ckpt_task()
11531153
end_time = time.time()
1154-
if jax.version.__version_info__ > (0, 3, 25):
1154+
if monitoring is not None:
11551155
monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT,
11561156
end_time - start_time)
11571157
return ckpt_path
@@ -1281,7 +1281,7 @@ def save_main_ckpt_task():
12811281
else:
12821282
save_main_ckpt_task()
12831283
end_time = time.time()
1284-
if jax.version.__version_info__ > (0, 3, 25):
1284+
if monitoring is not None:
12851285
monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT,
12861286
end_time - start_time)
12871287

@@ -1390,7 +1390,7 @@ def save_main_ckpt_task():
13901390
keep, overwrite, keep_every_n_steps, start_time, async_manager)
13911391

13921392
end_time = time.time()
1393-
if jax.version.__version_info__ > (0, 3, 25):
1393+
if monitoring is not None:
13941394
monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT,
13951395
end_time - start_time)
13961396
return ckpt_path
@@ -1553,7 +1553,7 @@ def read_chunk(i):
15531553
restored_checkpoint = from_state_dict(target, state_dict)
15541554

15551555
end_time = time.time()
1556-
if jax.version.__version_info__ > (0, 3, 25):
1556+
if monitoring is not None:
15571557
monitoring.record_event_duration_secs(_READ_CHECKPOINT_EVENT, end_time - start_time)
15581558

15591559
return restored_checkpoint
@@ -1616,7 +1616,7 @@ def read_chunk(i):
16161616

16171617
state_dict = msgpack_restore(checkpoint_contents)
16181618
end_time = time.time()
1619-
if jax.version.__version_info__ > (0, 3, 25):
1619+
if monitoring is not None:
16201620
monitoring.record_event_duration_secs(_READ_CHECKPOINT_EVENT, end_time - start_time)
16211621

16221622
return state_dict

brainpy/_src/dyn/base.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,16 +1475,10 @@ def __getitem__(self, key: Union[int, slice, str]):
14751475
elif isinstance(key, slice):
14761476
return Sequential(*(self.__all_nodes()[key]))
14771477
elif isinstance(key, int):
1478-
key = self.__format_key(key)
1479-
return self._static_modules[key] if (key not in self._dyn_modules) else self._dyn_modules[key]
1478+
return self.__all_nodes()[key]
14801479
elif isinstance(key, (tuple, list)):
1481-
nodes = []
1482-
for i in key:
1483-
if isinstance(i, int):
1484-
i = self.__format_key(i)
1485-
assert isinstance(i, str)
1486-
nodes.append(self._static_modules[i] if (i not in self._dyn_modules) else self._dyn_modules[i])
1487-
return Sequential(*nodes)
1480+
_all_nodes = self.__all_nodes()
1481+
return Sequential(*[_all_nodes[k] for k in key])
14881482
else:
14891483
raise KeyError(f'Unknown type of key: {type(key)}')
14901484

brainpy/_src/dyn/context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ def save(self, *args, **kwargs) -> None:
6565
for identifier, data in kwargs.items():
6666
self._arguments[identifier] = data
6767

68+
def __setitem__(self, key, value):
69+
"""Enable setting the shared item by ``bp.share[key] = value``."""
70+
self.save(key, value)
71+
72+
def __getitem__(self, item):
73+
"""Enable loading the shared parameter by ``bp.share[key]``."""
74+
return self.load(item)
75+
6876
def get_shargs(self) -> DotDict:
6977
"""Get all shared arguments in the global context."""
7078
return self._arguments.copy()

brainpy/_src/dyn/neurons/biological_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,8 @@ def __init__(
456456
self.input_var = input_var
457457

458458
# initializers
459-
self._W_initializer = check.is_initializer(V_initializer, allow_none=False)
460-
self._V_initializer = check.is_initializer(W_initializer, allow_none=False)
459+
self._W_initializer = check.is_initializer(W_initializer, allow_none=False)
460+
self._V_initializer = check.is_initializer(V_initializer, allow_none=False)
461461

462462
# variables
463463
self.reset_state(self.mode)
@@ -491,7 +491,7 @@ def dW(self, W, t, V):
491491

492492
@property
493493
def derivative(self):
494-
return JointEq([self.dV, self.dW])
494+
return JointEq(self.dV, self.dW)
495495

496496
def update(self, x=None):
497497
t = share.load('t')

brainpy/_src/dyn/neurons/input_groups.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,10 @@ def __init__(
131131
self.num_times = len(times)
132132

133133
# data about times and indices
134-
self.times = jnp.asarray(times)
135-
self.indices = jnp.asarray(indices, dtype=bm.int_)
134+
self.times = bm.asarray(times)
135+
self.indices = bm.asarray(indices, dtype=bm.int_)
136136
if need_sort:
137-
sort_idx = jnp.argsort(self.times)
137+
sort_idx = bm.argsort(self.times)
138138
self.indices.value = self.indices[sort_idx]
139139
self.times.value = self.times[sort_idx]
140140

@@ -144,7 +144,7 @@ def __init__(
144144
# functions
145145
def cond_fun(t):
146146
i = self.i.value
147-
return jnp.logical_and(i < self.num_times, t >= self.times[i])
147+
return bm.logical_and(i < self.num_times, t >= self.times[i])
148148

149149
def body_fun(t):
150150
i = self.i.value

brainpy/_src/dyn/runners.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,6 @@ def _step_func_predict(self, shared_args, t, i, x):
640640
shared = tools.DotDict(t=t, i=i, dt=self.dt)
641641
shared.update(shared_args)
642642
share.save(**shared)
643-
self.target.clear_input()
644643
self._step_func_input(shared)
645644

646645
# dynamics update step
@@ -655,6 +654,7 @@ def _step_func_predict(self, shared_args, t, i, x):
655654
if self.progress_bar:
656655
id_tap(lambda *arg: self._pbar.update(), ())
657656
share.clear_shargs()
657+
self.target.clear_input()
658658

659659
if self._memory_efficient:
660660
id_tap(self._step_mon_on_cpu, mon)

0 commit comments

Comments
 (0)