@@ -539,7 +539,15 @@ def train_step(
539
539
training_loss = []
540
540
validation_loss = []
541
541
l_validation = np .nan
542
- xs_train , ys_train = next (training_dataset )
542
+
543
+ train_dataset_batched = (
544
+ training_dataset .batch_size != training_dataset .n_episodes
545
+ )
546
+ if train_dataset_batched :
547
+ xs_train , ys_train = next (training_dataset )
548
+ else :
549
+ xs_train , ys_train = training_dataset .get_all ()
550
+
543
551
if validation_dataset is not None :
544
552
xs_eval , ys_eval = validation_dataset .get_all ()
545
553
else :
@@ -550,13 +558,8 @@ def train_step(
550
558
random_key , subkey_train , subkey_validation = jax .random .split (
551
559
random_key , 3
552
560
)
553
- # If the training dataset is batched, get a new batch of data
554
- # TODO(kevinjmiller): Implement prefetching for batched datasets as well
555
- if training_dataset .batch_size != training_dataset .n_episodes :
556
- warnings .warn (
557
- 'Training dataset is batched, but prefetching is not implemented.'
558
- ' This may slow down training.'
559
- )
561
+ # If the training dataset is batched, get a new batch of data.
562
+ if train_dataset_batched :
560
563
xs_train , ys_train = next (training_dataset )
561
564
562
565
loss , params , opt_state = train_step (
0 commit comments