Skip to content

Commit 332a0e9

Browse files
committed
Move away from constants to class variables
Signed-off-by: Deepak Narayanan <[email protected]>
1 parent 84de5b2 commit 332a0e9

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

megatron/core/datasets/indexed_dataset.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,14 @@ class _FileBinReader(_BinReader):
435435
bin_path (str): The path to the data (.bin) file.
436436
"""
437437

438-
def __init__(self, bin_path: str) -> None:
438+
def __init__(
439+
self, bin_path: str, num_max_retries: int = 3, sleep_duration_start: int = 10
440+
) -> None:
439441
self._bin_path = bin_path
442+
# Retry-specific parameters. With default arguments, sleep for 10, 20, 40 seconds
443+
# between retries.
444+
self.num_max_retries = num_max_retries
445+
self.sleep_duration_start = sleep_duration_start
440446

441447
def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndarray:
442448
"""Read bytes into a numpy array.
@@ -467,25 +473,23 @@ def _read():
467473
bin_buffer_file.readinto(sequence)
468474
return sequence
469475

470-
NUM_MAX_RETRIES = 3
471-
SLEEP_DURATION_START = 10 # Sleep for 10, 20, 40 seconds between retries.
472-
sleep_duration = SLEEP_DURATION_START
473-
for i in range(NUM_MAX_RETRIES + 1):
476+
sleep_duration = self.sleep_duration_start
477+
for i in range(self.num_max_retries + 1):
474478
try:
475479
return _read()
476480
except Exception as e:
477481
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
478-
if i == NUM_MAX_RETRIES:
482+
if i == self.num_max_retries:
479483
logger.warning(
480-
f"[{time_str}] {NUM_MAX_RETRIES+1} total tries to read data item failed; "
481-
f"going to abort and re-raise exception \"{e}\"..."
484+
f"[{time_str}] {self.num_max_retries+1} total tries to read data item "
485+
f"failed; going to abort and re-raise exception \"{e}\"..."
482486
)
483487
# Re-raise exception if in last iteration of for loop.
484488
raise e
485489
logger.warning(
486-
f"[{time_str}] Attempt {i+1}/{NUM_MAX_RETRIES+1} to read data item failed "
487-
f"with exception \"{e}\"; going to sleep for {sleep_duration} seconds and "
488-
"then re-try..."
490+
f"[{time_str}] Attempt {i+1}/{self.num_max_retries+1} to read data item "
491+
f"failed with exception \"{e}\"; going to sleep for {sleep_duration} "
492+
"seconds and then re-try..."
489493
)
490494
time.sleep(sleep_duration)
491495
sleep_duration = sleep_duration * 2

0 commit comments

Comments
 (0)