Skip to content

Commit 87d63a6

Browse files
authored
TripleCartesianProd and QuadrupleCartesianProd support mini-batch for both branch and trunk nets (#977)
1 parent 3101bcb commit 87d63a6

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

deepxde/data/quadruple.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,29 @@ def __init__(self, X_train, y_train, X_test, y_test):
7272
self.train_x, self.train_y = X_train, y_train
7373
self.test_x, self.test_y = X_test, y_test
7474

75-
self.train_sampler = BatchSampler(len(X_train[0]), shuffle=True)
75+
self.branch_sampler = BatchSampler(len(X_train[0]), shuffle=True)
76+
self.trunk_sampler = BatchSampler(len(X_train[2]), shuffle=True)
7677

7778
def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
7879
return loss_fn(targets, outputs)
7980

8081
def train_next_batch(self, batch_size=None):
8182
if batch_size is None:
8283
return self.train_x, self.train_y
83-
indices = self.train_sampler.get_next(batch_size)
84+
if not isinstance(batch_size, (tuple, list)):
85+
indices = self.branch_sampler.get_next(batch_size)
86+
return (
87+
self.train_x[0][indices],
88+
self.train_x[1][indices],
89+
self.train_x[2],
90+
), self.train_y[indices]
91+
indices_branch = self.branch_sampler.get_next(batch_size[0])
92+
indices_trunk = self.trunk_sampler.get_next(batch_size[1])
8493
return (
85-
self.train_x[0][indices],
86-
self.train_x[1][indices],
87-
self.train_x[2],
88-
), self.train_y[indices]
94+
self.train_x[0][indices_branch],
95+
self.train_x[1][indices_branch],
96+
self.train_x[2][indices_trunk],
97+
), self.train_y[indices_branch, indices_trunk]
8998

9099
def test(self):
91100
return self.test_x, self.test_y

deepxde/data/triple.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ class TripleCartesianProd(Data):
5454
5555
Args:
5656
X_train: A tuple of two NumPy arrays. The first element has the shape (`N1`,
57-
`dim1`), and the second element has the shape (`N2`, `dim2`). The mini-batch
58-
is only applied to `N1`.
57+
`dim1`), and the second element has the shape (`N2`, `dim2`).
5958
y_train: A NumPy array of shape (`N1`, `N2`).
6059
"""
6160

@@ -71,16 +70,24 @@ def __init__(self, X_train, y_train, X_test, y_test):
7170
self.train_x, self.train_y = X_train, y_train
7271
self.test_x, self.test_y = X_test, y_test
7372

74-
self.train_sampler = BatchSampler(len(X_train[0]), shuffle=True)
73+
self.branch_sampler = BatchSampler(len(X_train[0]), shuffle=True)
74+
self.trunk_sampler = BatchSampler(len(X_train[1]), shuffle=True)
7575

7676
def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
7777
return loss_fn(targets, outputs)
7878

7979
def train_next_batch(self, batch_size=None):
8080
if batch_size is None:
8181
return self.train_x, self.train_y
82-
indices = self.train_sampler.get_next(batch_size)
83-
return (self.train_x[0][indices], self.train_x[1]), self.train_y[indices]
82+
if not isinstance(batch_size, (tuple, list)):
83+
indices = self.branch_sampler.get_next(batch_size)
84+
return (self.train_x[0][indices], self.train_x[1]), self.train_y[indices]
85+
indices_branch = self.branch_sampler.get_next(batch_size[0])
86+
indices_trunk = self.trunk_sampler.get_next(batch_size[1])
87+
return (
88+
self.train_x[0][indices_branch],
89+
self.train_x[1][indices_trunk],
90+
), self.train_y[indices_branch, indices_trunk]
8491

8592
def test(self):
8693
return self.test_x, self.test_y

deepxde/model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,16 @@ def train(
533533
Args:
534534
iterations (Integer): Number of iterations to train the model, i.e., number
535535
of times the network weights are updated.
536-
batch_size: Integer or ``None``. If you solve PDEs via ``dde.data.PDE`` or
537-
``dde.data.TimePDE``, do not use `batch_size`, and instead use
538-
`dde.callbacks.PDEResidualResampler
539-
<https://deepxde.readthedocs.io/en/latest/modules/deepxde.html#deepxde.callbacks.PDEResidualResampler>`_,
540-
see an `example <https://github.com/lululxvi/deepxde/blob/master/examples/diffusion_1d_resample.py>`_.
536+
batch_size: Integer, tuple, or ``None``.
537+
538+
- If you solve PDEs via ``dde.data.PDE`` or ``dde.data.TimePDE``, do not use `batch_size`, and instead use
539+
`dde.callbacks.PDEResidualResampler
540+
<https://deepxde.readthedocs.io/en/latest/modules/deepxde.html#deepxde.callbacks.PDEResidualResampler>`_,
541+
see an `example <https://github.com/lululxvi/deepxde/blob/master/examples/diffusion_1d_resample.py>`_.
542+
- For DeepONet in the format of Cartesian product, if `batch_size` is an Integer,
543+
then it is the batch size for the branch input; if you want to also use mini-batch for the trunk net input,
544+
set `batch_size` as a tuple, where the fist number is the batch size for the branch net input
545+
and the second number is the batch size for the trunk net input.
541546
display_every (Integer): Print the loss and metrics every this steps.
542547
disregard_previous_best: If ``True``, disregard the previous saved best
543548
model.

0 commit comments

Comments
 (0)