Skip to content
40 changes: 39 additions & 1 deletion python/ray/data/_internal/progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import sys
import threading
import time
from abc import ABC, abstractmethod
from typing import Any, List, Optional

Expand Down Expand Up @@ -131,14 +133,26 @@ def __init__(
):
self._desc = truncate_operator_name(name, self.MAX_NAME_LENGTH)
self._progress = 0
self._total = total
# Prepend a space to the unit for better formatting.
if unit[0] != " ":
unit = " " + unit

if enabled is None:
if not sys.stdout.isatty():
enabled = False
if log_once("progress_bar_disabled"):
logger.info(
"Progress bar disabled because stdout is a non-interactive terminal."
)
elif enabled is None:
# When enabled is None (not explicitly set by the user),
# check DataContext setting
from ray.data.context import DataContext

enabled = DataContext.get_current().enable_progress_bars

self._use_logging = not sys.stdout.isatty()

if not enabled:
self._bar = None
elif tqdm:
Expand All @@ -163,6 +177,13 @@ def __init__(
needs_warning = False
self._bar = None

# For logging progress in non-interactive terminals
self._last_logged_time = 0
# Log interval in seconds
from ray.data.context import DataContext

self._log_interval = DataContext.get_current().progress_bar_log_interval

def set_description(self, name: str) -> None:
name = truncate_operator_name(name, self.MAX_NAME_LENGTH)
if self._bar and name != self._desc:
Expand All @@ -185,6 +206,23 @@ def update(self, increment: int = 0, total: Optional[int] = None) -> None:
# If the progress goes over 100%, update the total.
self._bar.total = self._progress
self._bar.update(increment)
elif self._use_logging and (increment != 0):
self._progress += increment
if total is not None:
self._total = total

# Log progress periodically
current_time = time.time()
time_diff = current_time - self._last_logged_time
should_log = (self._last_logged_time == 0) or (
time_diff >= self._log_interval
)

if should_log:
logger.info(
f"Progress ({self._desc}): {self._progress}/{self._total or 'unknown'}"
)
self._last_logged_time = current_time
Comment on lines +209 to +225
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should probably also into def refresh as well


def close(self):
if self._bar:
Expand Down
6 changes: 6 additions & 0 deletions python/ray/data/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ class ShuffleStrategy(str, enum.Enum):
"RAY_DATA_ENABLE_PROGRESS_BAR_NAME_TRUNCATION", True
)

# Progress bar log interval in seconds
DEFAULT_PROGRESS_BAR_LOG_INTERVAL = env_integer("RAY_DATA_PROGRESS_LOG_INTERVAL", 5)

# Globally enable or disable experimental rich progress bars. This is a new
# interface to replace the old tqdm progress bar implementation.
DEFAULT_ENABLE_RICH_PROGRESS_BARS = bool(
Expand Down Expand Up @@ -395,6 +398,8 @@ class DataContext:
`ProgressBar.MAX_NAME_LENGTH`. Otherwise, the full operator name is shown.
enable_rich_progress_bars: Whether to use the new rich progress bars instead
of the tqdm TUI.
progress_bar_log_interval: The interval in seconds for logging progress bar
updates in non-interactive terminals.
enable_get_object_locations_for_metrics: Whether to enable
``get_object_locations`` for metrics.
write_file_retry_on_errors: A list of substrings of error messages that should
Expand Down Expand Up @@ -565,6 +570,7 @@ class DataContext:
DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION
)
enable_rich_progress_bars: bool = DEFAULT_ENABLE_RICH_PROGRESS_BARS
progress_bar_log_interval: int = DEFAULT_PROGRESS_BAR_LOG_INTERVAL
enable_get_object_locations_for_metrics: bool = (
DEFAULT_ENABLE_GET_OBJECT_LOCATIONS_FOR_METRICS
)
Expand Down
166 changes: 119 additions & 47 deletions python/ray/data/tests/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,51 +36,55 @@ def wrapped_close():
bar.close = wrapped_close

