Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,20 @@ py_test(
],
)

py_test(
name = "test_task_pool_map_operator",
size = "small",
srcs = ["tests/test_task_pool_map_operator.py"],
tags = [
"exclusive",
"team:data",
],
deps = [
":conftest",
"//:ray_lib",
],
)

py_test(
name = "test_tensor",
size = "small",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from typing import Dict, List, Optional, Union

from .common import NodeIdStr
from ray.data._internal.execution.util import memory_string
from ray.util.annotations import DeveloperAPI

from .common import NodeIdStr


class ExecutionResources:
"""Specifies resources usage or resource limits for execution.
Expand Down Expand Up @@ -136,6 +137,11 @@ def zero(cls) -> "ExecutionResources":
"""Returns an ExecutionResources object with zero resources."""
return ExecutionResources(0.0, 0.0, 0.0, 0.0)

@classmethod
def inf(cls) -> "ExecutionResources":
"""Returns an ExecutionResources object with infinite resources."""
return ExecutionResources.for_limits()

def is_zero(self) -> bool:
"""Returns True if all resources are zero."""
return (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
import uuid

import ray
from .ref_bundle import RefBundle
from ray._raylet import ObjectRefGenerator
from ray.data._internal.execution.autoscaler.autoscaling_actor_pool import (
AutoscalingActorPool,
Expand All @@ -15,10 +14,11 @@
)
from ray.data._internal.execution.interfaces.op_runtime_metrics import OpRuntimeMetrics
from ray.data._internal.logical.interfaces import LogicalOperator, Operator
from ray.data._internal.output_buffer import OutputBlockSizeOption
from ray.data._internal.stats import StatsDict, Timer
from ray.data.context import DataContext
from ray.data._internal.output_buffer import OutputBlockSizeOption

from .ref_bundle import RefBundle

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -542,13 +542,18 @@ def pending_processor_usage(self) -> ExecutionResources:
"""
return ExecutionResources(0, 0, 0)

def base_resource_usage(self) -> ExecutionResources:
"""Returns the minimum amount of resources required for execution.
def min_max_resource_requirements(
self,
) -> Tuple[ExecutionResources, ExecutionResources]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be optional

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would we handle None in the resource manager? Would it be equivalent to the default implementation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it'd default to not knowing resource reqs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, i see that you're defaulting to [0, inf) that's fine too

"""Returns the min and max resources to start the operator and make progress.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's expand that these are derived from operator's concurrency configuration multiplied by single task/actor resource requirements.


For example, an operator that creates an actor pool requiring 8 GPUs could
return ExecutionResources(gpu=8) as its base usage.
return ExecutionResources(gpu=8) as its minimum usage.

This method is used by the resource manager to reserve minimum resources and to
ensure that it doesn't over-provision resources.
"""
return ExecutionResources()
return ExecutionResources.zero(), ExecutionResources.inf()

def incremental_resource_usage(self) -> ExecutionResources:
"""Returns the incremental resources required for processing another input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,28 @@ def progress_str(self) -> str:
)
return "[locality off]"

def base_resource_usage(self) -> ExecutionResources:
min_workers = self._actor_pool.min_size()
return ExecutionResources(
cpu=self._ray_remote_args.get("num_cpus", 0) * min_workers,
gpu=self._ray_remote_args.get("num_gpus", 0) * min_workers,
def min_max_resource_requirements(
self,
) -> Tuple[ExecutionResources, ExecutionResources]:
min_actors = self._actor_pool.min_size()
assert min_actors is not None, min_actors

num_cpus_per_actor = self._ray_remote_args.get("num_cpus", 0)
num_gpus_per_actor = self._ray_remote_args.get("num_gpus", 0)
memory_per_actor = self._ray_remote_args.get("memory", 0)

min_resource_usage = ExecutionResources(
cpu=num_cpus_per_actor * min_actors,
gpu=num_gpus_per_actor * min_actors,
memory=memory_per_actor * min_actors,
# To ensure that all actors are utilized, reserve enough resource budget
# to launch one task for each worker.
object_store_memory=self._metrics.obj_store_mem_max_pending_output_per_task
* min_actors,
)

return min_resource_usage, ExecutionResources.for_limits()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return min_resource_usage, ExecutionResources.for_limits()
return min_resource_usage, ExecutionResources.inf()


def current_processor_usage(self) -> ExecutionResources:
# Both pending and running actors count towards our current resource usage.
num_active_workers = self._actor_pool.current_size()
Expand Down
12 changes: 8 additions & 4 deletions python/ray/data/_internal/execution/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@
ApplyAdditionalSplitToOutputBlocks,
MapTransformer,
)
from ray.data._internal.util import MemoryProfiler
from ray.data._internal.execution.util import memory_string
from ray.data._internal.stats import StatsDict
from ray.data._internal.util import MemoryProfiler
from ray.data.block import (
Block,
BlockAccessor,
BlockMetadata,
BlockExecStats,
BlockMetadata,
BlockStats,
to_stats,
)
Expand Down Expand Up @@ -489,8 +489,10 @@ def pending_processor_usage(self) -> ExecutionResources:
raise NotImplementedError

