Skip to content

Preserve formatting in concatenated IterableDataset #7522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

francescorubbo
Copy link

Fixes #7515

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool ! I took the liberty to update the branch with the changes from main, and added come comments:

@@ -3408,6 +3408,10 @@ def _concatenate_iterable_datasets(
else:
_check_column_names([col_name for dset in dsets for col_name in dset.features])

# Check format is consistent; if so, will set format for concatenated dataset
format_type_set = {dset._formatting.format_type for dset in dsets}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dset._formatting doesn't always exist

So if dset._formatting is None for all the datasets, the formatting of the output dataset should be None as well

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a check in the set comprehension, so that it skips elements where dset._formatting is None. This means that we will try to set the output format to that of the inputs who do have dset._formatting. Is that reasonable? or we should set the output format to None in this case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any discrepancy should result in formatting=None, if we want to be consistent with the Dataset API

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I've reworked the logic to account for that.

@@ -3408,6 +3408,10 @@ def _concatenate_iterable_datasets(
else:
_check_column_names([col_name for dset in dsets for col_name in dset.features])

# Check format is consistent; if so, will set format for concatenated dataset
format_type_set = {dset._formatting.format_type for dset in dsets}
format_type = format_type_set.pop() if len(format_type_set) == 1 else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the datasets have disparate formats, maybe you can add this logging info (same as for concatenating Dataset objects):

logger.info("Some of the datasets have disparate format. Resetting the format of the concatenated dataset.")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

concatenate_datasets does not preserve Pytorch format for IterableDataset
3 participants