Skip to content

Commit 9906ccc

Browse files
authored
Add mini-batch for PointSetOperatorBC (#1997)
1 parent a20ad97 commit 9906ccc

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

deepxde/icbc/boundary_conditions.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,19 +248,41 @@ class PointSetOperatorBC:
248248
and outputs a tensor of size `N x 1`, where `N` is the length of
249249
`inputs`. `inputs` and `outputs` are the network input and output
250250
tensors, respectively; `X` are the NumPy array of the `inputs`.
251+
batch_size: The number of points per minibatch, or `None` to return all points.
252+
This is only supported for the backend PyTorch and PaddlePaddle.
253+
Note, If you want to use batch size here, you should also set callback
254+
'dde.callbacks.PDEPointResampler(bc_points=True)' in training.
255+
shuffle: Randomize the order on each pass through the data when batching.
251256
"""
252257

253-
def __init__(self, points, values, func):
258+
def __init__(self, points, values, func, batch_size=None, shuffle=True):
254259
self.points = np.array(points, dtype=config.real(np))
255260
if not isinstance(values, numbers.Number) and values.shape[1] != 1:
256261
raise RuntimeError("PointSetOperatorBC should output 1D values")
257262
self.values = bkd.as_tensor(values, dtype=config.real(bkd.lib))
258263
self.func = func
264+
self.batch_size = batch_size
265+
266+
if batch_size is not None: # batch iterator and state
267+
if backend_name not in ["pytorch", "paddle"]:
268+
raise RuntimeError(
269+
"batch_size only implemented for pytorch and paddle backend"
270+
)
271+
self.batch_sampler = data.sampler.BatchSampler(len(self), shuffle=shuffle)
272+
self.batch_indices = None
273+
274+
def __len__(self):
275+
return self.points.shape[0]
259276

260277
def collocation_points(self, X):
278+
if self.batch_size is not None:
279+
self.batch_indices = self.batch_sampler.get_next(self.batch_size)
280+
return self.points[self.batch_indices]
261281
return self.points
262282

263283
def error(self, X, inputs, outputs, beg, end, aux_var=None):
284+
if self.batch_size is not None:
285+
return self.func(inputs, outputs, X)[beg:end] - self.values[self.batch_indices]
264286
return self.func(inputs, outputs, X)[beg:end] - self.values
265287

266288

0 commit comments

Comments
 (0)