Skip to content

Commit 594ca0e

Browse files
kevin-j-millercopybara-github
authored andcommitted
Remove todo to add pre-fetching for batched datasets.
I have profiled this and there no longer seems to be a meaningful slowdown with batched datasets. PiperOrigin-RevId: 756290065
1 parent a4d63a2 commit 594ca0e

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

disentangled_rnns/library/rnn_utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,15 @@ def train_step(
539539
training_loss = []
540540
validation_loss = []
541541
l_validation = np.nan
542-
xs_train, ys_train = next(training_dataset)
542+
543+
train_dataset_batched = (
544+
training_dataset.batch_size != training_dataset.n_episodes
545+
)
546+
if train_dataset_batched:
547+
xs_train, ys_train = next(training_dataset)
548+
else:
549+
xs_train, ys_train = training_dataset.get_all()
550+
543551
if validation_dataset is not None:
544552
xs_eval, ys_eval = validation_dataset.get_all()
545553
else:
@@ -550,13 +558,8 @@ def train_step(
550558
random_key, subkey_train, subkey_validation = jax.random.split(
551559
random_key, 3
552560
)
553-
# If the training dataset is batched, get a new batch of data
554-
# TODO(kevinjmiller): Implement prefetching for batched datasets as well
555-
if training_dataset.batch_size != training_dataset.n_episodes:
556-
warnings.warn(
557-
'Training dataset is batched, but prefetching is not implemented.'
558-
' This may slow down training.'
559-
)
561+
# If the training dataset is batched, get a new batch of data.
562+
if train_dataset_batched:
560563
xs_train, ys_train = next(training_dataset)
561564

562565
loss, params, opt_state = train_step(

0 commit comments

Comments
 (0)