diff --git a/python/ray/data/_internal/execution/interfaces/physical_operator.py b/python/ray/data/_internal/execution/interfaces/physical_operator.py index f8757396dad..545458506a4 100644 --- a/python/ray/data/_internal/execution/interfaces/physical_operator.py +++ b/python/ray/data/_internal/execution/interfaces/physical_operator.py @@ -37,6 +37,7 @@ if TYPE_CHECKING: + from ray.data._internal.execution.streaming_executor_state import OpState from ray.data.block import BlockMetadataWithSchema logger = logging.getLogger(__name__) @@ -51,6 +52,23 @@ Waitable = Union[ray.ObjectRef, ObjectRefGenerator] +@dataclass(frozen=True) +class ObjectStoreUsage: + """Per-op object store accounting. + + Attributes: + internal: Bytes held by this op's currently-running tasks + (outputs not yet yielded to the object store). + outputs: Bytes this op has produced that are still live in + the object store — its internal output queue, its + ``OpState`` external output queue, and the downstream + eligible ops' inputs. + """ + + internal: int + outputs: int + + class OpTask(ABC): """Abstract class that represents a task that is created by an PhysicalOperator. @@ -868,6 +886,41 @@ def current_logical_usage(self) -> ExecutionResources: """ return ExecutionResources.zero() + def estimate_object_store_usage(self, state: "OpState") -> ObjectStoreUsage: + """Returns the bytes this operator contributes to the global object + store budget. Subclasses may override this when their object store + footprint doesn't match the generic model. + """ + # Operator's internal Object Store usage + mem_op_internal = self.metrics.obj_store_mem_pending_task_outputs or 0 + + # Operator's outputs' Object Store usage + op_outputs_bytes = ( + # Internal output queue + self.metrics.obj_store_mem_internal_outqueue + + + # External output queue + state.output_queue_bytes() + ) + + # TODO fix ineligible ops: this needs to include usage of all of OS + # for ineligible ops + # + # Outputs of this operator used downstream + used_op_outputs_bytes = sum( + ( + downstream_op.metrics.obj_store_mem_internal_inqueue_for_input( + downstream_op.input_dependencies.index(self) + ) + + downstream_op.metrics.obj_store_mem_pending_task_inputs + ) + for downstream_op in self.output_dependencies + ) + return ObjectStoreUsage( + internal=int(mem_op_internal), + outputs=int(op_outputs_bytes + used_op_outputs_bytes), + ) + def running_logical_usage(self) -> ExecutionResources: """Returns the estimated running CPU, GPU, and memory usage of this operator, excluding object store memory. diff --git a/python/ray/data/_internal/execution/resource_manager.py b/python/ray/data/_internal/execution/resource_manager.py index 1d78b50e6fb..68296421211 100644 --- a/python/ray/data/_internal/execution/resource_manager.py +++ b/python/ray/data/_internal/execution/resource_manager.py @@ -173,34 +173,9 @@ def _estimate_object_store_memory_usage( return self._external_consumer_bytes return 0 - # Operator's internal Object Store usage - mem_op_internal = op.metrics.obj_store_mem_pending_task_outputs or 0 - - # Operator's outputs' Object Store usage - op_outputs_bytes = ( - # Internal output queue - op.metrics.obj_store_mem_internal_outqueue - + - # External output queue - state.output_queue_bytes() - ) - - # TODO fix ineligible ops: this needs to include usage of all of OS - # for ineligible ops - # - # Outputs of this operator used downstream - used_op_outputs_bytes = sum( - ( - downstream_op.metrics.obj_store_mem_internal_inqueue_for_input( - downstream_op.input_dependencies.index(op) - ) - + downstream_op.metrics.obj_store_mem_pending_task_inputs - ) - for downstream_op in op.output_dependencies - ) - - self._mem_op_internal[op] = mem_op_internal - self._mem_op_outputs[op] = op_outputs_bytes + used_op_outputs_bytes + usage = op.estimate_object_store_usage(state) + self._mem_op_internal[op] = usage.internal + self._mem_op_outputs[op] = usage.outputs # Attribute iterator / streaming_split prefetch to the executor sink only. if op is self._output_operator: diff --git a/python/ray/data/tests/test_resource_manager.py b/python/ray/data/tests/test_resource_manager.py index cc17438504d..f06e82a74aa 100644 --- a/python/ray/data/tests/test_resource_manager.py +++ b/python/ray/data/tests/test_resource_manager.py @@ -16,6 +16,7 @@ ExecutionResources, ) from ray.data._internal.execution.interfaces.physical_operator import ( + ObjectStoreUsage, TaskExecDriverStats, ) from ray.data._internal.execution.operators.base_physical_operator import ( @@ -434,6 +435,35 @@ def test_object_store_usage(self, restore_data_context): assert resource_manager.get_op_usage(o2).object_store_memory == 0 assert resource_manager.get_op_usage(o3).object_store_memory == 1 + def test_object_store_accounting_delegates_to_op(self, restore_data_context): + """``ResourceManager`` must dispatch to ``op.estimate_object_store_usage`` so subclasses can override the accounting.""" + # Real upstream so the override op has a valid input dependency. + input = make_ref_bundles([[x] for x in range(1)])[0] + upstream = InputDataBuffer(DataContext.get_current(), [input]) + + # Subclass that overrides the accounting to return hard-coded + # values — bypasses the generic metrics+state computation. + override = mock_map_op(upstream) + override.estimate_object_store_usage = lambda state: ObjectStoreUsage( + internal=42, outputs=100 + ) + + topo = build_streaming_topology(override, ExecutionOptions()) + resource_manager = ResourceManager( + topo, + ExecutionOptions(), + MagicMock(return_value=ExecutionResources.zero()), + DataContext.get_current(), + ) + + resource_manager.update_usages() + + # The override's hard-coded values flow through unchanged into + # both the per-component dicts and the aggregated op usage. + assert resource_manager.get_mem_op_internal(override) == 42 + assert resource_manager.get_mem_op_outputs(override) == 100 + assert resource_manager.get_op_usage(override).object_store_memory == 42 + 100 + def test_get_completed_ops_usage(self, restore_data_context): """Test that _get_completed_ops_usage returns total usage of completed ops.""" o1 = InputDataBuffer(DataContext.get_current(), [])