Skip to content

Commit 40c2632

Browse files
authored
Merge pull request #542 from chaoming0625/master
[dyn] add `save_state`, `load_state`, `reset_state`, and `clear_input` helpers
2 parents 8e201f6 + 178a7cc commit 40c2632

File tree

9 files changed

+119
-48
lines changed

9 files changed

+119
-48
lines changed

brainpy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@
7878
# shared parameters
7979
from brainpy._src.context import (share as share)
8080
from brainpy._src.helpers import (reset_state as reset_state,
81+
save_state as save_state,
82+
load_state as load_state,
8183
clear_input as clear_input)
8284

8385

brainpy/_src/dependency_check.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,21 @@
1212
_minimal_taichi_version = (1, 7, 0)
1313

1414
taichi = None
15-
has_import_ti = False
1615
brainpylib_cpu_ops = None
1716
brainpylib_gpu_ops = None
1817

1918

2019
def import_taichi():
21-
global taichi, has_import_ti
22-
if not has_import_ti:
20+
global taichi
21+
if taichi is None:
2322
try:
2423
import taichi as taichi # noqa
25-
has_import_ti = True
2624
except ModuleNotFoundError:
2725
raise ModuleNotFoundError(
2826
'Taichi is needed. Please install taichi through:\n\n'
2927
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
3028
)
3129

32-
if taichi is None:
33-
raise ModuleNotFoundError(
34-
'Taichi is needed. Please install taichi through:\n\n'
35-
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
36-
)
3730
if taichi.__version__ < _minimal_taichi_version:
3831
raise RuntimeError(
3932
f'We need taichi>={_minimal_taichi_version}. '

brainpy/_src/helpers.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1-
from .dynsys import DynamicalSystem, DynView
1+
from typing import Dict
2+
23
from brainpy._src.dyn.base import IonChaDyn
4+
from brainpy._src.dynsys import DynamicalSystem, DynView
5+
from brainpy._src.math.object_transform.base import StateLoadResult
6+
37

48
__all__ = [
59
'reset_state',
10+
'load_state',
11+
'save_state',
612
'clear_input',
713
]
814

@@ -30,3 +36,44 @@ def clear_input(target: DynamicalSystem, *args, **kwargs):
3036
"""
3137
for node in target.nodes().subset(DynamicalSystem).not_subset(DynView).unique().values():
3238
node.clear_input(*args, **kwargs)
39+
40+
41+
def load_state(target: DynamicalSystem, state_dict: Dict, **kwargs):
42+
"""Copy parameters and buffers from :attr:`state_dict` into
43+
this module and its descendants.
44+
45+
Args:
46+
target: DynamicalSystem. The dynamical system to load its states.
47+
state_dict: dict. A dict containing parameters and persistent buffers.
48+
49+
Returns:
50+
-------
51+
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
52+
53+
* **missing_keys** is a list of str containing the missing keys
54+
* **unexpected_keys** is a list of str containing the unexpected keys
55+
"""
56+
nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique()
57+
missing_keys = []
58+
unexpected_keys = []
59+
for name, node in nodes.items():
60+
r = node.load_state(state_dict[name], **kwargs)
61+
if r is not None:
62+
missing, unexpected = r
63+
missing_keys.extend([f'{name}.{key}' for key in missing])
64+
unexpected_keys.extend([f'{name}.{key}' for key in unexpected])
65+
return StateLoadResult(missing_keys, unexpected_keys)
66+
67+
68+
def save_state(target: DynamicalSystem, **kwargs) -> Dict:
69+
"""Save all states in the ``target`` as a dictionary for later disk serialization.
70+
71+
Args:
72+
target: DynamicalSystem. The node to save its states.
73+
74+
Returns:
75+
Dict. The state dict for serialization.
76+
"""
77+
nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique() # retrieve all nodes
78+
return {key: node.save_state(**kwargs) for key, node in nodes.items()}
79+

brainpy/_src/math/object_transform/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,14 @@ def unique_name(self, name=None, type_=None):
478478
check_name_uniqueness(name=name, obj=self)
479479
return name
480480

481+
def save_state(self, **kwargs) -> Dict:
482+
"""Save states as a dictionary. """
483+
return self.__save_state__(**kwargs)
484+
485+
def load_state(self, state_dict: Dict, **kwargs) -> Optional[Tuple[Sequence[str], Sequence[str]]]:
486+
"""Load states from a dictionary."""
487+
return self.__load_state__(state_dict, **kwargs)
488+
481489
def __save_state__(self, **kwargs) -> Dict:
482490
"""Save states. """
483491
return self.vars(include_self=True, level=0).unique().dict()
@@ -502,7 +510,7 @@ def state_dict(self, **kwargs) -> dict:
502510
A dictionary containing a whole state of the module.
503511
"""
504512
nodes = self.nodes() # retrieve all nodes
505-
return {key: node.__save_state__(**kwargs) for key, node in nodes.items()}
513+
return {key: node.save_state(**kwargs) for key, node in nodes.items()}
506514

507515
def load_state_dict(
508516
self,
@@ -544,7 +552,7 @@ def load_state_dict(
544552
missing_keys = []
545553
unexpected_keys = []
546554
for name, node in nodes.items():
547-
r = node.__load_state__(state_dict[name], **kwargs)
555+
r = node.load_state(state_dict[name], **kwargs)
548556
if r is not None:
549557
missing, unexpected = r
550558
missing_keys.extend([f'{name}.{key}' for key in missing])
Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import jax
22
import jax.numpy as jnp
3-
import taichi as ti
3+
import taichi as taichi
4+
import pytest
5+
import platform
46

57
import brainpy.math as bm
68

79
bm.set_platform('cpu')
810

11+
if not platform.platform().startswith('Windows'):
12+
pytest.skip(allow_module_level=True)
13+
14+
915
# @ti.kernel
1016
# def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
1117
# vector: ti.types.ndarray(ndim=1),
@@ -19,43 +25,44 @@
1925
# for j in range(num_cols):
2026
# out[indices[i, j]] += weight_0
2127

22-
@ti.func
23-
def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:
24-
return weight[0]
28+
@taichi.func
29+
def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32:
30+
return weight[0]
2531

26-
@ti.func
27-
def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):
28-
out[index] += weight_val
2932

30-
@ti.kernel
31-
def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
32-
vector: ti.types.ndarray(ndim=1),
33-
weight: ti.types.ndarray(ndim=1),
34-
out: ti.types.ndarray(ndim=1)):
35-
weight_val = get_weight(weight)
36-
num_rows, num_cols = indices.shape
37-
ti.loop_config(serialize=True)
38-
for i in range(num_rows):
39-
if vector[i]:
40-
for j in range(num_cols):
41-
update_output(out, indices[i, j], weight_val)
33+
@taichi.func
34+
def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32):
35+
out[index] += weight_val
4236

