|
1 | 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. |
2 | 2 |
|
| 3 | +import copy |
| 4 | +import glob |
3 | 5 | import logging |
4 | 6 | import os |
5 | 7 | import time |
6 | 8 | from dataclasses import dataclass |
7 | 9 | from typing import Dict, Optional, Tuple |
8 | 10 |
|
| 11 | +from megatron.core.datasets.blended_dataset import BlendedDataset |
| 12 | +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder |
9 | 13 | import numpy |
10 | 14 | import torch |
11 | 15 |
|
|
16 | 20 | from megatron.core.datasets.utils import Split |
17 | 21 | from megatron.core.datasets.utils_s3 import S3Config, is_s3_path |
18 | 22 | from megatron.core.utils import log_single_rank |
| 23 | +from megatron.core import mpu |
19 | 24 |
|
20 | 25 | logger = logging.getLogger(__name__) |
21 | 26 |
|
@@ -587,6 +592,94 @@ def _get_num_epochs(self, num_tokens_per_epoch: int) -> int: |
587 | 592 | num_tokens += num_tokens_per_epoch |
588 | 593 | return num_epochs |
589 | 594 |
|
| 595 | +def is_dataset_built_on_rank(): |
| 596 | + return ( |
| 597 | + mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() |
| 598 | + ) and mpu.get_tensor_model_parallel_rank() == 0 |
| 599 | + |
| 600 | +class GPTDatasetFolder(MegatronDataset): |
| 601 | + """Dataset representing a folder of bin files. |
| 602 | +
|
| 603 | + In a nutshell, this is a wrapper around a BlendedDataset and builds individual GPTDatasets |
| 604 | + for each prefix in a directory and handles sampling based on natural distribution. |
| 605 | + """ |
| 606 | + |
| 607 | + def __init__( |
| 608 | + self, |
| 609 | + indexed_dataset, # unused but kept for API compatibility |
| 610 | + folder_path: str, |
| 611 | + indexed_indices, # unused but kept for API compatibility |
| 612 | + num_samples: int | None, |
| 613 | + index_split: Split, |
| 614 | + config: GPTDatasetConfig, |
| 615 | + ) -> None: |
| 616 | + self.folder_path = folder_path |
| 617 | + self.num_samples = num_samples |
| 618 | + self.index_split = index_split |
| 619 | + self.config = config |
| 620 | + self.built_anew_on_cache_miss = False |
| 621 | + del indexed_dataset |
| 622 | + del indexed_indices |
| 623 | + |
| 624 | + # Find all bin files in the directory |
| 625 | + bin_files = glob.glob(os.path.join(folder_path, "**/*.bin"), recursive=True) |
| 626 | + self.bin_prefixes = sorted([f[:-4] for f in bin_files]) # Remove .bin extension |
| 627 | + |
| 628 | + if not self.bin_prefixes: |
| 629 | + raise ValueError(f"No .bin files found in directory: {folder_path}") |
| 630 | + |
| 631 | + log_single_rank( |
| 632 | + logger, |
| 633 | + logging.INFO, |
| 634 | + f"Building GPTDatasetFolder from {folder_path} with {len(self.bin_prefixes)} bin files" |
| 635 | + ) |
| 636 | + |
| 637 | + self.internal_dataset = self._build_internal_dataset() |
| 638 | + |
| 639 | + def _build_internal_dataset(self): |
| 640 | + folder_config = copy.deepcopy(self.config) |
| 641 | + folder_config.blend = (self.bin_prefixes, None) # natural weights within bin files |
| 642 | + |
| 643 | + # TODO(MaxiBoether): validate this |
| 644 | + split_matrix = [None, None, None] # [train, valid, test] |
| 645 | + split_matrix[self.index_split.value] = (0.0, 1.0) # Use entire dataset for our split |
| 646 | + folder_config.split_matrix = split_matrix |
| 647 | + |
| 648 | + # Set up sizes for just this split |
| 649 | + sizes = [None, None, None] # [train, valid, test] |
| 650 | + sizes[self.index_split.value] = self.num_samples |
| 651 | + |
| 652 | + builder = BlendedMegatronDatasetBuilder( |
| 653 | + GPTDataset, |
| 654 | + sizes, |
| 655 | + is_dataset_built_on_rank, # TODO(MaxiBoether): validate dp + how to handle this function?? |
| 656 | + folder_config |
| 657 | + ) |
| 658 | + |
| 659 | + datasets = builder.build() |
| 660 | + internal_dataset = datasets[self.index_split.value] |
| 661 | + |
| 662 | + if internal_dataset.built_anew_on_cache_miss or any( |
| 663 | + dataset.built_anew_on_cache_miss for dataset in internal_dataset.datasets |
| 664 | + if hasattr(dataset, 'built_anew_on_cache_miss') |
| 665 | + ): |
| 666 | + self.built_anew_on_cache_miss = True |
| 667 | + |
| 668 | + return internal_dataset |
| 669 | + |
| 670 | + @staticmethod |
| 671 | + def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> None: |
| 672 | + return None # No-op |
| 673 | + |
| 674 | + @staticmethod |
| 675 | + def numel_low_level_dataset(low_level_dataset) -> int: |
| 676 | + return 0 # No-op |
| 677 | + |
| 678 | + def __len__(self) -> int: |
| 679 | + return len(self.internal_dataset) if self.internal_dataset else 0 |
| 680 | + |
| 681 | + def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: |
| 682 | + return self.internal_dataset[idx] |
590 | 683 |
|
591 | 684 | def _build_document_index( |
592 | 685 | documents: numpy.ndarray, |
|
0 commit comments