@@ -248,19 +248,41 @@ class PointSetOperatorBC:
248248 and outputs a tensor of size `N x 1`, where `N` is the length of
249249 `inputs`. `inputs` and `outputs` are the network input and output
250250 tensors, respectively; `X` are the NumPy array of the `inputs`.
251+ batch_size: The number of points per minibatch, or `None` to return all points.
252+ This is only supported for the backend PyTorch and PaddlePaddle.
253+ Note, If you want to use batch size here, you should also set callback
254+ 'dde.callbacks.PDEPointResampler(bc_points=True)' in training.
255+ shuffle: Randomize the order on each pass through the data when batching.
251256 """
252257
253- def __init__ (self , points , values , func ):
258+ def __init__ (self , points , values , func , batch_size = None , shuffle = True ):
254259 self .points = np .array (points , dtype = config .real (np ))
255260 if not isinstance (values , numbers .Number ) and values .shape [1 ] != 1 :
256261 raise RuntimeError ("PointSetOperatorBC should output 1D values" )
257262 self .values = bkd .as_tensor (values , dtype = config .real (bkd .lib ))
258263 self .func = func
264+ self .batch_size = batch_size
265+
266+ if batch_size is not None : # batch iterator and state
267+ if backend_name not in ["pytorch" , "paddle" ]:
268+ raise RuntimeError (
269+ "batch_size only implemented for pytorch and paddle backend"
270+ )
271+ self .batch_sampler = data .sampler .BatchSampler (len (self ), shuffle = shuffle )
272+ self .batch_indices = None
273+
274+ def __len__ (self ):
275+ return self .points .shape [0 ]
259276
260277 def collocation_points (self , X ):
278+ if self .batch_size is not None :
279+ self .batch_indices = self .batch_sampler .get_next (self .batch_size )
280+ return self .points [self .batch_indices ]
261281 return self .points
262282
263283 def error (self , X , inputs , outputs , beg , end , aux_var = None ):
284+ if self .batch_size is not None :
285+ return self .func (inputs , outputs , X )[beg :end ] - self .values [self .batch_indices ]
264286 return self .func (inputs , outputs , X )[beg :end ] - self .values
265287
266288
0 commit comments