4337

44-
prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu)
38+
@taichi.kernel
39+
def event_ell_cpu(indices: taichi.types.ndarray(ndim=2),
40+
vector: taichi.types.ndarray(ndim=1),
41+
weight: taichi.types.ndarray(ndim=1),
42+
out: taichi.types.ndarray(ndim=1)):
43+
weight_val = get_weight(weight)
44+
num_rows, num_cols = indices.shape
45+
taichi.loop_config(serialize=True)
46+
for i in range(num_rows):
47+
if vector[i]:
48+
for j in range(num_cols):
49+
update_output(out, indices[i, j], weight_val)
4550

4651

47-
# def test_taichi_op_register():
48-
# s = 1000
49-
# indices = bm.random.randint(0, s, (s, 1000))
50-
# vector = bm.random.rand(s) < 0.1
51-
# weight = bm.array([1.0])
52+
prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu)
53+
5254

53-
# out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
55+
def test_taichi_op_register():
56+
s = 1000
57+
indices = bm.random.randint(0, s, (s, 1000))
58+
vector = bm.random.rand(s) < 0.1
59+
weight = bm.array([1.0])
5460

55-
# out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
61+
out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
5662

57-
# print(out)
58-
# bm.clear_buffer_memory()
63+
out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
5964

65+
print(out)
66+
bm.clear_buffer_memory()
6067

6168
# test_taichi_op_register()

brainpy/_src/math/scales.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# -*- coding: utf-8 -*-
22

33

4+
from typing import Sequence, Union
5+
46
__all__ = [
57
'Scaling',
68
'IdScaling',
@@ -13,11 +15,20 @@ def __init__(self, scale, bias):
1315
self.bias = bias
1416

1517
@classmethod
16-
def transform(cls, V_range:list, scaled_V_range:list):
17-
'''
18-
V_range: [V_min, V_max]
19-
scaled_V_range: [scaled_V_min, scaled_V_max]
20-
'''
18+
def transform(
19+
cls,
20+
V_range: Sequence[Union[float, int]],
21+
scaled_V_range: Sequence[Union[float, int]] = (0., 1.)
22+
) -> 'Scaling':
23+
"""Transform the membrane potential range to a ``Scaling`` instance.
24+
25+
Args:
26+
V_range: [V_min, V_max]
27+
scaled_V_range: [scaled_V_min, scaled_V_max]
28+
29+
Returns:
30+
The instanced scaling object.
31+
"""
2132
V_min, V_max = V_range
2233
scaled_V_min, scaled_V_max = scaled_V_range
2334
scale = (V_max - V_min) / (scaled_V_max - scaled_V_min)

docs/apis/brainpy.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
:local:
99
:depth: 1
1010

11+
1112
Numerical Differential Integration
1213
----------------------------------
1314

@@ -77,5 +78,9 @@ Dynamical System Helpers
7778
:template: classtemplate.rst
7879

7980
LoopOverTime
81+
reset_state
82+
save_state
83+
load_state
84+
clear_input
8085

8186

requirements-dev.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ jaxlib
66
matplotlib>=3.4
77
msgpack
88
tqdm
9-
taichi
109

1110
# test requirements
1211
pytest

requirements-doc.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ jaxlib
77
matplotlib>=3.4
88
scipy>=1.1.0
99
numba
10-
taichi
1110

1211
# document requirements
1312
pandoc

0 commit comments

Comments
 (0)