Skip to content

Commit 3c4f2cc

Browse files
authored
Re-enable one step ahead device_put of data from datasets on JAX. (#22353)
The `_prefectch_numpy_iterator` feature in the JAX trainer had been disabled based on the idea that many dataset implementations already support prefectching. However, this came from a misundestanding of what `_prefetch_numpy_iterator` does, probably because of the misleading name. `_prefetch_numpy_iterator` triggers `device_put` for all the arrays of a batch of data to transfer arrays from CPU to the accelerator (GPU or TPU). `device_put` is asynchronous. By triggering `device_put` one step ahead. we parallelize the transfer of the next batch of data with the running of the model with the current batch of data. Without this, the batch of data is immediately used after called `device_put`, which causes a wait until `device_put` is complete. Prefetching happens only on CPU and is independent. - Removed the `builtin_prefetch` property, which now does not serve any purpose anymore. - Renamed `_prefetch_numpy_iterator` to `_one_batch_ahead_iterator`, which more accurately describes its purpose.
1 parent 549b476 commit 3c4f2cc

File tree

8 files changed

+19
-81
lines changed

8 files changed

+19
-81
lines changed

keras/src/backend/jax/trainer.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import collections
21
import itertools
32
import warnings
43
from functools import partial
@@ -1088,10 +1087,8 @@ def _get_iterator(self):
10881087
distribution = distribution_lib.distribution()
10891088
if distribution is not None:
10901089
return self._get_distributed_iterator(distribution)
1091-
if self.data_adapter.builtin_prefetch:
1092-
return self.data_adapter.get_jax_iterator()
10931090
else:
1094-
return self._prefetch_numpy_iterator(
1091+
return self._one_batch_ahead_iterator(
10951092
self.data_adapter.get_jax_iterator()
10961093
)
10971094

@@ -1108,27 +1105,23 @@ def _get_distributed_iterator(self, distribution):
11081105
)
11091106
yield _distribute_data(data, layouts)
11101107

1111-
def _prefetch_numpy_iterator(self, numpy_iterator):
1112-
"""Shard and prefetch batches on device.
1108+
def _one_batch_ahead_iterator(self, numpy_iterator):
1109+
"""Initiate transfers to the device one batch ahead.
11131110
1114-
Most of the implementation has been borrowed from
1115-
`flax.jax_utils.prefetch_to_device`
1116-
1117-
This utility takes an iterator and returns a new iterator which fills an
1118-
on device prefetch buffer. Eager prefetching can improve the performance
1119-
of training loops significantly by overlapping compute and data
1120-
transfer.
1111+
This utility takes an iterator and returns a new iterator which
1112+
initiates the transfer to device one step ahead. This can improve the
1113+
performance of training loops significantly by overlapping compute and
1114+
data transfer.
11211115
"""
1122-
queue = collections.deque()
1123-
1124-
# If you're training on GPUs, 2 is generally the best choice because
1125-
# this guarantees that you can overlap a training step on GPU with a
1126-
# data prefetch step on CPU.
1127-
def enqueue(n=2):
1128-
for data in itertools.islice(numpy_iterator, n):
1129-
queue.append(_distribute_data(data))
1130-
1131-
enqueue(n=2) # TODO: should we make `n` configurable?
1132-
while queue:
1133-
yield queue.popleft()
1134-
enqueue(1)
1116+
next_batch = None
1117+
for batch in numpy_iterator:
1118+
batch = _distribute_data(batch)
1119+
if next_batch is None:
1120+
next_batch = batch
1121+
else:
1122+
current_batch = next_batch
1123+
next_batch = batch
1124+
yield current_batch
1125+
1126+
if next_batch is not None:
1127+
yield next_batch

keras/src/trainers/data_adapters/data_adapter.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,6 @@ def get_torch_dataloader(self):
4646
"""
4747
raise NotImplementedError
4848

49-
@property
50-
def builtin_prefetch(self):
51-
"""Whether the DataAdapter has built-in prefetching capabilities.
52-
53-
Prefetching is an optimization technique where data is loaded and
54-
prepared in advance while the model is processing the current batch,
55-
reducing training time by overlapping data loading with computation.
56-
57-
Returns:
58-
bool: True if the DataAdapter implements its own prefetching
59-
mechanism and handles data loading asynchronously. False if the
60-
caller should implement prefetching externally.
61-
"""
62-
return False
63-
6449
@property
6550
def num_batches(self):
6651
"""Return the size (number of batches) for the dataset created.

keras/src/trainers/data_adapters/grain_dataset_adapter.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,6 @@ def __iter__(self):
201201
ConverterIterableDataset(self._dataset), batch_size=None
202202
)
203203

204-
@property
205-
def builtin_prefetch(self):
206-
return True
207-
208204
@property
209205
def num_batches(self):
210206
return None

keras/src/trainers/data_adapters/grain_dataset_adapter_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,6 @@ def test_multiple_calling_on_iterators(self):
185185
bx, by = batch
186186
self.assertEqual(bx.dtype, by.dtype)
187187

188-
def test_builtin_prefetch(self):
189-
dataset = grain.MapDataset.source(Range2DSource(0, 42))
190-
adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset)
191-
self.assertTrue(adapter.builtin_prefetch)
192-
193188
def test_num_batches(self):
194189
dataset = grain.MapDataset.source(Range2DSource(0, 42))
195190
adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset)

keras/src/trainers/data_adapters/tf_dataset_adapter.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,6 @@ def get_tf_dataset(self):
6262
def get_torch_dataloader(self):
6363
return data_adapter_utils.get_torch_dataloader(self._dataset)
6464

65-
@property
66-
def builtin_prefetch(self):
67-
return True
68-
6965
@property
7066
def num_batches(self):
7167
cardinality = self._dataset.cardinality

keras/src/trainers/data_adapters/tf_dataset_adapter_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,6 @@ def test_class_weights_int_targets(self):
8484
def test_class_weights_categorical_targets(self):
8585
self._test_class_weights(target_encoding="categorical")
8686

87-
def test_builtin_prefetch(self):
88-
dataset = tf.data.Dataset.range(42)
89-
adapter = tf_dataset_adapter.TFDatasetAdapter(dataset)
90-
self.assertTrue(adapter.builtin_prefetch)
91-
9287
def test_num_batches(self):
9388
dataset = tf.data.Dataset.range(42)
9489
cardinality = int(dataset.cardinality())

keras/src/trainers/data_adapters/torch_data_loader_adapter.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,6 @@ def get_tf_dataset(self):
6565
def get_torch_dataloader(self):
6666
return self._dataloader
6767

68-
@property
69-
def builtin_prefetch(self):
70-
prefetch_factor = self._dataloader.prefetch_factor
71-
if prefetch_factor is not None and prefetch_factor > 0:
72-
return True
73-
else:
74-
return False
75-
7668
@property
7769
def num_batches(self):
7870
return self._num_batches

keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,3 @@ def test_with_different_shapes(self):
171171
else:
172172
self.assertEqual(bx.shape, (2, 6))
173173
self.assertEqual(by.shape, (2, 2))
174-
175-
@parameterized.named_parameters(named_product(num_workers=[0, 2]))
176-
def test_builtin_prefetch(self, num_workers):
177-
x = torch.normal(2, 3, size=(34, 4))
178-
y = torch.normal(1, 3, size=(34, 2))
179-
ds = torch.utils.data.TensorDataset(x, y)
180-
dataloader = torch.utils.data.DataLoader(
181-
ds, batch_size=16, num_workers=num_workers
182-
)
183-
adapter = TorchDataLoaderAdapter(dataloader)
184-
if num_workers > 0:
185-
self.assertTrue(adapter.builtin_prefetch)
186-
else:
187-
self.assertFalse(adapter.builtin_prefetch)

0 commit comments

Comments
 (0)