2828 MathError )
2929from brainpy .nn .algorithms .offline import OfflineAlgorithm
3030from brainpy .nn .algorithms .online import OnlineAlgorithm
31- from brainpy .nn .constants import (PASS_SEQUENCE ,
32- DATA_PASS_FUNC ,
33- DATA_PASS_TYPES )
31+ from brainpy .nn .datatypes import (DataType , SingleData , MultipleData )
3432from brainpy .nn .graph_flow import (find_senders_and_receivers ,
3533 find_entries_and_exits ,
3634 detect_cycle ,
@@ -83,13 +81,13 @@ def feedback(self):
8381class Node (Base ):
8482 """Basic Node class for neural network building in BrainPy."""
8583
86- '''Support multiple types of data pass, including "PASS_SEQUENCE " (by default),
87- "PASS_NAME_DICT ", "PASS_NODE_DICT" and user-customized type which registered
88- by ``brainpy.nn.register_data_pass_type()`` function .
84+ '''Support multiple types of data pass, including "PassOnlyOne " (by default),
85+ "PassSequence ", "PassNameDict", etc. and user-customized type which inherits
86+ from basic "SingleData" or "MultipleData" .
8987
9088 This setting will change the feedforward/feedback input data which pass into
9189 the "call()" function and the sizes of the feedforward/feedback input data.'''
92- data_pass_type = PASS_SEQUENCE
90+ data_pass = SingleData ()
9391
9492 '''Offline fitting method.'''
9593 offline_fit_by : Union [Callable , OfflineAlgorithm ]
@@ -115,11 +113,10 @@ def __init__(
115113 self ._trainable = trainable
116114 self ._state = None # the state of the current node
117115 self ._fb_output = None # the feedback output of the current node
118- # data pass function
119- if self .data_pass_type not in DATA_PASS_FUNC :
120- raise ValueError (f'Unsupported data pass type { self .data_pass_type } . '
121- f'Only support { DATA_PASS_TYPES } ' )
122- self .data_pass_func = DATA_PASS_FUNC [self .data_pass_type ]
116+ # data pass
117+ if not isinstance (self .data_pass , DataType ):
118+ raise ValueError (f'Unsupported data pass type { type (self .data_pass )} . '
119+ f'Only support { DataType .__class__ } ' )
123120
124121 # super initialization
125122 super (Node , self ).__init__ (name = name )
@@ -129,11 +126,10 @@ def __init__(
129126 self ._feedforward_shapes = {self .name : (None ,) + tools .to_size (input_shape )}
130127
131128 def __repr__ (self ):
132- name = type (self ).__name__
133- prefix = ' ' * (len (name ) + 1 )
134- line1 = f"{ name } (name={ self .name } , forwards={ self .feedforward_shapes } , \n "
135- line2 = f"{ prefix } feedbacks={ self .feedback_shapes } , output={ self .output_shape } )"
136- return line1 + line2
129+ return (f"{ type (self ).__name__ } (name={ self .name } , "
130+ f"forwards={ self .feedforward_shapes } , "
131+ f"feedbacks={ self .feedback_shapes } , "
132+ f"output={ self .output_shape } )" )
137133
138134 def __call__ (self , * args , ** kwargs ) -> Tensor :
139135 """The main computation function of a Node.
@@ -298,7 +294,7 @@ def trainable(self, value: bool):
298294 @property
299295 def feedforward_shapes (self ):
300296 """Input data size."""
301- return self .data_pass_func (self ._feedforward_shapes )
297+ return self .data_pass . filter (self ._feedforward_shapes )
302298
303299 @feedforward_shapes .setter
304300 def feedforward_shapes (self , size ):
@@ -324,7 +320,7 @@ def set_feedforward_shapes(self, feedforward_shapes: Dict):
324320 @property
325321 def feedback_shapes (self ):
326322 """Output data size."""
327- return self .data_pass_func (self ._feedback_shapes )
323+ return self .data_pass . filter (self ._feedback_shapes )
328324
329325 @feedback_shapes .setter
330326 def feedback_shapes (self , size ):
@@ -530,8 +526,8 @@ def _check_inputs(self, ff, fb=None):
530526 f'batch size by ".initialize(num_batch)", or change the data '
531527 f'consistent with the data batch size { self .state .shape [0 ]} .' )
532528 # data
533- ff = self .data_pass_func (ff )
534- fb = self .data_pass_func (fb )
529+ ff = self .data_pass . filter (ff )
530+ fb = self .data_pass . filter (fb )
535531 return ff , fb
536532
537533 def _call (self ,
@@ -747,6 +743,8 @@ def set_state(self, state):
747743class Network (Node ):
748744 """Basic Network class for neural network building in BrainPy."""
749745
746+ data_pass = MultipleData ('sequence' )
747+
750748 def __init__ (self ,
751749 nodes : Optional [Sequence [Node ]] = None ,
752750 ff_edges : Optional [Sequence [Tuple [Node ]]] = None ,
@@ -1145,8 +1143,8 @@ def _check_inputs(self, ff, fb=None):
11451143 check_shape_except_batch (size , fb [k ].shape )
11461144
11471145 # data transformation
1148- ff = self .data_pass_func (ff )
1149- fb = self .data_pass_func (fb )
1146+ ff = self .data_pass . filter (ff )
1147+ fb = self .data_pass . filter (fb )
11501148 return ff , fb
11511149
11521150 def _call (self ,
@@ -1208,12 +1206,12 @@ def _call(self,
12081206 def _call_a_node (self , node , ff , fb , monitors , forced_states ,
12091207 parent_outputs , children_queue , ff_senders ,
12101208 ** shared_kwargs ):
1211- ff = node .data_pass_func (ff )
1209+ ff = node .data_pass . filter (ff )
12121210 if f'{ node .name } .inputs' in monitors :
12131211 monitors [f'{ node .name } .inputs' ] = ff
12141212 # get the output results
12151213 if len (fb ):
1216- fb = node .data_pass_func (fb )
1214+ fb = node .data_pass . filter (fb )
12171215 if f'{ node .name } .feedbacks' in monitors :
12181216 monitors [f'{ node .name } .feedbacks' ] = fb
12191217 parent_outputs [node ] = node .forward (ff , fb , ** shared_kwargs )
@@ -1440,7 +1438,7 @@ def plot_node_graph(self,
14401438 if len (nodes_untrainable ):
14411439 proxie .append (Line2D ([], [], color = 'white' , marker = 'o' ,
14421440 markerfacecolor = untrainable_color ))
1443- labels .append ('Untrainable ' )
1441+ labels .append ('Nontrainable ' )
14441442 if len (ff_edges ):
14451443 proxie .append (Line2D ([], [], color = ff_color , linewidth = 2 ))
14461444 labels .append ('Feedforward' )
0 commit comments