Skip to content

Commit 2e2d6f4

Browse files
authored
Merge pull request #316 from chaoming0625/master
Ready for publish
2 parents dbdc5c6 + 7cf0ad3 commit 2e2d6f4

File tree

9 files changed

+300
-64
lines changed

9 files changed

+300
-64
lines changed

brainpy/math/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@
5353
from . import surrogate
5454
from .surrogate.compt import *
5555

56-
# JAX transformations for Variable and class objects
56+
# Variable and Objects for object-oriented JAX transformations
5757
from .object_transform import *
5858

59-
6059
# environment settings
6160
from .modes import *
6261
from .environment import *

brainpy/math/ndarray.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33

44
import warnings
5-
from typing import Optional, Tuple
5+
from typing import Optional, Tuple as TupleType
66

77
import numpy as np
88
from jax import numpy as jnp
99
from jax.tree_util import register_pytree_node
1010

11+
1112
from brainpy.errors import MathError
1213

1314
__all__ = [
@@ -997,7 +998,7 @@ def __init__(
997998
f'but the batch axis is set to be {batch_axis}.')
998999

9991000
@property
1000-
def shape_nb(self) -> Tuple[int, ...]:
1001+
def shape_nb(self) -> TupleType[int, ...]:
10011002
"""Shape without batch axis."""
10021003
shape = list(self.value.shape)
10031004
if self.batch_axis is not None:
@@ -1562,7 +1563,6 @@ class BatchVariable(Variable):
15621563
pass
15631564

15641565

1565-
15661566
class VariableView(Variable):
15671567
"""A view of a Variable instance.
15681568
@@ -1742,7 +1742,7 @@ def _jaxarray_unflatten(aux_data, flat_contents):
17421742

17431743

17441744
register_pytree_node(Array,
1745-
lambda t: ((t.value,), (t._transform_context, )),
1745+
lambda t: ((t.value,), (t._transform_context,)),
17461746
_jaxarray_unflatten)
17471747

17481748
register_pytree_node(Variable,
@@ -1756,3 +1756,4 @@ def _jaxarray_unflatten(aux_data, flat_contents):
17561756
register_pytree_node(Parameter,
17571757
lambda t: ((t.value,), None),
17581758
lambda aux_data, flat_contents: Parameter(*flat_contents))
1759+

brainpy/math/object_transform/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,15 @@
3030
+ controls.__all__
3131
+ jit.__all__
3232
+ function.__all__
33+
+ base_object.__all__
34+
+ base_transform.__all__
35+
+ collector.__all__
3336
)
3437

3538
from .autograd import *
3639
from .controls import *
3740
from .jit import *
3841
from .function import *
42+
from .base_object import *
43+
from .base_transform import *
44+
from .collector import *

brainpy/math/object_transform/base_object.py

Lines changed: 119 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,39 @@
11
# -*- coding: utf-8 -*-
22

33
import os
4-
import logging
54
import warnings
65
from collections import namedtuple
76
from typing import Any, Tuple, Callable, Sequence, Dict, Union
87

8+
import jax
9+
import numpy as np
10+
from jax._src.tree_util import _registry
11+
from jax.tree_util import register_pytree_node
12+
from jax.tree_util import register_pytree_node_class
13+
from jax.util import safe_zip
14+
915
from brainpy import errors
1016
from .collector import Collector, ArrayCollector
11-
from ..ndarray import Variable, VariableView, TrainVar
17+
from ..ndarray import (Array,
18+
Variable,
19+
VariableView,
20+
TrainVar)
1221

1322
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
1423

24+
1525
__all__ = [
16-
'check_name_uniqueness',
17-
'get_unique_name',
18-
'clear_name_cache',
26+
# naming
27+
'check_name_uniqueness', 'get_unique_name', 'clear_name_cache',
1928

29+
# objects
2030
'BrainPyObject', 'Base', 'FunAsObject',
31+
32+
# variables
33+
'numerical_seq', 'object_seq',
34+
'numerical_dict', 'object_dict',
2135
]
2236

23-
logger = logging.getLogger('brainpy.brainpy_object')
2437

2538
_name2id = dict()
2639
_typed_names = {}
@@ -59,7 +72,7 @@ def clear_name_cache(ignore_warn=False):
5972
_name2id.clear()
6073
_typed_names.clear()
6174
if not ignore_warn:
62-
logger.warning(f'All named models and their ids are cleared.')
75+
warnings.warn(f'All named models and their ids are cleared.', UserWarning)
6376

6477

6578
class BrainPyObject(object):
@@ -78,6 +91,11 @@ class BrainPyObject(object):
7891
_excluded_vars = ()
7992

8093
def __init__(self, name=None):
94+
super().__init__()
95+
cls = self.__class__
96+
if cls not in _registry:
97+
register_pytree_node_class(cls)
98+
8199
# check whether the object has a unique name.
82100
self._name = None
83101
self._name = self.unique_name(name=name)
@@ -91,15 +109,17 @@ def __init__(self, name=None):
91109
# which cannot be accessed by self.xxx
92110
self.implicit_nodes = Collector()
93111

94-
def __setattr__(self, key, value) -> None:
95-
"""Overwrite __setattr__ method for non-changeable Variable setting.
112+
def __setattr__(self, key: str, value: Any) -> None:
113+
"""Overwrite `__setattr__` method for change Variable values.
96114
97115
.. versionadded:: 2.3.1
98116
99117
Parameters
100118
----------
101119
key: str
120+
The attribute.
102121
value: Any
122+
The value.
103123
"""
104124
if key in self.__dict__:
105125
val = self.__dict__[key]
@@ -109,19 +129,24 @@ def __setattr__(self, key, value) -> None:
109129
super().__setattr__(key, value)
110130

111131
def tree_flatten(self):
112-
"""
132+
"""Flattens the object as a PyTree.
133+
134+
The flattening order is determined by attributes added order.
135+
113136
.. versionadded:: 2.3.1
114137
115138
Returns
116139
-------
117-
140+
res: tuple
141+
A tuple of dynamical values and static values.
118142
"""
143+
dts = (BrainPyObject,) + tuple(dynamical_types)
119144
dynamic_names = []
120145
dynamic_values = []
121146
static_names = []
122147
static_values = []
123148
for k, v in self.__dict__.items():
124-
if isinstance(v, (ArrayCollector, BrainPyObject, Variable)):
149+
if isinstance(v, dts):
125150
dynamic_names.append(k)
126151
dynamic_values.append(v)
127152
else:
@@ -531,3 +556,85 @@ def __repr__(self) -> str:
531556
node_string = ", \n".join(nodes)
532557
return (f'{name}(nodes=[{node_string}],\n' +
533558
" " * (len(name) + 1) + f'num_of_vars={len(self.implicit_vars)})')
559+
560+
561+
class numerical_seq(list):
562+
"""A list to represent a dynamically changed numerical
563+
sequence in which its element can be changed during JIT compilation.
564+
565+
.. note::
566+
The element must be numerical, like ``bool``, ``int``, ``float``,
567+
``jax.Array``, ``numpy.ndarray``, ``brainpy.math.Array``.
568+
"""
569+
def append(self, element):
570+
if not isinstance(element, (bool, int, float, jax.Array, Array, np.ndarray)):
571+
raise TypeError(f'Each element should be a numerical value.')
572+
573+
def extend(self, iterable) -> None:
574+
for element in iterable:
575+
self.append(element)
576+
577+
578+
register_pytree_node(numerical_seq,
579+
lambda x: (tuple(x), ()),
580+
lambda _, values: numerical_seq(values))
581+
582+
583+
class object_seq(list):
584+
"""A list to represent a sequence of :py:class:`~.BrainPyObject`.
585+
586+
.. note::
587+
The element must be :py:class:`~.BrainPyObject`.
588+
"""
589+
def append(self, element):
590+
if not isinstance(element, BrainPyObject):
591+
raise TypeError(f'Only support {BrainPyObject.__name__}')
592+
593+
def extend(self, iterable) -> None:
594+
for element in iterable:
595+
self.append(element)
596+
597+
598+
register_pytree_node(object_seq,
599+
lambda x: (tuple(x), ()),
600+
lambda _, values: object_seq(values))
601+
602+
603+
class numerical_dict(dict):
604+
"""A dict to represent a dynamically changed numerical
605+
dictionary in which its element can be changed during JIT compilation.
606+
607+
.. note::
608+
Each key must be a string, and each value must be numerical, including
609+
``bool``, ``int``, ``float``, ``jax.Array``, ``numpy.ndarray``,
610+
``brainpy.math.Array``.
611+
"""
612+
def update(self, *args, **kwargs) -> 'numerical_dict':
613+
super().update(*args, **kwargs)
614+
return self
615+
616+
617+
register_pytree_node(numerical_dict,
618+
lambda x: (tuple(x.values()), tuple(x.keys())),
619+
lambda keys, values: numerical_dict(safe_zip(keys, values)))
620+
621+
622+
class object_dict(dict):
623+
"""A dict to represent a dictionary of :py:class:`~.BrainPyObject`.
624+
625+
.. note::
626+
Each key must be a string, and each value must be :py:class:`~.BrainPyObject`.
627+
"""
628+
def update(self, *args, **kwargs) -> 'object_dict':
629+
super().update(*args, **kwargs)
630+
return self
631+
632+
633+
register_pytree_node(object_dict,
634+
lambda x: (tuple(x.values()), tuple(x.keys())),
635+
lambda keys, values: object_dict(safe_zip(keys, values)))
636+
637+
dynamical_types = [Variable,
638+
numerical_seq, numerical_dict,
639+
object_seq, object_dict]
640+

0 commit comments

Comments
 (0)