From 4a575b5c8bdd7fd91c41fda7e4f9d455c849054c Mon Sep 17 00:00:00 2001 From: Francesco Rubbo Date: Tue, 15 Apr 2025 19:33:29 -0700 Subject: [PATCH 1/6] Preserve formatting in concatenated iterable dataset when the inputs have consistent formatting --- src/datasets/iterable_dataset.py | 12 +++++++++++- tests/test_iterable_dataset.py | 30 +++++++++++++++++++++--------- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 423abeeb56d..da3631c81f5 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -3408,6 +3408,10 @@ def _concatenate_iterable_datasets( else: _check_column_names([col_name for dset in dsets for col_name in dset.features]) + # Check format is consistent; if so, will set format for concatenated dataset + format_type_set = {dset._formatting.format_type for dset in dsets} + format_type = format_type_set.pop() if len(format_type_set) == 1 else None + # TODO: improve this to account for a mix of ClassLabel and Value for example # right now it would keep the type of the first dataset in the list features = Features( @@ -3429,7 +3433,13 @@ def _concatenate_iterable_datasets( # Get all the auth tokens per repository - in case the datasets come from different private repositories token_per_repo_id = {repo_id: token for dataset in dsets for repo_id, token in dataset._token_per_repo_id.items()} # Return new daset - return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id) + return IterableDataset( + ex_iterable=ex_iterable, + info=info, + split=split, + token_per_repo_id=token_per_repo_id, + formatting=FormattingConfig(format_type=format_type), + ) def _interleave_iterable_datasets( diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index fd316841667..3f3b6b19c02 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1204,9 +1204,9 @@ def test_skip_examples_iterable(): skip_ex_iterable = SkipExamplesIterable(base_ex_iterable, n=count) expected = list(generate_examples_fn(n=total))[count:] assert list(skip_ex_iterable) == expected - assert skip_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is skip_ex_iterable, ( - "skip examples makes the shards order fixed" - ) + assert ( + skip_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is skip_ex_iterable + ), "skip examples makes the shards order fixed" assert_load_state_dict_resumes_iteration(skip_ex_iterable) @@ -1216,9 +1216,9 @@ def test_take_examples_iterable(): take_ex_iterable = TakeExamplesIterable(base_ex_iterable, n=count) expected = list(generate_examples_fn(n=total))[:count] assert list(take_ex_iterable) == expected - assert take_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is take_ex_iterable, ( - "skip examples makes the shards order fixed" - ) + assert ( + take_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is take_ex_iterable + ), "skip examples makes the shards order fixed" assert_load_state_dict_resumes_iteration(take_ex_iterable) @@ -1288,9 +1288,9 @@ def test_horizontally_concatenated_examples_iterable(): concatenated_ex_iterable = HorizontallyConcatenatedMultiSourcesExamplesIterable([ex_iterable1, ex_iterable2]) expected = [{**x, **y} for (_, x), (_, y) in zip(ex_iterable1, ex_iterable2)] assert [x for _, x in concatenated_ex_iterable] == expected - assert concatenated_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is concatenated_ex_iterable, ( - "horizontally concatenated examples makes the shards order fixed" - ) + assert ( + concatenated_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is concatenated_ex_iterable + ), "horizontally concatenated examples makes the shards order fixed" assert_load_state_dict_resumes_iteration(concatenated_ex_iterable) @@ -2126,6 +2126,18 @@ def test_concatenate_datasets_axis_1_with_different_lengths(): assert list(concatenated_dataset) == [{**x, **y} for x, y in zip(extended_dataset2_list, dataset1)] +@require_torch +@require_tf +@require_jax +@pytest.mark.parametrize( + "format_type", [None, "torch", "python", "tf", "tensorflow", "np", "numpy", "jax", "arrow", "pd", "pandas"] +) +def test_concatenate_datasets_with_format(dataset: IterableDataset, format_type): + formatted_dataset = dataset.with_format(format_type) + concatenated_dataset = concatenate_datasets([formatted_dataset]) + assert concatenated_dataset._formatting.format_type == get_format_type_from_alias(format_type) + + @pytest.mark.parametrize( "probas, seed, expected_length, stopping_strategy", [ From 8550975ff55f23aac4bc9d3ee7bf9d4b8bf70d21 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Mon, 28 Apr 2025 15:43:10 +0200 Subject: [PATCH 2/6] style --- tests/test_iterable_dataset.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 3f3b6b19c02..7f1e57c5728 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1204,9 +1204,9 @@ def test_skip_examples_iterable(): skip_ex_iterable = SkipExamplesIterable(base_ex_iterable, n=count) expected = list(generate_examples_fn(n=total))[count:] assert list(skip_ex_iterable) == expected - assert ( - skip_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is skip_ex_iterable - ), "skip examples makes the shards order fixed" + assert skip_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is skip_ex_iterable, ( + "skip examples makes the shards order fixed" + ) assert_load_state_dict_resumes_iteration(skip_ex_iterable) @@ -1216,9 +1216,9 @@ def test_take_examples_iterable(): take_ex_iterable = TakeExamplesIterable(base_ex_iterable, n=count) expected = list(generate_examples_fn(n=total))[:count] assert list(take_ex_iterable) == expected - assert ( - take_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is take_ex_iterable - ), "skip examples makes the shards order fixed" + assert take_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is take_ex_iterable, ( + "skip examples makes the shards order fixed" + ) assert_load_state_dict_resumes_iteration(take_ex_iterable) @@ -1288,9 +1288,9 @@ def test_horizontally_concatenated_examples_iterable(): concatenated_ex_iterable = HorizontallyConcatenatedMultiSourcesExamplesIterable([ex_iterable1, ex_iterable2]) expected = [{**x, **y} for (_, x), (_, y) in zip(ex_iterable1, ex_iterable2)] assert [x for _, x in concatenated_ex_iterable] == expected - assert ( - concatenated_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is concatenated_ex_iterable - ), "horizontally concatenated examples makes the shards order fixed" + assert concatenated_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is concatenated_ex_iterable, ( + "horizontally concatenated examples makes the shards order fixed" + ) assert_load_state_dict_resumes_iteration(concatenated_ex_iterable) From 721833f470d6b16a5d48fde8cf13043f7a4b6e64 Mon Sep 17 00:00:00 2001 From: Francesco Rubbo Date: Mon, 28 Apr 2025 07:21:31 -0700 Subject: [PATCH 3/6] If `dset._formatting` is None for any of the datasets, set the concatenated dataset format to None. Add log line for inputs with inconsistent format. --- src/datasets/iterable_dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index da3631c81f5..7a3aee46e8b 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -3409,7 +3409,9 @@ def _concatenate_iterable_datasets( _check_column_names([col_name for dset in dsets for col_name in dset.features]) # Check format is consistent; if so, will set format for concatenated dataset - format_type_set = {dset._formatting.format_type for dset in dsets} + format_type_set = {dset._formatting.format_type for dset in dsets if dset._formatting} + if len(format_type_set > 1): + logger.info("Some of the datasets have disparate format. Resetting the format of the concatenated dataset.") format_type = format_type_set.pop() if len(format_type_set) == 1 else None # TODO: improve this to account for a mix of ClassLabel and Value for example From 1e55f37d9ecd1267246fc8719594d20213b29388 Mon Sep 17 00:00:00 2001 From: Francesco Rubbo Date: Mon, 28 Apr 2025 08:13:00 -0700 Subject: [PATCH 4/6] fix incorrect grouping --- src/datasets/iterable_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 7a3aee46e8b..dac146ad95b 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -3410,7 +3410,7 @@ def _concatenate_iterable_datasets( # Check format is consistent; if so, will set format for concatenated dataset format_type_set = {dset._formatting.format_type for dset in dsets if dset._formatting} - if len(format_type_set > 1): + if len(format_type_set) > 1: logger.info("Some of the datasets have disparate format. Resetting the format of the concatenated dataset.") format_type = format_type_set.pop() if len(format_type_set) == 1 else None From 70d86732ff2ddce32e936a46546a06e7fe890cb8 Mon Sep 17 00:00:00 2001 From: Francesco Rubbo Date: Mon, 28 Apr 2025 08:26:10 -0700 Subject: [PATCH 5/6] Reset output formatting if any of the inputs has formatting not set --- src/datasets/iterable_dataset.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index dac146ad95b..c1ed7d65dde 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -3409,10 +3409,16 @@ def _concatenate_iterable_datasets( _check_column_names([col_name for dset in dsets for col_name in dset.features]) # Check format is consistent; if so, will set format for concatenated dataset - format_type_set = {dset._formatting.format_type for dset in dsets if dset._formatting} - if len(format_type_set) > 1: - logger.info("Some of the datasets have disparate format. Resetting the format of the concatenated dataset.") - format_type = format_type_set.pop() if len(format_type_set) == 1 else None + if any(dset._formatting is None for dset in dsets): + formatting = None + else: + format_type_set = {dset._formatting.format_type for dset in dsets} + format_type = format_type_set.pop() if len(format_type_set) == 1 else None + formatting = FormattingConfig(format_type=format_type) + if formatting is None: + logger.info( + "Some of the datasets have disparate format or format not set. Resetting the format of the concatenated dataset." + ) # TODO: improve this to account for a mix of ClassLabel and Value for example # right now it would keep the type of the first dataset in the list @@ -3440,7 +3446,7 @@ def _concatenate_iterable_datasets( info=info, split=split, token_per_repo_id=token_per_repo_id, - formatting=FormattingConfig(format_type=format_type), + formatting=formatting, ) From a724b12809c264da83781792e7ad876241c4c8d5 Mon Sep 17 00:00:00 2001 From: Francesco Rubbo Date: Mon, 28 Apr 2025 08:30:05 -0700 Subject: [PATCH 6/6] log unset format also in case `formatting` is set, but `format_type` is None --- src/datasets/iterable_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index c1ed7d65dde..b072fa69c8f 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -3415,7 +3415,7 @@ def _concatenate_iterable_datasets( format_type_set = {dset._formatting.format_type for dset in dsets} format_type = format_type_set.pop() if len(format_type_set) == 1 else None formatting = FormattingConfig(format_type=format_type) - if formatting is None: + if formatting is None or formatting.format_type is None: logger.info( "Some of the datasets have disparate format or format not set. Resetting the format of the concatenated dataset." )