Skip to content

Commit 8692d3e

Browse files
authored
Merge pull request #341 from chaoming0625/master
Updates
2 parents e0a3ee1 + cac4b1d commit 8692d3e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+1881
-1999
lines changed

brainpy/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@
5858
synapses, # synaptic dynamics
5959
synouts, # synaptic output
6060
synplast, # synaptic plasticity
61-
experimental, # experimental model
61+
syn,
6262
)
63-
from brainpy._src.dyn.base import not_pass_shargs
64-
from brainpy._src.dyn.base import (DynamicalSystem as DynamicalSystem,
65-
Module as Module,
63+
from brainpy._src.dyn.base import not_pass_sha
64+
from brainpy._src.dyn.base import (DynamicalSystem,
65+
DynamicalSystemNS,
6666
Container as Container,
6767
Sequential as Sequential,
6868
Network as Network,
@@ -77,6 +77,8 @@
7777
from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations
7878
LoopOverTime as LoopOverTime,)
7979
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
80+
from brainpy._src.dyn.context import share
81+
from brainpy._src.dyn.delay import Delay
8082

8183

8284
# Part 4: Training #
@@ -240,3 +242,7 @@
240242
dyn.__dict__['NMDA'] = compat.NMDA
241243
del compat
242244

245+
246+
from brainpy._src import checking
247+
tools.__dict__['checking'] = checking
248+
del checking

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def f_cell(h: Dict):
752752

753753
# call update functions
754754
args = (shared,) + self.args
755-
target.update(*args)
755+
target(*args)
756756

757757
# get new states
758758
new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis))

brainpy/_src/analysis/lowdim/lowdim_bifurcation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,10 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
225225

226226
if self._can_convert_to_one_eq():
227227
if self.convert_type() == C.x_by_y:
228-
X = self.resolutions[self.y_var].value
228+
X = bm.as_jax(self.resolutions[self.y_var])
229229
else:
230-
X = self.resolutions[self.x_var].value
231-
pars = tuple(self.resolutions[p].value for p in self.target_par_names)
230+
X = bm.as_jax(self.resolutions[self.x_var])
231+
pars = tuple(bm.as_jax(self.resolutions[p]) for p in self.target_par_names)
232232
mesh_values = jnp.meshgrid(*((X,) + pars))
233233
mesh_values = tuple(jnp.moveaxis(v, 0, 1).flatten() for v in mesh_values)
234234
candidates = mesh_values[0]

brainpy/_src/analysis/lowdim/lowdim_phase_plane.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,9 @@ def plot_fixed_point(self, with_plot=True, with_return=False, show=False,
290290

291291
if self._can_convert_to_one_eq():
292292
if self.convert_type() == C.x_by_y:
293-
candidates = self.resolutions[self.y_var].value
293+
candidates = bm.as_jax(self.resolutions[self.y_var])
294294
else:
295-
candidates = self.resolutions[self.x_var].value
295+
candidates = bm.as_jax(self.resolutions[self.x_var])
296296
else:
297297
if select_candidates == 'fx-nullcline':
298298
candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys()
File renamed without changes.

brainpy/_src/checkpoints/serialization.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,13 @@
3737
get_tensorstore_spec = None
3838

3939
from brainpy._src.math.ndarray import Array
40-
from brainpy._src.math.object_transform.base import Collector
4140
from brainpy.errors import (AlreadyExistsError,
4241
MPACheckpointingRequiredError,
4342
MPARestoreTargetRequiredError,
4443
MPARestoreDataCorruptedError,
4544
MPARestoreTypeNotMatchError,
4645
InvalidCheckpointPath,
4746
InvalidCheckpointError)
48-
from brainpy.tools import DotDict
4947
from brainpy.types import PyTree
5048

5149
__all__ = [
@@ -154,17 +152,27 @@ def from_state_dict(target, state: Dict[str, Any], name: str = '.'):
154152
A copy of the object with the restored state.
155153
"""
156154
ty = _NamedTuple if _is_namedtuple(target) else type(target)
157-
if ty not in _STATE_DICT_REGISTRY:
155+
for t in _STATE_DICT_REGISTRY.keys():
156+
if issubclass(ty, t):
157+
ty = t
158+
break
159+
else:
158160
return state
159161
ty_from_state_dict = _STATE_DICT_REGISTRY[ty][1]
160162
with _record_path(name):
161163
return ty_from_state_dict(target, state)
162164

163165

166+
164167
def to_state_dict(target) -> Dict[str, Any]:
165168
"""Returns a dictionary with the state of the given target."""
166169
ty = _NamedTuple if _is_namedtuple(target) else type(target)
167-
if ty not in _STATE_DICT_REGISTRY:
170+
171+
for t in _STATE_DICT_REGISTRY.keys():
172+
if issubclass(ty, t):
173+
ty = t
174+
break
175+
else:
168176
return target
169177

170178
ty_to_state_dict = _STATE_DICT_REGISTRY[ty][0]
@@ -269,8 +277,9 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any]):
269277

270278
register_serialization_state(Array, _array_dict_state, _restore_array)
271279
register_serialization_state(dict, _dict_state_dict, _restore_dict)
272-
register_serialization_state(DotDict, _dict_state_dict, _restore_dict)
273-
register_serialization_state(Collector, _dict_state_dict, _restore_dict)
280+
# register_serialization_state(DotDict, _dict_state_dict, _restore_dict)
281+
# register_serialization_state(Collector, _dict_state_dict, _restore_dict)
282+
# register_serialization_state(ArrayCollector, _dict_state_dict, _restore_dict)
274283
register_serialization_state(list, _list_state_dict, _restore_list)
275284
register_serialization_state(tuple,
276285
_list_state_dict,
@@ -1221,8 +1230,9 @@ def _save_main_ckpt_file2(target: bytes,
12211230
def save_pytree(
12221231
filename: str,
12231232
target: PyTree,
1224-
overwrite: bool = False,
1233+
overwrite: bool = True,
12251234
async_manager: Optional[AsyncManager] = None,
1235+
verbose: bool = True,
12261236
) -> None:
12271237
"""Save a checkpoint of the model. Suitable for single-host.
12281238
@@ -1250,12 +1260,16 @@ def save_pytree(
12501260
if defined, the save will run without blocking the main
12511261
thread. Only works for single host. Note that an ongoing save will still
12521262
block subsequent saves, to make sure overwrite/keep logic works correctly.
1263+
verbose: bool
1264+
Whether output the print information.
12531265
12541266
Returns
12551267
-------
12561268
out: str
12571269
Filename of saved checkpoint.
12581270
"""
1271+
if verbose:
1272+
print(f'Saving checkpoint into {filename}')
12591273
start_time = time.time()
12601274
# Make sure all saves are finished before the logic of checking and removing
12611275
# outdated checkpoints happens.
@@ -1284,6 +1298,7 @@ def save_main_ckpt_task():
12841298
end_time - start_time)
12851299

12861300

1301+
12871302
def multiprocess_save(
12881303
ckpt_dir: Union[str, os.PathLike],
12891304
target: PyTree,

brainpy/_src/dyn/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
channels, neurons, rates, # neuron related
99
synapses, synouts, synplast, # synapse related
1010
networks,
11-
layers, # ANN related
1211
runners,
1312
transform,
1413
)

0 commit comments

Comments
 (0)