@@ -248,19 +248,37 @@ class PointSetOperatorBC:
248248 and outputs a tensor of size `N x 1`, where `N` is the length of
249249 `inputs`. `inputs` and `outputs` are the network input and output
250250 tensors, respectively; `X` are the NumPy array of the `inputs`.
251+ batch_size: The number of points per minibatch, or `None` to return all points.
252+ This is only supported for the backend PyTorch and PaddlePaddle.
253+ Note, If you want to use batch size here, you should also set callback
254+ 'dde.callbacks.PDEPointResampler(bc_points=True)' in training.
255+ shuffle: Randomize the order on each pass through the data when batching.
251256 """
252257
253- def __init__ (self , points , values , func ):
258+ def __init__ (self , points , values , func , batch_size = None , shuffle = None ):
254259 self .points = np .array (points , dtype = config .real (np ))
255260 if not isinstance (values , numbers .Number ) and values .shape [1 ] != 1 :
256261 raise RuntimeError ("PointSetOperatorBC should output 1D values" )
257262 self .values = bkd .as_tensor (values , dtype = config .real (bkd .lib ))
258263 self .func = func
259264
265+ if batch_size is not None : # batch iterator and state
266+ if backend_name not in ["pytorch" , "paddle" ]:
267+ raise RuntimeError (
268+ "batch_size only implemented for pytorch and paddle backend"
269+ )
270+ self .batch_sampler = data .sampler .BatchSampler (len (self ), shuffle = shuffle )
271+ self .batch_indices = None
272+
260273 def collocation_points (self , X ):
274+ if self .batch_size is not None :
275+ self .batch_indices = self .batch_sampler .get_next (self .batch_size )
276+ return self .points [self .batch_indices ]
261277 return self .points
262278
263279 def error (self , X , inputs , outputs , beg , end , aux_var = None ):
280+ if self .batch_size is not None :
281+ return self .func (inputs , outputs , X )[beg :end ] - self .values [self .batch_indices ]
264282 return self .func (inputs , outputs , X )[beg :end ] - self .values
265283
266284
0 commit comments