@@ -153,3 +153,102 @@ def forward(self, input):
153
153
return self .OutputRescalingLayer (input , out )
154
154
else :
155
155
return out
156
+
157
+ class DeepSet (nn .Module ):
158
+ def __init__ (
159
+ self ,
160
+ n_in : int ,
161
+ n_out : int ,
162
+ L : int ,
163
+ phi_layers : int ,
164
+ rho_layers : int ,
165
+ phi_hidden_dim : int = 128 ,
166
+ rho_hidden_dim : int = 128 ,
167
+ phi_activator : Optional [nn .Module ] = lazy_instance (nn .ReLU ),
168
+ phi_hidden_bias : bool = True ,
169
+ phi_last_activator : Optional [nn .Module ] = lazy_instance (nn .Identity ),
170
+ phi_last_bias = True ,
171
+ rho_activator : Optional [nn .Module ] = lazy_instance (nn .ReLU ),
172
+ rho_hidden_bias : bool = True ,
173
+ rho_last_activator : Optional [nn .Module ] = lazy_instance (nn .Identity ),
174
+ rho_last_bias = True ,
175
+ OutputRescalingLayer : Optional [nn .Module ] = None ,
176
+ InputRescalingLayer : Optional [nn .Module ] = None ,
177
+ ):
178
+ """
179
+ Init method.
180
+ """
181
+ assert n_in == 1 # only supporting univariate states for now
182
+ super ().__init__ () # init the base class
183
+ self .rho = FlexibleSequential (
184
+ L ,
185
+ n_out ,
186
+ rho_layers ,
187
+ rho_hidden_dim ,
188
+ rho_activator ,
189
+ rho_hidden_bias ,
190
+ rho_last_activator ,
191
+ rho_last_bias ,
192
+ OutputRescalingLayer = OutputRescalingLayer ,
193
+ )
194
+
195
+ self .phi = FlexibleSequential (
196
+ n_in ,
197
+ L ,
198
+ phi_layers ,
199
+ phi_hidden_dim ,
200
+ phi_activator ,
201
+ phi_hidden_bias ,
202
+ phi_last_activator ,
203
+ phi_last_bias ,
204
+ InputRescalingLayer = InputRescalingLayer ,
205
+ )
206
+
207
+ def forward (self , X ):
208
+ num_batches , N = X .shape
209
+ phi_X = torch .stack (
210
+ [torch .mean (self .phi (X [i , :].reshape ([N , 1 ])), 0 ) for i in range (num_batches )]
211
+ )
212
+ return self .rho (phi_X )
213
+
214
+
215
+
216
+ class DeepSetMoments (nn .Module ):
217
+ def __init__ (
218
+ self ,
219
+ n_in : int ,
220
+ n_out : int ,
221
+ L : int ,
222
+ rho_layers : int ,
223
+ rho_hidden_dim : int = 128 ,
224
+ rho_activator : Optional [nn .Module ] = lazy_instance (nn .ReLU ),
225
+ rho_hidden_bias : bool = True ,
226
+ rho_last_activator : Optional [nn .Module ] = lazy_instance (nn .Identity ),
227
+ rho_last_bias = True ,
228
+ OutputRescalingLayer : Optional [nn .Module ] = None ,
229
+ ):
230
+ """
231
+ Init method.
232
+ """
233
+ assert n_in == 1 # only supporting univariate states for now
234
+ super ().__init__ () # init the base class
235
+ self .rho = FlexibleSequential (
236
+ L ,
237
+ n_out ,
238
+ rho_layers ,
239
+ rho_hidden_dim ,
240
+ rho_activator ,
241
+ rho_hidden_bias ,
242
+ rho_last_activator ,
243
+ rho_last_bias ,
244
+ OutputRescalingLayer = OutputRescalingLayer ,
245
+ )
246
+
247
+ self .phi = Moments (L )
248
+
249
+ def forward (self , X ):
250
+ num_batches , N = X .shape
251
+ phi_X = torch .stack (
252
+ [torch .mean (self .phi (X [i , :].reshape ([N , 1 ])), 0 ) for i in range (num_batches )]
253
+ )
254
+ return self .rho (phi_X )
0 commit comments