@abstractmethod
def base_resource_usage(self) -> ExecutionResources:
raise NotImplementedError
def min_max_resource_requirements(
self,
) -> Tuple[ExecutionResources, ExecutionResources]:
...

@abstractmethod
def incremental_resource_usage(self) -> ExecutionResources:
Expand Down Expand Up @@ -739,6 +741,8 @@ def _canonicalize_ray_remote_args(ray_remote_args: Dict[str, Any]) -> Dict[str,
"""
ray_remote_args = ray_remote_args.copy()

# TODO: Might be better to log this warning at composition-time rather than at
# execution. Validating inputs early is a good practice.
if ray_remote_args.get("num_cpus") and ray_remote_args.get("num_gpus"):
logger.warning(
"Specifying both num_cpus and num_gpus for map tasks is experimental, "
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Tuple

from ray.data._internal.execution.interfaces import (
ExecutionResources,
Expand Down Expand Up @@ -110,8 +110,10 @@ def _add_bundled_input(self, bundle: RefBundle):
def progress_str(self) -> str:
return ""

def base_resource_usage(self) -> ExecutionResources:
return ExecutionResources()
def min_max_resource_requirements(
self,
) -> Tuple[ExecutionResources, ExecutionResources]:
return self.incremental_resource_usage(), ExecutionResources.for_limits()

def current_processor_usage(self) -> ExecutionResources:
num_active_workers = self.num_active_tasks()
Expand All @@ -127,6 +129,7 @@ def incremental_resource_usage(self) -> ExecutionResources:
return ExecutionResources(
cpu=self._ray_remote_args.get("num_cpus", 0),
gpu=self._ray_remote_args.get("num_gpus", 0),
memory=self._ray_remote_args.get("memory", 0),
object_store_memory=self._metrics.obj_store_mem_max_pending_output_per_task
or 0,
)
Expand Down
71 changes: 22 additions & 49 deletions python/ray/data/_internal/execution/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from ray.data.context import DataContext

if TYPE_CHECKING:
from ray.data._internal.execution.streaming_executor_state import OpState
from ray.data._internal.execution.streaming_executor_state import Topology
from ray.data._internal.execution.streaming_executor_state import OpState, Topology


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -419,9 +418,6 @@ def __init__(self, resource_manager: ResourceManager, reservation_ratio: float):
# See `test_no_deadlock_on_small_cluster_resources` as an example.
self._reserved_min_resources: Dict[PhysicalOperator, bool] = {}

self._cached_global_limits = ExecutionResources.zero()
self._cached_num_eligible_ops = 0

self._idle_detector = self.IdleDetector()

def _is_op_eligible(self, op: PhysicalOperator) -> bool:
Expand All @@ -442,14 +438,6 @@ def _update_reservation(self):
global_limits = self._resource_manager.get_global_limits()
eligible_ops = self._get_eligible_ops()

if (
global_limits == self._cached_global_limits
and len(eligible_ops) == self._cached_num_eligible_ops
):
return
self._cached_global_limits = global_limits
self._cached_num_eligible_ops = len(eligible_ops)

self._op_reserved.clear()
self._reserved_for_op_outputs.clear()
self._reserved_min_resources.clear()
Expand All @@ -467,38 +455,23 @@ def _update_reservation(self):
# Reserve at least half of the default reserved resources for the outputs.
# This makes sure that we will have enough budget to pull blocks from the
# op.
self._reserved_for_op_outputs[op] = max(
default_reserved.object_store_memory / 2, 1.0
reserved_for_outputs = ExecutionResources(
0, 0, max(default_reserved.object_store_memory / 2, 1)
)
# Calculate the minimum amount of resources to reserve.
# 1. Make sure the reserved resources are at least to allow one task.
min_reserved = op.incremental_resource_usage().copy()
# 2. To ensure that all GPUs are utilized, reserve enough resource budget
# to launch one task for each worker.
if op.base_resource_usage().gpu > 0:
min_workers = sum(
pool.min_size() for pool in op.get_autoscaling_actor_pools()
)
min_reserved.object_store_memory *= min_workers
# Also include `reserved_for_op_outputs`.
min_reserved.object_store_memory += self._reserved_for_op_outputs[op]
# Total resources we want to reserve for this operator.
op_total_reserved = default_reserved.max(min_reserved)

# Check if the remaining resources are enough for op_total_reserved.
# Note, we only consider CPU and GPU, but not object_store_memory,
# because object_store_memory can be oversubscribed, but CPU/GPU cannot.
if op_total_reserved.satisfies_limit(

min_resource_usage, max_resource_usage = op.min_max_resource_requirements()
reserved_for_tasks = default_reserved.subtract(reserved_for_outputs)
reserved_for_tasks = reserved_for_tasks.max(min_resource_usage)
reserved_for_tasks = reserved_for_tasks.min(max_resource_usage)

# Check if the remaining resources are enough for both reserved_for_tasks
# and reserved_for_outputs. Note, we only consider CPU and GPU, but not
# object_store_memory, because object_store_memory can be oversubscribed,
# but CPU/GPU cannot.
if reserved_for_tasks.add(reserved_for_outputs).satisfies_limit(
remaining, ignore_object_store_memory=True
):
# If the remaining resources are enough to reserve `op_total_reserved`,
# subtract it from the remaining and reserve it for this op.
self._reserved_min_resources[op] = True
remaining = remaining.subtract(op_total_reserved)
self._op_reserved[op] = op_total_reserved
self._op_reserved[
op
].object_store_memory -= self._reserved_for_op_outputs[op]
else:
# If the remaining resources are not enough to reserve the minimum
# resources for this operator, we'll only reserve the minimum object
Expand All @@ -508,14 +481,8 @@ def _update_reservation(self):
# ops. It's fine that downstream ops don't get the minimum reservation,
# because they can wait for upstream ops to finish and release resources.
self._reserved_min_resources[op] = False
self._op_reserved[op] = ExecutionResources(
0,
0,
min_reserved.object_store_memory
- self._reserved_for_op_outputs[op],
)
remaining = remaining.subtract(
ExecutionResources(0, 0, min_reserved.object_store_memory)
reserved_for_tasks = ExecutionResources(
0, 0, min_resource_usage.object_store_memory
)
if index == 0:
# Log a warning if even the first operator cannot reserve
Expand All @@ -525,7 +492,13 @@ def _update_reservation(self):
" The job may hang forever unless the cluster scales up."
)

self._op_reserved[op] = reserved_for_tasks
self._reserved_for_op_outputs[op] = reserved_for_outputs.object_store_memory

op_total_reserved = reserved_for_tasks.add(reserved_for_outputs)
remaining = remaining.subtract(op_total_reserved)
remaining = remaining.max(ExecutionResources.zero())

self._total_shared = remaining

def can_submit_new_task(self, op: PhysicalOperator) -> bool:
Expand Down
7 changes: 4 additions & 3 deletions python/ray/data/_internal/execution/streaming_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
register_dataset_logger,
unregister_dataset_logger,
)
from ray.data._internal.metadata_exporter import Topology as TopologyMetadata
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.stats import DatasetStats, StatsManager, DatasetState, Timer
from ray.data._internal.stats import DatasetState, DatasetStats, StatsManager, Timer
from ray.data.context import OK_PREFIX, WARN_PREFIX, DataContext
from ray.data._internal.metadata_exporter import Topology as TopologyMetadata

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -491,7 +491,8 @@ def walk(op):

base_usage = ExecutionResources(cpu=1)
for op in walk(dag):
base_usage = base_usage.add(op.base_resource_usage())
min_resource_usage, _ = op.min_max_resource_requirements()
base_usage = base_usage.add(min_resource_usage)

if not base_usage.satisfies_limit(limits):
error_message = (
Expand Down
37 changes: 35 additions & 2 deletions python/ray/data/tests/test_actor_pool_map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import threading
import unittest
from typing import Any, Optional, Tuple
from unittest.mock import MagicMock

import pytest

Expand All @@ -11,7 +12,11 @@
from ray.actor import ActorHandle
from ray.data._internal.compute import ActorPoolStrategy
from ray.data._internal.execution.interfaces import ExecutionResources
from ray.data._internal.execution.operators.actor_pool_map_operator import _ActorPool
from ray.data._internal.execution.operators.actor_pool_map_operator import (
ActorPoolMapOperator,
_ActorPool,
)
from ray.data._internal.execution.operators.input_data_buffer import InputDataBuffer
from ray.data._internal.execution.util import make_ref_bundles
from ray.tests.conftest import * # noqa
from ray.types import ObjectRef
Expand Down Expand Up @@ -450,7 +455,33 @@ def test_locality_manager_busyness_ranking(self):
assert res3 is None


def test_start_actor_timeout(ray_start_regular, restore_data_context):
def test_min_max_resource_requirements(restore_data_context):
data_context = ray.data.DataContext.get_current()
op = ActorPoolMapOperator(
map_transformer=MagicMock(),
input_op=InputDataBuffer(data_context, input_data=MagicMock()),
data_context=data_context,
target_max_block_size=None,
compute_strategy=ray.data.ActorPoolStrategy(
min_size=1,
max_size=2,
),
ray_remote_args={"num_cpus": 1},
)
op._metrics = MagicMock(obj_store_mem_max_pending_output_per_task=3)

(
min_resource_usage_bound,
max_resource_usage_bound,
) = op.min_max_resource_requirements()

assert (
min_resource_usage_bound == ExecutionResources(cpu=1, object_store_memory=3)
and max_resource_usage_bound == ExecutionResources.for_limits()
)


def test_start_actor_timeout(ray_start_regular_shared, restore_data_context):
"""Tests that ActorPoolMapOperator raises an exception on
timeout while waiting for actors."""

Expand Down Expand Up @@ -482,6 +513,8 @@ def __call__(self, x):
def test_actor_pool_fault_tolerance_e2e(ray_start_cluster, restore_data_context):
"""Test that a dataset with actor pools can finish, when
all nodes in the cluster are removed and added back."""
ray.shutdown()

cluster = ray_start_cluster
cluster.add_node(num_cpus=0)
ray.init()
Expand Down
Loading