Skip to content

Commit 59f996f

Browse files
committed
OpDataSet supports batch size
1 parent f516cd0 commit 59f996f

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

deepxde/data/op_dataset.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sklearn import preprocessing
77

88
from .data import Data
9+
from .sampler import BatchSampler
910

1011

1112
class OpDataSet(Data):
@@ -32,6 +33,7 @@ def __init__(
3233
self.train_x, self.train_y = X_train, y_train
3334
self.test_x, self.test_y = X_test, y_test
3435
elif fname_train is not None:
36+
# Bug to be fixed: x should be a tuple
3537
train_data = np.loadtxt(fname_train)
3638
self.train_x = train_data[:, col_x]
3739
self.train_y = train_data[:, col_y]
@@ -44,11 +46,19 @@ def __init__(
4446
if standardize:
4547
self._standardize()
4648

49+
self.train_sampler = BatchSampler(len(self.train_y), shuffle=True)
50+
4751
def losses(self, targets, outputs, loss, model):
4852
return [loss(targets, outputs)]
4953

5054
def train_next_batch(self, batch_size=None):
51-
return self.train_x, self.train_y
55+
if batch_size is None:
56+
return self.train_x, self.train_y
57+
indices = self.train_sampler.get_next(batch_size)
58+
return (
59+
(self.train_x[0][indices], self.train_x[1][indices]),
60+
self.train_y[indices],
61+
)
5262

5363
def test(self):
5464
return self.test_x, self.test_y

deepxde/data/sampler.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import numpy as np
6+
7+
8+
class BatchSampler(object):
9+
"""Samples a mini-batch of indices.
10+
11+
The indices are repeated indefinitely. Has the same effect as:
12+
13+
.. code-block:: python
14+
15+
indices = tf.data.Dataset.range(num_samples)
16+
indices = indices.repeat().shuffle(num_samples).batch(batch_size)
17+
iterator = iter(indices)
18+
batch_indices = iterator.get_next()
19+
20+
However, ``tf.data.Dataset.__iter__()`` is only supported inside of ``tf.function`` or when eager execution is
21+
enabled. ``tf.data.Dataset.make_one_shot_iterator()`` supports graph mode, but is too slow.
22+
23+
This class is not implemented as a Python Iterator, so that it can support dynamic batch size.
24+
25+
Args:
26+
num_samples (int): The number of samples.
27+
shuffle (bool): Set to ``True`` to have the indices reshuffled at every epoch.
28+
"""
29+
30+
def __init__(self, num_samples, shuffle=True):
31+
self.num_samples = num_samples
32+
self.shuffle = shuffle
33+
34+
self._indices = np.arange(self.num_samples)
35+
self._epochs_completed = 0
36+
self._index_in_epoch = 0
37+
38+
# Shuffle for the first epoch
39+
if shuffle:
40+
np.random.shuffle(self._indices)
41+
42+
@property
43+
def epochs_completed(self):
44+
return self._epochs_completed
45+
46+
def get_next(self, batch_size):
47+
"""Returns the indices of the next batch.
48+
49+
Args:
50+
batch_size (int): The number of elements to combine in a single batch.
51+
"""
52+
if batch_size > self.num_samples:
53+
raise ValueError(
54+
"batch_size={} is larger than num_samples={}.".format(
55+
batch_size, self.num_samples
56+
)
57+
)
58+
59+
start = self._index_in_epoch
60+
if start + batch_size <= self.num_samples:
61+
self._index_in_epoch += batch_size
62+
end = self._index_in_epoch
63+
return self._indices[start:end]
64+
else:
65+
# Finished epoch
66+
self._epochs_completed += 1
67+
# Get the rest examples in this epoch
68+
rest_num_samples = self.num_samples - start
69+
indices_rest_part = np.copy(
70+
self._indices[start : self.num_samples]
71+
) # self._indices will be shuffled below.
72+
# Shuffle the indices
73+
if self.shuffle:
74+
np.random.shuffle(self._indices)
75+
# Start next epoch
76+
start = 0
77+
self._index_in_epoch = batch_size - rest_num_samples
78+
end = self._index_in_epoch
79+
indices_new_part = self._indices[start:end]
80+
return np.hstack((indices_rest_part, indices_new_part))

0 commit comments

Comments
 (0)