-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Description
Describe the bug
When using Virtual Pipeline Parallelism (--virtual-pipeline-model-parallel-size > 1), the training script redundantly initializes the dataset and data iterators once per virtual stage (i.e., len(model) times). This leads to multiple independent np.memmap calls on the same large data files (e.g., 16GB .bin files), causing excessive virtual address space consumption. On systems with TB-scale datasets, this results in:
OSError: [Errno 12] Cannot allocate memory
during mmap.mmap() even when physical memory and vm.max_map_count are sufficient.
This issue does not occur when VPP is disabled, confirming that the redundant initialization—not the data itself—is the root cause.
Tagging @megatron-oncall for attention.
Steps/Code to reproduce bug
- Prepare a dataset consisting of one or more large binary files (e.g., 16GB each) loaded via np.memmap in a custom Dataset.
- Launch training with VPP enabled, e.g.:
python pretrain_gpt.py
--virtual-pipeline-model-parallel-size 4
--pipeline-model-parallel-size 2
... # other PP/MP settings
- Observe that build_train_valid_test_data_iterators is called 4 times (once per virtual model).
- Training fails during dataset initialization with OSError: [Errno 12] Cannot allocate memory.
Minimal code pattern causing the issue (from pretrain_*.py):
if args.virtual_pipeline_model_parallel_size is not None:
train_data_iterator = []
for i in range(len(model)):
# ❌ Called N times (N = VPP size)
iterators = build_train_valid_test_data_iterators(provider)
train_data_iterator.append(iterators[0])
Expected behavior
Data iterators should be initialized only once per rank, regardless of VPP configuration. All virtual pipeline stages should share the same data iterator instance, as they consume micro-batches sequentially from a single global data stream.
The correct pattern should be:
train_iter, valid_iter, test_iter = build_train_valid_test_data_iterators(provider)
if args.virtual_pipeline_model_parallel_size is not None:
# Share the same iterator across all virtual stages
train_data_iterator = [train_iter] * len(model)
# ... similarly for valid/test
This avoids redundant dataset initialization and prevents virtual memory exhaustion.
Additional context
Observed in Megatron-LM versions 0.5 through 0.15.
The problem becomes severe with large mmap-backed datasets (common in LLM pretraining).
vm.max_map_count is already set to a very high value (2147483642), ruling out mmap region limits.
The fix is backward-compatible and aligns with how other frameworks (e.g., DeepSpeed, ColossalAI) handle VPP data loading.
This is a resource usage bug, not a correctness bug—training works if memory permits, but unnecessarily fails under realistic large-data conditions.