Skip to content

Commit cba095a

Browse files
committed
add mini-batch for PointSetOperatorBC
1 parent a60cd74 commit cba095a

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

deepxde/icbc/boundary_conditions.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,19 +248,37 @@ class PointSetOperatorBC:
248248
and outputs a tensor of size `N x 1`, where `N` is the length of
249249
`inputs`. `inputs` and `outputs` are the network input and output
250250
tensors, respectively; `X` are the NumPy array of the `inputs`.
251+
batch_size: The number of points per minibatch, or `None` to return all points.
252+
This is only supported for the backend PyTorch and PaddlePaddle.
253+
Note, If you want to use batch size here, you should also set callback
254+
'dde.callbacks.PDEPointResampler(bc_points=True)' in training.
255+
shuffle: Randomize the order on each pass through the data when batching.
251256
"""
252257

253-
def __init__(self, points, values, func):
258+
def __init__(self, points, values, func, batch_size=None, shuffle=None):
254259
self.points = np.array(points, dtype=config.real(np))
255260
if not isinstance(values, numbers.Number) and values.shape[1] != 1:
256261
raise RuntimeError("PointSetOperatorBC should output 1D values")
257262
self.values = bkd.as_tensor(values, dtype=config.real(bkd.lib))
258263
self.func = func
259264

265+
if batch_size is not None: # batch iterator and state
266+
if backend_name not in ["pytorch", "paddle"]:
267+
raise RuntimeError(
268+
"batch_size only implemented for pytorch and paddle backend"
269+
)
270+
self.batch_sampler = data.sampler.BatchSampler(len(self), shuffle=shuffle)
271+
self.batch_indices = None
272+
260273
def collocation_points(self, X):
274+
if self.batch_size is not None:
275+
self.batch_indices = self.batch_sampler.get_next(self.batch_size)
276+
return self.points[self.batch_indices]
261277
return self.points
262278

263279
def error(self, X, inputs, outputs, beg, end, aux_var=None):
280+
if self.batch_size is not None:
281+
return self.func(inputs, outputs, X)[beg:end] - self.values[self.batch_indices]
264282
return self.func(inputs, outputs, X)[beg:end] - self.values
265283

266284

0 commit comments

Comments
 (0)