-
Notifications
You must be signed in to change notification settings - Fork 7.2k
spilling peak and average #60809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
spilling peak and average #60809
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,6 +1,8 @@ | ||||||||||||||||||
| import logging | ||||||||||||||||||
| import threading | ||||||||||||||||||
| import time | ||||||||||||||||||
| from abc import abstractmethod | ||||||||||||||||||
| from typing import Any, Dict, Optional | ||||||||||||||||||
| from typing import Any, Dict, List, Optional | ||||||||||||||||||
|
|
||||||||||||||||||
| import ray | ||||||||||||||||||
| import ray.train | ||||||||||||||||||
|
|
@@ -14,11 +16,88 @@ | |||||||||||||||||
|
|
||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||
|
|
||||||||||||||||||
| SPILL_MONITOR_ACTOR_NAME = "spill_metrics_monitor" | ||||||||||||||||||
| SPILL_MONITOR_ACTOR_NAMESPACE = "_spill_metrics_monitor" | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @ray.remote(num_cpus=0) | ||||||||||||||||||
| class SpillMetricsMonitor: | ||||||||||||||||||
| """Actor that periodically polls object store spill metrics | ||||||||||||||||||
| to compute peak and average spilling rates (GB/s). | ||||||||||||||||||
|
|
||||||||||||||||||
| A single instance is shared across all workers via a named actor. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| def __init__(self, poll_interval_s: float = 10.0): | ||||||||||||||||||
| self._poll_interval_s = poll_interval_s | ||||||||||||||||||
| self._stop_event = threading.Event() | ||||||||||||||||||
| self._spill_rates_gb_s: List[float] = [] | ||||||||||||||||||
| self._lock = threading.Lock() | ||||||||||||||||||
|
|
||||||||||||||||||
| self._thread = threading.Thread(target=self._poll_loop, daemon=True) | ||||||||||||||||||
| self._thread.start() | ||||||||||||||||||
|
|
||||||||||||||||||
| def _get_spilled_bytes(self) -> int: | ||||||||||||||||||
| memory_info = get_memory_info_reply( | ||||||||||||||||||
| get_state_from_address(ray.get_runtime_context().gcs_address) | ||||||||||||||||||
| ) | ||||||||||||||||||
| return memory_info.store_stats.spilled_bytes_total | ||||||||||||||||||
|
|
||||||||||||||||||
| def _poll_loop(self): | ||||||||||||||||||
| try: | ||||||||||||||||||
| prev_spilled_bytes = self._get_spilled_bytes() | ||||||||||||||||||
| prev_time = time.monotonic() | ||||||||||||||||||
| except Exception as e: | ||||||||||||||||||
| logger.warning(f"SpillMetricsMonitor: failed initial poll: {e}") | ||||||||||||||||||
| return | ||||||||||||||||||
|
|
||||||||||||||||||
| while not self._stop_event.wait(self._poll_interval_s): | ||||||||||||||||||
| try: | ||||||||||||||||||
| current_bytes = self._get_spilled_bytes() | ||||||||||||||||||
| current_time = time.monotonic() | ||||||||||||||||||
|
|
||||||||||||||||||
| delta_bytes = current_bytes - prev_spilled_bytes | ||||||||||||||||||
| delta_time = current_time - prev_time | ||||||||||||||||||
|
|
||||||||||||||||||
| if delta_time > 0: | ||||||||||||||||||
| rate_gb_s = (delta_bytes / (1024**3)) / delta_time | ||||||||||||||||||
| with self._lock: | ||||||||||||||||||
| self._spill_rates_gb_s.append(rate_gb_s) | ||||||||||||||||||
|
Comment on lines
+62
to
+65
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| prev_spilled_bytes = current_bytes | ||||||||||||||||||
| prev_time = current_time | ||||||||||||||||||
| except Exception as e: | ||||||||||||||||||
| logger.warning(f"SpillMetricsMonitor: poll failed: {e}") | ||||||||||||||||||
|
|
||||||||||||||||||
| def get_metrics(self) -> Dict[str, float]: | ||||||||||||||||||
| with self._lock: | ||||||||||||||||||
| rates = list(self._spill_rates_gb_s) | ||||||||||||||||||
|
|
||||||||||||||||||
| if not rates: | ||||||||||||||||||
| return {} | ||||||||||||||||||
|
|
||||||||||||||||||
| return { | ||||||||||||||||||
| "object_store_spilling_peak_gb_s": round(max(rates), 4), | ||||||||||||||||||
| "object_store_spilling_avg_gb_s": round(sum(rates) / len(rates), 4), | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def get_or_create_spill_metrics_monitor( | ||||||||||||||||||
| poll_interval_s: float = 10.0, | ||||||||||||||||||
| ) -> ray.actor.ActorHandle: | ||||||||||||||||||
| return SpillMetricsMonitor.options( | ||||||||||||||||||
| name=SPILL_MONITOR_ACTOR_NAME, | ||||||||||||||||||
| namespace=SPILL_MONITOR_ACTOR_NAMESPACE, | ||||||||||||||||||
| get_if_exists=True, | ||||||||||||||||||
| lifetime="detached", | ||||||||||||||||||
| ).remote(poll_interval_s=poll_interval_s) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class RayDataLoaderFactory(BaseDataLoaderFactory): | ||||||||||||||||||
| def __init__(self, benchmark_config: BenchmarkConfig) -> None: | ||||||||||||||||||
| super().__init__(benchmark_config) | ||||||||||||||||||
| self._ray_ds_iterators = {} | ||||||||||||||||||
| self._spill_monitor: Optional[ray.actor.ActorHandle] = None | ||||||||||||||||||
|
|
||||||||||||||||||
| dataloader_config = self.get_dataloader_config() | ||||||||||||||||||
| assert isinstance(dataloader_config, RayDataConfig), type(dataloader_config) | ||||||||||||||||||
|
|
@@ -60,6 +139,10 @@ def get_train_dataloader(self): | |||||||||||||||||
| Returns: | ||||||||||||||||||
| Iterator of training batches | ||||||||||||||||||
| """ | ||||||||||||||||||
| # Get or create the shared spill monitor actor on first call. | ||||||||||||||||||
| if self._spill_monitor is None: | ||||||||||||||||||
| self._spill_monitor = get_or_create_spill_metrics_monitor() | ||||||||||||||||||
|
|
||||||||||||||||||
| ds_iterator = ray.train.get_dataset_shard(DatasetKey.TRAIN) | ||||||||||||||||||
| self._ray_ds_iterators[DatasetKey.TRAIN] = ds_iterator | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -177,4 +260,12 @@ def get_metrics(self) -> Dict[str, Any]: | |||||||||||||||||
| f"Failed to collect object_store_spilled_total_gb metric: {e}" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Collect peak and average spilling rate from the background monitor. | ||||||||||||||||||
| if self._spill_monitor is not None: | ||||||||||||||||||
| try: | ||||||||||||||||||
| spill_metrics = ray.get(self._spill_monitor.get_metrics.remote()) | ||||||||||||||||||
| metrics.update(spill_metrics) | ||||||||||||||||||
| except Exception as e: | ||||||||||||||||||
| logger.warning(f"Failed to collect spill rate metrics: {e}") | ||||||||||||||||||
|
|
||||||||||||||||||
| return metrics | ||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_poll_loopmethod is missing a return type hint. For consistency with the rest of the codebase's type annotations, it should be specified. Since this method doesn't return a value, the hint should be-> None.