Skip to content

Commit 64e22da

Browse files
committed
First untested version
1 parent 9d248e8 commit 64e22da

3 files changed

Lines changed: 127 additions & 2 deletions

File tree

megatron/core/datasets/blended_megatron_dataset_builder.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import logging
44
import math
55
from concurrent.futures import ThreadPoolExecutor
6+
import os
67
from typing import Any, Callable, Iterable, List, Optional, Type, Union
78

9+
from megatron.core.datasets.gpt_dataset import GPTDatasetFolder
810
import numpy
911
import torch
1012

@@ -450,6 +452,29 @@ def _build_megatron_dataset_splits(
450452
torch.distributed.barrier()
451453
return [None] * len(Split)
452454

455+
# TODO(MaxiBoether): it's a bit suboptimal that we need to handle this explicitly currently
456+
# however, I don't see a straightforward way to integrate the codepath fully.
457+
if self.cls == GPTDatasetFolder and os.path.isdir(dataset_path):
458+
mid_level_datasets = []
459+
for i, _split in enumerate(Split):
460+
if split[i] is None:
461+
mid_level_datasets.append(None)
462+
else:
463+
mid_level_datasets.append(
464+
self.build_generic_dataset(
465+
self.cls,
466+
self.is_built_on_rank,
467+
synchronize_ranks,
468+
None, # indexed_dataset (unused)
469+
dataset_path, # folder_path
470+
None, # indexed_indices (unused)
471+
sizes[i],
472+
_split,
473+
self.config,
474+
)
475+
)
476+
return mid_level_datasets
477+
453478
# Build the low level dataset
454479
low_level_dataset = self.cls.build_low_level_dataset(dataset_path, self.config)
455480

megatron/core/datasets/gpt_dataset.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
22

3+
import copy
4+
import glob
35
import logging
46
import os
57
import time
68
from dataclasses import dataclass
79
from typing import Dict, Optional, Tuple
810

11+
from megatron.core.datasets.blended_dataset import BlendedDataset
12+
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
913
import numpy
1014
import torch
1115

@@ -16,6 +20,7 @@
1620
from megatron.core.datasets.utils import Split
1721
from megatron.core.datasets.utils_s3 import S3Config, is_s3_path
1822
from megatron.core.utils import log_single_rank
23+
from megatron.core import mpu
1924

2025
logger = logging.getLogger(__name__)
2126

@@ -587,6 +592,94 @@ def _get_num_epochs(self, num_tokens_per_epoch: int) -> int:
587592
num_tokens += num_tokens_per_epoch
588593
return num_epochs
589594

595+
def is_dataset_built_on_rank():
596+
return (
597+
mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
598+
) and mpu.get_tensor_model_parallel_rank() == 0
599+
600+
class GPTDatasetFolder(MegatronDataset):
601+
"""Dataset representing a folder of bin files.
602+
603+
In a nutshell, this is a wrapper around a BlendedDataset and builds individual GPTDatasets
604+
for each prefix in a directory and handles sampling based on natural distribution.
605+
"""
606+
607+
def __init__(
608+
self,
609+
indexed_dataset, # unused but kept for API compatibility
610+
folder_path: str,
611+
indexed_indices, # unused but kept for API compatibility
612+
num_samples: int | None,
613+
index_split: Split,
614+
config: GPTDatasetConfig,
615+
) -> None:
616+
self.folder_path = folder_path
617+
self.num_samples = num_samples
618+
self.index_split = index_split
619+
self.config = config
620+
self.built_anew_on_cache_miss = False
621+
del indexed_dataset
622+
del indexed_indices
623+
624+
# Find all bin files in the directory
625+
bin_files = glob.glob(os.path.join(folder_path, "**/*.bin"), recursive=True)
626+
self.bin_prefixes = sorted([f[:-4] for f in bin_files]) # Remove .bin extension
627+
628+
if not self.bin_prefixes:
629+
raise ValueError(f"No .bin files found in directory: {folder_path}")
630+
631+
log_single_rank(
632+
logger,
633+
logging.INFO,
634+
f"Building GPTDatasetFolder from {folder_path} with {len(self.bin_prefixes)} bin files"
635+
)
636+
637+
self.internal_dataset = self._build_internal_dataset()
638+
639+
def _build_internal_dataset(self):
640+
folder_config = copy.deepcopy(self.config)
641+
folder_config.blend = (self.bin_prefixes, None) # natural weights within bin files
642+
643+
# TODO(MaxiBoether): validate this
644+
split_matrix = [None, None, None] # [train, valid, test]
645+
split_matrix[self.index_split.value] = (0.0, 1.0) # Use entire dataset for our split
646+
folder_config.split_matrix = split_matrix
647+
648+
# Set up sizes for just this split
649+
sizes = [None, None, None] # [train, valid, test]
650+
sizes[self.index_split.value] = self.num_samples
651+
652+
builder = BlendedMegatronDatasetBuilder(
653+
GPTDataset,
654+
sizes,
655+
is_dataset_built_on_rank, # TODO(MaxiBoether): validate dp + how to handle this function??
656+
folder_config
657+
)
658+
659+
datasets = builder.build()
660+
internal_dataset = datasets[self.index_split.value]
661+
662+
if internal_dataset.built_anew_on_cache_miss or any(
663+
dataset.built_anew_on_cache_miss for dataset in internal_dataset.datasets
664+
if hasattr(dataset, 'built_anew_on_cache_miss')
665+
):
666+
self.built_anew_on_cache_miss = True
667+
668+
return internal_dataset
669+
670+
@staticmethod
671+
def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> None:
672+
return None # No-op
673+
674+
@staticmethod
675+
def numel_low_level_dataset(low_level_dataset) -> int:
676+
return 0 # No-op
677+
678+
def __len__(self) -> int:
679+
return len(self.internal_dataset) if self.internal_dataset else 0
680+
681+
def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]:
682+
return self.internal_dataset[idx]
590683

591684
def _build_document_index(
592685
documents: numpy.ndarray,

pretrain_gpt.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from megatron.core import mpu
1616
from megatron.core.enums import ModelType
1717
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
18-
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
18+
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, GPTDatasetFolder
1919
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
2020
from megatron.core.rerun_state_machine import get_rerun_state_machine
2121
import megatron.legacy.model
@@ -278,7 +278,14 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
278278
if args.mock_data:
279279
dataset_type = MockGPTDataset
280280
else:
281-
dataset_type = GPTDataset
281+
example_path = config.blend[0][0]
282+
if os.path.isdir(example_path):
283+
print_rank_0(f"> Using directory-based sampling.")
284+
dataset_type = GPTDatasetFolder
285+
else:
286+
print_rank_0(f"> Using file-based sampling.")
287+
dataset_type = GPTDataset
288+
assert os.path.isfile(example_path) or os.path.isfile(f"{example_path}.bin")
282289

283290
print_rank_0("> building train, validation, and test datasets for GPT ...")
284291

0 commit comments

Comments
 (0)