# Test basic usage
pb = ProgressBar("", total, "unit", enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
for _ in range(total):
pb.update(1)
pb.close()
with patch("sys.stdout.isatty", return_value=True):
pb = ProgressBar("", total, "unit", enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
for _ in range(total):
pb.update(1)
pb.close()

assert pb._progress == total
assert total_at_close == total
assert pb._progress == total
assert total_at_close == total

# Test if update() exceeds the original total, the total will be updated.
pb = ProgressBar("", total, "unit", enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
new_total = total * 2
for _ in range(new_total):
pb.update(1)
pb.close()

assert pb._progress == new_total
assert total_at_close == new_total
with patch("sys.stdout.isatty", return_value=True):
pb = ProgressBar("", total, "unit", enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
new_total = total * 2
for _ in range(new_total):
pb.update(1)
pb.close()

assert pb._progress == new_total
assert total_at_close == new_total

# Test that if the bar is not complete at close(), the total will be updated.
pb = ProgressBar("", total, "unit")
assert pb._bar is not None
patch_close(pb._bar)
new_total = total // 2
for _ in range(new_total):
pb.update(1)
pb.close()

assert pb._progress == new_total
assert total_at_close == new_total
with patch("sys.stdout.isatty", return_value=True):
pb = ProgressBar("", total, "unit")
assert pb._bar is not None
patch_close(pb._bar)
new_total = total // 2
for _ in range(new_total):
pb.update(1)
pb.close()

assert pb._progress == new_total
assert total_at_close == new_total

# Test updating the total
pb = ProgressBar("", total, "unit", enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
new_total = total * 2
pb.update(0, new_total)
with patch("sys.stdout.isatty", return_value=True):
pb = ProgressBar("", total, "unit", enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
new_total = total * 2
pb.update(0, new_total)

assert pb._bar.total == new_total
pb.update(total + 1, total)
assert pb._bar.total == total + 1
pb.close()
assert pb._bar.total == new_total
pb.update(total + 1, total)
assert pb._bar.total == total + 1
pb.close()


@pytest.mark.parametrize(
Expand All @@ -102,16 +106,84 @@ def test_progress_bar_truncates_chained_operators(
caplog,
propagate_logs,
):
with patch.object(ProgressBar, "MAX_NAME_LENGTH", max_line_length):
pb = ProgressBar(name, None, "unit")

assert pb.get_description() == expected_description
if should_emit_warning:
assert any(
record.levelno == logging.WARNING
and "Truncating long operator name" in record.message
for record in caplog.records
), caplog.records
with patch("sys.stdout.isatty", return_value=True):
with patch.object(ProgressBar, "MAX_NAME_LENGTH", max_line_length):
pb = ProgressBar(name, None, "unit")

assert pb.get_description() == expected_description
if should_emit_warning:
assert any(
record.levelno == logging.WARNING
and "Truncating long operator name" in record.message
for record in caplog.records
), caplog.records


def test_progress_bar_non_interactive_terminal():
"""Test that progress bars are disabled in non-interactive terminals."""
total = 100

# Mock non-interactive terminal
with patch("sys.stdout.isatty", return_value=False):
# Even with enabled=True, progress bar should be disabled in non-interactive terminal
pb = ProgressBar("test", total, "unit", enabled=True)
assert pb._bar is None

with patch("sys.stdout.isatty", return_value=False):
# Even with enabled=None, progress bar should be disabled in non-interactive terminal
pb = ProgressBar("test", total, "unit")
assert pb._bar is None

# Mock interactive terminal
with patch("sys.stdout.isatty", return_value=True):
# With enabled=True, progress bar should be enabled in interactive terminal
pb = ProgressBar("test", total, "unit", enabled=True)
assert pb._bar is not None


@patch("ray.data._internal.progress_bar.logger")
def test_progress_bar_logging_in_non_interactive_terminal_with_total(mock_logger):
"""Test that progress is logged in non-interactive terminals with known total."""
total = 10

# Mock time to ensure logging occurs
with patch("ray.data._internal.progress_bar.time.time", side_effect=[0, 10]), patch(
"sys.stdout.isatty", return_value=False
):
pb = ProgressBar("test", total, "unit")
assert pb._bar is None
assert pb._use_logging is True

# Reset mock to clear the "progress bar disabled" log call
mock_logger.info.reset_mock()

# Update progress - should log
pb.update(5)

# Verify logger.info was called with expected message
mock_logger.info.assert_called_once_with("Progress (test): 5/10")


@patch("ray.data._internal.progress_bar.logger")
def test_progress_bar_logging_in_non_interactive_terminal_without_total(mock_logger):
"""Test that progress is logged in non-interactive terminals with unknown total."""

# Mock time to ensure logging occurs
with patch("ray.data._internal.progress_bar.time.time", side_effect=[0, 10]), patch(
"sys.stdout.isatty", return_value=False
):
pb = ProgressBar("test2", None, "unit")
assert pb._bar is None
assert pb._use_logging is True

# Reset mock to clear the "progress bar disabled" log call
mock_logger.info.reset_mock()

# Update progress - should log
pb.update(3)

# Verify logger.info was called with expected message for unknown total
mock_logger.info.assert_called_once_with("Progress (test2): 3/unknown")


if __name__ == "__main__":
Expand Down