Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions megatron/core/datasets/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import time
from abc import ABC, abstractmethod
from collections.abc import Iterable
from datetime import datetime
from enum import Enum
from functools import lru_cache
from itertools import accumulate
Expand Down Expand Up @@ -434,8 +435,14 @@ class _FileBinReader(_BinReader):
bin_path (str): The path to the data (.bin) file.
"""

def __init__(self, bin_path: str) -> None:
def __init__(
self, bin_path: str, num_max_retries: int = 3, sleep_duration_start: int = 10
) -> None:
self._bin_path = bin_path
# Retry-specific parameters. With default arguments, sleep for 10, 20, 40 seconds
# between retries.
self.num_max_retries = num_max_retries
self.sleep_duration_start = sleep_duration_start

def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndarray:
"""Read bytes into a numpy array.
Expand All @@ -451,17 +458,43 @@ def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndar
numpy.ndarray: An array with `count` items and data-type `dtype` constructed from
reading bytes from the data file starting at `offset`.
"""
sequence = numpy.empty(count, dtype=dtype)
if MultiStorageClientFeature.is_enabled():
msc = MultiStorageClientFeature.import_package()
with msc.open(self._bin_path, mode="rb", buffering=0) as bin_buffer_file:
bin_buffer_file.seek(offset)
bin_buffer_file.readinto(sequence)
else:
with open(self._bin_path, mode="rb", buffering=0) as bin_buffer_file:
bin_buffer_file.seek(offset)
bin_buffer_file.readinto(sequence)
return sequence

def _read():
"""Helper method to read `count` bytes from self._bin_path at provided offset."""
sequence = numpy.empty(count, dtype=dtype)
if MultiStorageClientFeature.is_enabled():
msc = MultiStorageClientFeature.import_package()
with msc.open(self._bin_path, mode="rb", buffering=0) as bin_buffer_file:
bin_buffer_file.seek(offset)
bin_buffer_file.readinto(sequence)
else:
with open(self._bin_path, mode="rb", buffering=0) as bin_buffer_file:
bin_buffer_file.seek(offset)
bin_buffer_file.readinto(sequence)
return sequence

sleep_duration = self.sleep_duration_start
for i in range(self.num_max_retries + 1):
try:
return _read()
except Exception as e:
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
if i == self.num_max_retries:
logger.warning(
f"[{time_str}] {self.num_max_retries+1} total tries to read data item "
f"failed; going to abort and re-raise exception \"{e}\"..."
)
# Re-raise exception if in last iteration of for loop.
raise e
logger.warning(
f"[{time_str}] Attempt {i+1}/{self.num_max_retries+1} to read data item "
f"failed with exception \"{e}\"; going to sleep for {sleep_duration} "
"seconds and then re-try..."
)
time.sleep(sleep_duration)
sleep_duration = sleep_duration * 2

raise RuntimeError("Should not reach here!")


class _S3BinReader(_BinReader):
Expand Down
Loading