Skip to content

External dataset providers can't access consumed_samples #1208

@pilot7747

Description

@pilot7747

Describe the bug

Any information about train_state isn't passed into the dataset provider when using dataloader_type = "external"

Steps/Code to reproduce bug

Try to create a dataset provider with a dataloader_type = "external" which state can be restored (even if it's defined by just the number of consumed samples).

Expected behavior

The ability to access it needs to be provided.

Additional context

In samplers.py:

if dataloader_type == "single":
        batch_sampler = MegatronPretrainingSampler(
            total_samples=len(dataset),
            consumed_samples=consumed_samples,
            micro_batch_size=micro_batch_size,
            data_parallel_rank=data_parallel_rank,
            data_parallel_size=data_parallel_size,
            drop_last=drop_last,
        )
    elif dataloader_type == "cyclic":
        batch_sampler = MegatronPretrainingRandomSampler(
            dataset,
            total_samples=len(dataset),
            consumed_samples=consumed_samples,
            micro_batch_size=micro_batch_size,
            data_parallel_rank=data_parallel_rank,
            data_parallel_size=data_parallel_size,
            data_sharding=data_sharding,
        )
    elif dataloader_type == "external":
        # External dataloaders are passed through. User is expected to provide a
        # torch-compatible dataloader and define samplers, if needed.
        return dataset

For single and cyclic the consumed_samples is explicitly passed into sampler but for external it's thought to be provided by the dataset itself which is unrealistic.

A solution might be to pass the train_state into build_train_valid_test_datasets which then passes it down into build_train_valid_test_datasets_provider.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions