|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import threading |
15 | 16 | from typing import cast |
| 17 | + |
16 | 18 | from absl.testing import absltest |
17 | 19 | from absl.testing import flagsaver |
18 | 20 | from absl.testing import parameterized |
|
22 | 24 | from grain._src.python.dataset import dataset |
23 | 25 | from grain._src.python.dataset.transformations import interleave |
24 | 26 | from grain._src.python.dataset.transformations import prefetch |
| 27 | +from grain._src.python.dataset.transformations import repeat |
| 28 | +from grain._src.python.dataset.transformations import zip as zip_dataset |
25 | 29 | from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint |
26 | 30 | import numpy as np |
27 | 31 |
|
@@ -515,6 +519,53 @@ def test_setting_shard_state_with_exhausted_states(self): |
515 | 519 | if isinstance(self, InterleaveIterDatasetTest): |
516 | 520 | self.assertEqual(state["exhausted"], [0, 1]) |
517 | 521 |
|
| 522 | + def test_options_propagated_with_interleaved_interleaves(self): |
| 523 | + ds = ( |
| 524 | + dataset.MapDataset.range(0, 1500) |
| 525 | + .to_iter_dataset() |
| 526 | + .filter(lambda x: False) |
| 527 | + ) |
| 528 | + interleave_ds = self._create_dataset([ds], cycle_length=1) |
| 529 | + interleave_ds_2 = self._create_dataset([interleave_ds], cycle_length=1) |
| 530 | + |
| 531 | + filter_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) |
| 532 | + ds_with_options = dataset.WithOptionsIterDataset( |
| 533 | + interleave_ds_2, filter_options |
| 534 | + ) |
| 535 | + with self.assertRaisesRegex(ValueError, r"skipped 100\.00 %"): |
| 536 | + list(ds_with_options) |
| 537 | + |
| 538 | + def test_options_propagated_with_zipped_interleaves(self): |
| 539 | + no_filter_ds = dataset.MapDataset.range( |
| 540 | + 1200, 1500 |
| 541 | + ).to_iter_dataset() # 300 elements |
| 542 | + |
| 543 | + filter_ds = ( |
| 544 | + dataset.MapDataset.range(0, 1500) |
| 545 | + .to_iter_dataset() |
| 546 | + .filter(lambda x: x >= 1200) |
| 547 | + ) |
| 548 | + interleave_ds1 = self._create_dataset([filter_ds], cycle_length=1) |
| 549 | + interleave_ds2 = self._create_dataset([no_filter_ds], cycle_length=1) |
| 550 | + zipped_ds = zip_dataset.ZipIterDataset([interleave_ds1, interleave_ds2]) |
| 551 | + zipped_ds2 = zip_dataset.ZipIterDataset([interleave_ds2, interleave_ds1]) |
| 552 | + |
| 553 | + filter_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) |
| 554 | + ds_with_options1 = dataset.WithOptionsIterDataset(zipped_ds, filter_options) |
| 555 | + ds_with_options2 = dataset.WithOptionsIterDataset( |
| 556 | + zipped_ds2, filter_options |
| 557 | + ) |
| 558 | + |
| 559 | + with self.assertRaisesRegex( |
| 560 | + ValueError, r"FilterDatasetIterator.*skipped 100\.00 %" |
| 561 | + ): |
| 562 | + list(ds_with_options1) |
| 563 | + |
| 564 | + with self.assertRaisesRegex( |
| 565 | + ValueError, r"FilterDatasetIterator.*skipped 100\.00 %" |
| 566 | + ): |
| 567 | + list(ds_with_options2) |
| 568 | + |
518 | 569 |
|
519 | 570 | if __name__ == "__main__": |
520 | 571 | absltest.main() |
0 commit comments