diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index 88681dd16..4482e1eeb 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -445,7 +445,11 @@ def __array__(self, dtype=None, **kwargs): def persist(self, fuse=True, **kwargs): out = self.optimize(fuse=fuse) - return DaskMethodsMixin.persist(out, **kwargs) + return DaskMethodsMixin.persist( + out, + task_resources=out.expr.collect_task_resources(), + **kwargs, + ) def compute(self, fuse=True, **kwargs): """Compute this DataFrame. @@ -474,7 +478,11 @@ def compute(self, fuse=True, **kwargs): if not isinstance(out, Scalar): out = out.repartition(npartitions=1) out = out.optimize(fuse=fuse) - return DaskMethodsMixin.compute(out, **kwargs) + return DaskMethodsMixin.compute( + out, + task_resources=out.expr.collect_task_resources(), + **kwargs, + ) def analyze(self, filename: str | None = None, format: str | None = None) -> None: """Outputs statistics about every node in the expression. @@ -2494,6 +2502,30 @@ def to_delayed(self, optimize_graph=True): """ return self.to_legacy_dataframe().to_delayed(optimize_graph=optimize_graph) + def resource_barrier(self, resources): + """Define a resource-constraint barrier + + Parameters + ---------- + resources : dict + Resource constraint (e.g. ``{GPU: 1}``). + + Notes + ----- + 1. This resources constraint will be applied to all tasks + generated by operations after this point (or until the + `resource_barrier` API is used again). + 2. This resource constraint will superceed any other + resource constraints defined with global annotations. + 3. Creating a resource barrier will not block optimizations + like column projection or predicate pushdown. We assume + both projection and filtering are resource agnostic. + 4. Resource constraints only apply to distributed execution. + 5. The scheduler will only try to satisfy resource constraints + when relevant worker resources exist. + """ + return new_collection(expr.ElemwiseResourceBarrier(self.expr, resources)) + def to_backend(self, backend: str | None = None, **kwargs): """Move to a new DataFrame backend @@ -5192,6 +5224,7 @@ def read_parquet( filesystem="fsspec", engine=None, arrow_to_pandas=None, + resources=None, **kwargs, ): """ @@ -5371,6 +5404,10 @@ def read_parquet( arrow_to_pandas: dict, default None Dictionary of options to use when converting from ``pyarrow.Table`` to a pandas ``DataFrame`` object. Only used by the "arrow" engine. + resources: dict, default None + Resource constraint to apply to the generated IO tasks and all + future operations. The `resource_barrier` API can be used to modify + future resource constraints after the collection is created. **kwargs: dict (of dicts) Options to pass through to ``engine.read_partitions`` as stand-alone key-word arguments. Note that these options will be ignored by the @@ -5386,6 +5423,7 @@ def read_parquet( to_parquet pyarrow.parquet.ParquetDataset """ + from dask_expr.io.io import IOResourceBarrier from dask_expr.io.parquet import ( ReadParquetFSSpec, ReadParquetPyarrowFS, @@ -5405,6 +5443,9 @@ def read_parquet( if op == "in" and not isinstance(val, (set, list, tuple)): raise TypeError("Value of 'in' filter must be a list, set or tuple.") + if resources is not None: + resources = IOResourceBarrier(resources) + if ( isinstance(filesystem, pa_fs.FileSystem) or isinstance(filesystem, str) @@ -5454,6 +5495,7 @@ def read_parquet( pyarrow_strings_enabled=pyarrow_strings_enabled(), kwargs=kwargs, _series=isinstance(columns, str), + resource_requirement=resources, ) ) @@ -5476,6 +5518,7 @@ def read_parquet( engine=_set_parquet_engine(engine), kwargs=kwargs, _series=isinstance(columns, str), + resource_requirement=resources, ) ) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 8d474a44f..75a717ddd 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -738,6 +738,30 @@ def walk(self) -> Generator[Expr]: yield node + def collect_task_resources(self) -> dict: + resources_annotation = {} + stack = [self] + seen = set() + while stack: + node = stack.pop() + if node._name in seen: + continue + seen.add(node._name) + + resources = node._resources + if resources is not None: + resources_annotation.update( + { + k: (resources(k) if callable(resources) else resources) + for k in node._layer().keys() + } + ) + + for dep in node.dependencies(): + stack.append(dep) + + return resources_annotation + def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]: """Search the expression graph for a specific operation type diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index c0366432a..c832fe3a6 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -52,7 +52,7 @@ random_state_data, ) from pandas.errors import PerformanceWarning -from tlz import merge_sorted, partition, unique +from tlz import merge, merge_sorted, partition, unique from dask_expr import _core as core from dask_expr._util import ( @@ -87,6 +87,11 @@ def ndim(self): except AttributeError: return 0 + @functools.cached_property + def _resources(self): + dep_resources = merge(dep._resources or {} for dep in self.dependencies()) + return dep_resources or None + def __dask_keys__(self): return [(self._name, i) for i in range(self.npartitions)] @@ -1303,6 +1308,36 @@ def operation(df): return df.copy(deep=True) +class ResourceBarrier(Expr): + @property + def _resources(self): + raise NotImplementedError() + + def __str__(self): + return f"{type(self).__name__}({self._resources})" + + +class ElemwiseResourceBarrier(Elemwise, ResourceBarrier): + _parameters = ["frame", "resource_spec"] + _projection_passthrough = True + _filter_passthrough = True + _preserves_partitioning_information = True + + @property + def _resources(self): + return self.resource_spec + + def _task(self, index: int): + return (self.frame._name, index) + + @property + def _meta(self): + return self.frame._meta + + def _divisions(self): + return self.frame.divisions + + class RenameSeries(Elemwise): _parameters = ["frame", "index", "sorted_index"] _defaults = {"sorted_index": False} @@ -3128,7 +3163,7 @@ def are_co_aligned(*exprs): def is_valid_blockwise_op(expr): return isinstance(expr, Blockwise) and not isinstance( - expr, (FromPandas, FromArray, FromDelayed) + expr, (FromPandas, FromArray, FromDelayed, ResourceBarrier) ) diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index a28b8762e..87261bbbf 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -19,6 +19,7 @@ Literal, PartitionsFiltered, Projection, + ResourceBarrier, determine_column_projection, no_default, ) @@ -31,6 +32,17 @@ def __str__(self): return f"{type(self).__name__}({self._name[-7:]})" +class IOResourceBarrier(ResourceBarrier): + _parameters = ["resource_spec"] + + @property + def _resources(self): + return self.resource_spec + + def _layer(self): + return {} + + class FromGraph(IO): """A DataFrame created from an opaque Dask task graph @@ -149,7 +161,8 @@ def _tune_up(self, parent): class FusedParquetIO(FusedIO): - _parameters = ["_expr"] + _parameters = ["_expr", "resource_requirement"] + _defaults = {"resource_requirement": None} @functools.cached_property def _name(self): @@ -159,6 +172,10 @@ def _name(self): + _tokenize_deterministic(*self.operands) ) + def dependencies(self): + dep = self.resource_requirement + return [] if dep is None else [dep] + @staticmethod def _load_multiple_files( frag_filters, diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index dc75a47a0..9fd5290e1 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -828,6 +828,7 @@ class ReadParquetPyarrowFS(ReadParquet): "arrow_to_pandas", "pyarrow_strings_enabled", "kwargs", + "resource_requirement", "_partitions", "_series", "_dataset_info_cache", @@ -844,6 +845,7 @@ class ReadParquetPyarrowFS(ReadParquet): "arrow_to_pandas": None, "pyarrow_strings_enabled": True, "kwargs": None, + "resource_requirement": None, "_partitions": None, "_series": False, "_dataset_info_cache": None, @@ -1098,7 +1100,7 @@ def _tune_up(self, parent): return if isinstance(parent, FusedParquetIO): return - return parent.substitute(self, FusedParquetIO(self)) + return parent.substitute(self, FusedParquetIO(self, self.resource_requirement)) @cached_property def fragments(self): @@ -1253,6 +1255,7 @@ class ReadParquetFSSpec(ReadParquet): "filesystem", "engine", "kwargs", + "resource_requirement", "_partitions", "_series", "_dataset_info_cache", @@ -1273,6 +1276,7 @@ class ReadParquetFSSpec(ReadParquet): "filesystem": "fsspec", "engine": "pyarrow", "kwargs": None, + "resource_requirement": None, "_partitions": None, "_series": False, "_dataset_info_cache": None,