Skip to content

Commit e4145ea

Browse files
committed
fix missing batch_size in PointSetOperator, add len() function
1 parent dae9d0c commit e4145ea

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

deepxde/icbc/boundary_conditions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def __init__(self, points, values, func, batch_size=None, shuffle=True):
261261
raise RuntimeError("PointSetOperatorBC should output 1D values")
262262
self.values = bkd.as_tensor(values, dtype=config.real(bkd.lib))
263263
self.func = func
264+
self.batch_size = batch_size
264265

265266
if batch_size is not None: # batch iterator and state
266267
if backend_name not in ["pytorch", "paddle"]:
@@ -270,6 +271,9 @@ def __init__(self, points, values, func, batch_size=None, shuffle=True):
270271
self.batch_sampler = data.sampler.BatchSampler(len(self), shuffle=shuffle)
271272
self.batch_indices = None
272273

274+
def __len__(self):
275+
return self.points.shape[0]
276+
273277
def collocation_points(self, X):
274278
if self.batch_size is not None:
275279
self.batch_indices = self.batch_sampler.get_next(self.batch_size)

0 commit comments

Comments
 (0)