Skip to content

Commit ece85e9

Browse files
Internal
PiperOrigin-RevId: 938107452
1 parent 1a709b1 commit ece85e9

4 files changed

Lines changed: 66 additions & 4 deletions

File tree

grain/_src/python/dataset/transformations/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ py_test(
268268
srcs = ["interleave_test.py"],
269269
srcs_version = "PY3",
270270
deps = [
271+
":zip",
271272
"//grain/_src/python:options",
272273
"//grain/_src/python/dataset",
273274
"//grain/_src/python/dataset:base",

grain/_src/python/dataset/transformations/interleave_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import threading
1516
from typing import cast
17+
1618
from absl.testing import absltest
1719
from absl.testing import flagsaver
1820
from absl.testing import parameterized
@@ -22,6 +24,8 @@
2224
from grain._src.python.dataset import dataset
2325
from grain._src.python.dataset.transformations import interleave
2426
from grain._src.python.dataset.transformations import prefetch
27+
from grain._src.python.dataset.transformations import repeat
28+
from grain._src.python.dataset.transformations import zip as zip_dataset
2529
from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint
2630
import numpy as np
2731

@@ -515,6 +519,53 @@ def test_setting_shard_state_with_exhausted_states(self):
515519
if isinstance(self, InterleaveIterDatasetTest):
516520
self.assertEqual(state["exhausted"], [0, 1])
517521

522+
def test_options_propagated_with_interleaved_interleaves(self):
523+
ds = (
524+
dataset.MapDataset.range(0, 1500)
525+
.to_iter_dataset()
526+
.filter(lambda x: False)
527+
)
528+
interleave_ds = self._create_dataset([ds], cycle_length=1)
529+
interleave_ds_2 = self._create_dataset([interleave_ds], cycle_length=1)
530+
531+
filter_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1)
532+
ds_with_options = dataset.WithOptionsIterDataset(
533+
interleave_ds_2, filter_options
534+
)
535+
with self.assertRaisesRegex(ValueError, r"skipped 100\.00 %"):
536+
list(ds_with_options)
537+
538+
def test_options_propagated_with_zipped_interleaves(self):
539+
no_filter_ds = dataset.MapDataset.range(
540+
1200, 1500
541+
).to_iter_dataset() # 300 elements
542+
543+
filter_ds = (
544+
dataset.MapDataset.range(0, 1500)
545+
.to_iter_dataset()
546+
.filter(lambda x: x >= 1200)
547+
)
548+
interleave_ds1 = self._create_dataset([filter_ds], cycle_length=1)
549+
interleave_ds2 = self._create_dataset([no_filter_ds], cycle_length=1)
550+
zipped_ds = zip_dataset.ZipIterDataset([interleave_ds1, interleave_ds2])
551+
zipped_ds2 = zip_dataset.ZipIterDataset([interleave_ds2, interleave_ds1])
552+
553+
filter_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1)
554+
ds_with_options1 = dataset.WithOptionsIterDataset(zipped_ds, filter_options)
555+
ds_with_options2 = dataset.WithOptionsIterDataset(
556+
zipped_ds2, filter_options
557+
)
558+
559+
with self.assertRaisesRegex(
560+
ValueError, r"FilterDatasetIterator.*skipped 100\.00 %"
561+
):
562+
list(ds_with_options1)
563+
564+
with self.assertRaisesRegex(
565+
ValueError, r"FilterDatasetIterator.*skipped 100\.00 %"
566+
):
567+
list(ds_with_options2)
568+
518569

519570
if __name__ == "__main__":
520571
absltest.main()

grain/_src/python/dataset/transformations/prefetch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,8 +555,7 @@ def __init__(
555555
assert target_prefetch_buffer_size >= 0, target_prefetch_buffer_size
556556
self._target_prefetch_buffer_size = target_prefetch_buffer_size
557557
self.autotune_buffer_size = autotune_buffer_size
558-
self._step_zero_state: StateT = parent.get_state()
559-
self._state: StateT | None = self._step_zero_state
558+
self._state: StateT | None = None
560559
self._next_index: int | None = 0
561560

562561
self._prefetch_thread: threading.Thread | None = None
@@ -628,6 +627,9 @@ def start_prefetch(self):
628627
)
629628
def __next__(self):
630629

630+
if self._state is None:
631+
self._state = self._maybe_nonnative_parent.get_state()
632+
631633
timer = dataset_stats.Timer()
632634
with timer:
633635
if self._target_prefetch_buffer_size > 0:

grain/_src/python/dataset/transformations/repeat.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def __init__(
9595
super().__init__(parent)
9696
self._num_epochs = num_epochs
9797
self._epoch = 0
98-
self._parent_starting_state = self._parent.get_state()
98+
self._parent_starting_state = None
99+
100+
def _ensure_keep_alive_flags(self):
99101
# Check for ProcessPrefetchDatasetIterator and InterleaveDatasetIterator and
100102
# ensure processes/iterators are not reset on StopIteration. This is needed
101103
# to avoid recreating the worker processes on each epoch.
@@ -107,13 +109,16 @@ def __init__(
107109
if isinstance(node, interleave.InterleaveDatasetIterator):
108110
node.set_keep_iterators_after_stop_iteration(True)
109111
to_visit.extend(n for n in node._iterators_in_use if n is not None) # pylint: disable=protected-access
110-
to_visit.extend(n for n in node._parents)
112+
to_visit.extend(n for n in node._parents) # pylint: disable=protected-access
111113

112114
@stats.record_next_duration_if_output
113115
def __next__(self):
114116
timer = stats.Timer()
115117
if self._epoch == self._num_epochs:
116118
raise StopIteration
119+
if self._parent_starting_state is None:
120+
self._parent_starting_state = self._parent.get_state()
121+
self._ensure_keep_alive_flags()
117122
while True:
118123
try:
119124
elem = next(self._parent)
@@ -131,6 +136,9 @@ def get_state(self):
131136
return {"parent": self._parent.get_state(), "epoch": self._epoch}
132137

133138
def set_state(self, state):
139+
if self._parent_starting_state is None:
140+
self._parent_starting_state = self._parent.get_state()
141+
self._ensure_keep_alive_flags()
134142
self._epoch = state["epoch"]
135143
self._parent.set_state(state["parent"])
136144

0 commit comments

Comments
 (0)