Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion release/train_tests/benchmark/ray_dataloader_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import threading
import time
from abc import abstractmethod
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import ray
import ray.train
Expand All @@ -14,11 +16,88 @@

logger = logging.getLogger(__name__)

SPILL_MONITOR_ACTOR_NAME = "spill_metrics_monitor"
SPILL_MONITOR_ACTOR_NAMESPACE = "_spill_metrics_monitor"


@ray.remote(num_cpus=0)
class SpillMetricsMonitor:
"""Actor that periodically polls object store spill metrics
to compute peak and average spilling rates (GB/s).

A single instance is shared across all workers via a named actor.
"""

def __init__(self, poll_interval_s: float = 10.0):
self._poll_interval_s = poll_interval_s
self._stop_event = threading.Event()
self._spill_rates_gb_s: List[float] = []
self._lock = threading.Lock()

self._thread = threading.Thread(target=self._poll_loop, daemon=True)
self._thread.start()

def _get_spilled_bytes(self) -> int:
memory_info = get_memory_info_reply(
get_state_from_address(ray.get_runtime_context().gcs_address)
)
return memory_info.store_stats.spilled_bytes_total

def _poll_loop(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _poll_loop method is missing a return type hint. For consistency with the rest of the codebase's type annotations, it should be specified. Since this method doesn't return a value, the hint should be -> None.

Suggested change
def _poll_loop(self):
def _poll_loop(self) -> None:

try:
prev_spilled_bytes = self._get_spilled_bytes()
prev_time = time.monotonic()
except Exception as e:
logger.warning(f"SpillMetricsMonitor: failed initial poll: {e}")
return

while not self._stop_event.wait(self._poll_interval_s):
try:
current_bytes = self._get_spilled_bytes()
current_time = time.monotonic()

delta_bytes = current_bytes - prev_spilled_bytes
delta_time = current_time - prev_time

if delta_time > 0:
rate_gb_s = (delta_bytes / (1024**3)) / delta_time
with self._lock:
self._spill_rates_gb_s.append(rate_gb_s)
Comment on lines +62 to +65
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The spilled_bytes_total counter could theoretically reset (e.g., on GCS restart), which would cause delta_bytes to be negative. This would result in a negative spill rate being recorded, skewing the average calculation. It's safer to only calculate the rate for non-negative delta_bytes.

Suggested change
if delta_time > 0:
rate_gb_s = (delta_bytes / (1024**3)) / delta_time
with self._lock:
self._spill_rates_gb_s.append(rate_gb_s)
if delta_time > 0 and delta_bytes >= 0:
rate_gb_s = (delta_bytes / (1024**3)) / delta_time
with self._lock:
self._spill_rates_gb_s.append(rate_gb_s)


prev_spilled_bytes = current_bytes
prev_time = current_time
except Exception as e:
logger.warning(f"SpillMetricsMonitor: poll failed: {e}")

def get_metrics(self) -> Dict[str, float]:
with self._lock:
rates = list(self._spill_rates_gb_s)

if not rates:
return {}

return {
"object_store_spilling_peak_gb_s": round(max(rates), 4),
"object_store_spilling_avg_gb_s": round(sum(rates) / len(rates), 4),
}


def get_or_create_spill_metrics_monitor(
poll_interval_s: float = 10.0,
) -> ray.actor.ActorHandle:
return SpillMetricsMonitor.options(
name=SPILL_MONITOR_ACTOR_NAME,
namespace=SPILL_MONITOR_ACTOR_NAMESPACE,
get_if_exists=True,
lifetime="detached",
).remote(poll_interval_s=poll_interval_s)


class RayDataLoaderFactory(BaseDataLoaderFactory):
def __init__(self, benchmark_config: BenchmarkConfig) -> None:
super().__init__(benchmark_config)
self._ray_ds_iterators = {}
self._spill_monitor: Optional[ray.actor.ActorHandle] = None

dataloader_config = self.get_dataloader_config()
assert isinstance(dataloader_config, RayDataConfig), type(dataloader_config)
Expand Down Expand Up @@ -60,6 +139,10 @@ def get_train_dataloader(self):
Returns:
Iterator of training batches
"""
# Get or create the shared spill monitor actor on first call.
if self._spill_monitor is None:
self._spill_monitor = get_or_create_spill_metrics_monitor()

ds_iterator = ray.train.get_dataset_shard(DatasetKey.TRAIN)
self._ray_ds_iterators[DatasetKey.TRAIN] = ds_iterator

Expand Down Expand Up @@ -177,4 +260,12 @@ def get_metrics(self) -> Dict[str, Any]:
f"Failed to collect object_store_spilled_total_gb metric: {e}"
)

# Collect peak and average spilling rate from the background monitor.
if self._spill_monitor is not None:
try:
spill_metrics = ray.get(self._spill_monitor.get_metrics.remote())
metrics.update(spill_metrics)
except Exception as e:
logger.warning(f"Failed to collect spill rate metrics: {e}")

return metrics