11# -*- coding: utf-8 -*-
22
33import os
4- import logging
54import warnings
65from collections import namedtuple
76from typing import Any , Tuple , Callable , Sequence , Dict , Union
87
8+ import jax
9+ import numpy as np
10+ from jax ._src .tree_util import _registry
11+ from jax .tree_util import register_pytree_node
12+ from jax .tree_util import register_pytree_node_class
13+ from jax .util import safe_zip
14+
915from brainpy import errors
1016from .collector import Collector , ArrayCollector
11- from ..ndarray import Variable , VariableView , TrainVar
17+ from ..ndarray import (Array ,
18+ Variable ,
19+ VariableView ,
20+ TrainVar )
1221
1322StateLoadResult = namedtuple ('StateLoadResult' , ['missing_keys' , 'unexpected_keys' ])
1423
24+
1525__all__ = [
16- 'check_name_uniqueness' ,
17- 'get_unique_name' ,
18- 'clear_name_cache' ,
26+ # naming
27+ 'check_name_uniqueness' , 'get_unique_name' , 'clear_name_cache' ,
1928
29+ # objects
2030 'BrainPyObject' , 'Base' , 'FunAsObject' ,
31+
32+ # variables
33+ 'numerical_seq' , 'object_seq' ,
34+ 'numerical_dict' , 'object_dict' ,
2135]
2236
23- logger = logging .getLogger ('brainpy.brainpy_object' )
2437
2538_name2id = dict ()
2639_typed_names = {}
@@ -59,7 +72,7 @@ def clear_name_cache(ignore_warn=False):
5972 _name2id .clear ()
6073 _typed_names .clear ()
6174 if not ignore_warn :
62- logger . warning (f'All named models and their ids are cleared.' )
75+ warnings . warn (f'All named models and their ids are cleared.' , UserWarning )
6376
6477
6578class BrainPyObject (object ):
@@ -78,6 +91,11 @@ class BrainPyObject(object):
7891 _excluded_vars = ()
7992
8093 def __init__ (self , name = None ):
94+ super ().__init__ ()
95+ cls = self .__class__
96+ if cls not in _registry :
97+ register_pytree_node_class (cls )
98+
8199 # check whether the object has a unique name.
82100 self ._name = None
83101 self ._name = self .unique_name (name = name )
@@ -91,15 +109,17 @@ def __init__(self, name=None):
91109 # which cannot be accessed by self.xxx
92110 self .implicit_nodes = Collector ()
93111
94- def __setattr__ (self , key , value ) -> None :
95- """Overwrite __setattr__ method for non-changeable Variable setting .
112+ def __setattr__ (self , key : str , value : Any ) -> None :
113+ """Overwrite ` __setattr__` method for change Variable values .
96114
97115 .. versionadded:: 2.3.1
98116
99117 Parameters
100118 ----------
101119 key: str
120+ The attribute.
102121 value: Any
122+ The value.
103123 """
104124 if key in self .__dict__ :
105125 val = self .__dict__ [key ]
@@ -109,19 +129,24 @@ def __setattr__(self, key, value) -> None:
109129 super ().__setattr__ (key , value )
110130
111131 def tree_flatten (self ):
112- """
132+ """Flattens the object as a PyTree.
133+
134+ The flattening order is determined by attributes added order.
135+
113136 .. versionadded:: 2.3.1
114137
115138 Returns
116139 -------
117-
140+ res: tuple
141+ A tuple of dynamical values and static values.
118142 """
143+ dts = (BrainPyObject ,) + tuple (dynamical_types )
119144 dynamic_names = []
120145 dynamic_values = []
121146 static_names = []
122147 static_values = []
123148 for k , v in self .__dict__ .items ():
124- if isinstance (v , ( ArrayCollector , BrainPyObject , Variable ) ):
149+ if isinstance (v , dts ):
125150 dynamic_names .append (k )
126151 dynamic_values .append (v )
127152 else :
@@ -531,3 +556,85 @@ def __repr__(self) -> str:
531556 node_string = ", \n " .join (nodes )
532557 return (f'{ name } (nodes=[{ node_string } ],\n ' +
533558 " " * (len (name ) + 1 ) + f'num_of_vars={ len (self .implicit_vars )} )' )
559+
560+
561+ class numerical_seq (list ):
562+ """A list to represent a dynamically changed numerical
563+ sequence in which its element can be changed during JIT compilation.
564+
565+ .. note::
566+ The element must be numerical, like ``bool``, ``int``, ``float``,
567+ ``jax.Array``, ``numpy.ndarray``, ``brainpy.math.Array``.
568+ """
569+ def append (self , element ):
570+ if not isinstance (element , (bool , int , float , jax .Array , Array , np .ndarray )):
571+ raise TypeError (f'Each element should be a numerical value.' )
572+
573+ def extend (self , iterable ) -> None :
574+ for element in iterable :
575+ self .append (element )
576+
577+
578+ register_pytree_node (numerical_seq ,
579+ lambda x : (tuple (x ), ()),
580+ lambda _ , values : numerical_seq (values ))
581+
582+
583+ class object_seq (list ):
584+ """A list to represent a sequence of :py:class:`~.BrainPyObject`.
585+
586+ .. note::
587+ The element must be :py:class:`~.BrainPyObject`.
588+ """
589+ def append (self , element ):
590+ if not isinstance (element , BrainPyObject ):
591+ raise TypeError (f'Only support { BrainPyObject .__name__ } ' )
592+
593+ def extend (self , iterable ) -> None :
594+ for element in iterable :
595+ self .append (element )
596+
597+
598+ register_pytree_node (object_seq ,
599+ lambda x : (tuple (x ), ()),
600+ lambda _ , values : object_seq (values ))
601+
602+
603+ class numerical_dict (dict ):
604+ """A dict to represent a dynamically changed numerical
605+ dictionary in which its element can be changed during JIT compilation.
606+
607+ .. note::
608+ Each key must be a string, and each value must be numerical, including
609+ ``bool``, ``int``, ``float``, ``jax.Array``, ``numpy.ndarray``,
610+ ``brainpy.math.Array``.
611+ """
612+ def update (self , * args , ** kwargs ) -> 'numerical_dict' :
613+ super ().update (* args , ** kwargs )
614+ return self
615+
616+
617+ register_pytree_node (numerical_dict ,
618+ lambda x : (tuple (x .values ()), tuple (x .keys ())),
619+ lambda keys , values : numerical_dict (safe_zip (keys , values )))
620+
621+
622+ class object_dict (dict ):
623+ """A dict to represent a dictionary of :py:class:`~.BrainPyObject`.
624+
625+ .. note::
626+ Each key must be a string, and each value must be :py:class:`~.BrainPyObject`.
627+ """
628+ def update (self , * args , ** kwargs ) -> 'object_dict' :
629+ super ().update (* args , ** kwargs )
630+ return self
631+
632+
633+ register_pytree_node (object_dict ,
634+ lambda x : (tuple (x .values ()), tuple (x .keys ())),
635+ lambda keys , values : object_dict (safe_zip (keys , values )))
636+
637+ dynamical_types = [Variable ,
638+ numerical_seq , numerical_dict ,
639+ object_seq , object_dict ]
640+
0 commit comments