-
Notifications
You must be signed in to change notification settings - Fork 47
Open
Labels
Description
Describe the bug
Any information about train_state isn't passed into the dataset provider when using dataloader_type = "external"
Steps/Code to reproduce bug
Try to create a dataset provider with a dataloader_type = "external" which state can be restored (even if it's defined by just the number of consumed samples).
Expected behavior
The ability to access it needs to be provided.
Additional context
In samplers.py:
if dataloader_type == "single":
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=data_parallel_rank,
data_parallel_size=data_parallel_size,
drop_last=drop_last,
)
elif dataloader_type == "cyclic":
batch_sampler = MegatronPretrainingRandomSampler(
dataset,
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=data_parallel_rank,
data_parallel_size=data_parallel_size,
data_sharding=data_sharding,
)
elif dataloader_type == "external":
# External dataloaders are passed through. User is expected to provide a
# torch-compatible dataloader and define samplers, if needed.
return datasetFor single and cyclic the consumed_samples is explicitly passed into sampler but for external it's thought to be provided by the dataset itself which is unrealistic.
A solution might be to pass the train_state into build_train_valid_test_datasets which then passes it down into build_train_valid_test_datasets_provider.