Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion deepxde/icbc/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,19 +248,41 @@ class PointSetOperatorBC:
and outputs a tensor of size `N x 1`, where `N` is the length of
`inputs`. `inputs` and `outputs` are the network input and output
tensors, respectively; `X` are the NumPy array of the `inputs`.
batch_size: The number of points per minibatch, or `None` to return all points.
This is only supported for the backend PyTorch and PaddlePaddle.
Note, If you want to use batch size here, you should also set callback
'dde.callbacks.PDEPointResampler(bc_points=True)' in training.
shuffle: Randomize the order on each pass through the data when batching.
"""

def __init__(self, points, values, func):
def __init__(self, points, values, func, batch_size=None, shuffle=True):
self.points = np.array(points, dtype=config.real(np))
if not isinstance(values, numbers.Number) and values.shape[1] != 1:
raise RuntimeError("PointSetOperatorBC should output 1D values")
self.values = bkd.as_tensor(values, dtype=config.real(bkd.lib))
self.func = func
self.batch_size = batch_size

if batch_size is not None: # batch iterator and state
if backend_name not in ["pytorch", "paddle"]:
raise RuntimeError(
"batch_size only implemented for pytorch and paddle backend"
)
self.batch_sampler = data.sampler.BatchSampler(len(self), shuffle=shuffle)
self.batch_indices = None

def __len__(self):
return self.points.shape[0]

def collocation_points(self, X):
if self.batch_size is not None:
self.batch_indices = self.batch_sampler.get_next(self.batch_size)
return self.points[self.batch_indices]
return self.points

def error(self, X, inputs, outputs, beg, end, aux_var=None):
if self.batch_size is not None:
return self.func(inputs, outputs, X)[beg:end] - self.values[self.batch_indices]
return self.func(inputs, outputs, X)[beg:end] - self.values


Expand Down