Skip to content

Preserve formatting in concatenated IterableDataset #7522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
12 changes: 11 additions & 1 deletion src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3408,6 +3408,10 @@ def _concatenate_iterable_datasets(
else:
_check_column_names([col_name for dset in dsets for col_name in dset.features])

# Check format is consistent; if so, will set format for concatenated dataset
format_type_set = {dset._formatting.format_type for dset in dsets}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dset._formatting doesn't always exist

So if dset._formatting is None for all the datasets, the formatting of the output dataset should be None as well

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a check in the set comprehension, so that it skips elements where dset._formatting is None. This means that we will try to set the output format to that of the inputs who do have dset._formatting. Is that reasonable? or we should set the output format to None in this case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any discrepancy should result in formatting=None, if we want to be consistent with the Dataset API

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I've reworked the logic to account for that.

format_type = format_type_set.pop() if len(format_type_set) == 1 else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the datasets have disparate formats, maybe you can add this logging info (same as for concatenating Dataset objects):

logger.info("Some of the datasets have disparate format. Resetting the format of the concatenated dataset.")


# TODO: improve this to account for a mix of ClassLabel and Value for example
# right now it would keep the type of the first dataset in the list
features = Features(
Expand All @@ -3429,7 +3433,13 @@ def _concatenate_iterable_datasets(
# Get all the auth tokens per repository - in case the datasets come from different private repositories
token_per_repo_id = {repo_id: token for dataset in dsets for repo_id, token in dataset._token_per_repo_id.items()}
# Return new daset
return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)
return IterableDataset(
ex_iterable=ex_iterable,
info=info,
split=split,
token_per_repo_id=token_per_repo_id,
formatting=FormattingConfig(format_type=format_type),
)


def _interleave_iterable_datasets(
Expand Down
12 changes: 12 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2126,6 +2126,18 @@ def test_concatenate_datasets_axis_1_with_different_lengths():
assert list(concatenated_dataset) == [{**x, **y} for x, y in zip(extended_dataset2_list, dataset1)]


@require_torch
@require_tf
@require_jax
@pytest.mark.parametrize(
"format_type", [None, "torch", "python", "tf", "tensorflow", "np", "numpy", "jax", "arrow", "pd", "pandas"]
)
def test_concatenate_datasets_with_format(dataset: IterableDataset, format_type):
formatted_dataset = dataset.with_format(format_type)
concatenated_dataset = concatenate_datasets([formatted_dataset])
assert concatenated_dataset._formatting.format_type == get_format_type_from_alias(format_type)


@pytest.mark.parametrize(
"probas, seed, expected_length, stopping_strategy",
[
Expand Down