Skip to content

[FEA] Add compute_with_dask utility for down-stream Merlin libriaries #70

Open
@rjzamora

Description

@rjzamora

Proposal

I propose that we add a single compute_dask_object (or compute_with_dask) utility to merlin.core.utils, and use that utility for all Dask computation within Merlin. The purpose of this utility would be to check the global_dask_client() utility, and utilize the appropriate default client/scheduler for Merlin (and not the default for dask/distributed).

More Background

While starting to look into merlin-models#339, I noticed that there is at least one place where Merlin-models uses a bare compute() statement to compute a Dask collection. I suspect that this is also done in several other places across the Merlin ecosystem.

When there is no global Dask client in the current python context, using compute() will typically result in execution with Dask's "multi-threaded" scheduler. This may be fine for CPU-backed data, but will result in many python threads thrashing the same GPU (device 0) when the data is GPU backed.

For compute operations in NVTabular (which only operate on Delayed Dask objects), the merlin.core.utils.global_dask_client utility is used to query to current Dask client. If this function returns None, the convention is to use the "sychronous" scheduler (compute(scheduler="sychronous")), otherwise the distributed client is used. I propose that this same convention be used everywhere in Merlin (besides special cases where scheduler="synchronous" can be hard coded.

Note that these changes are also required for Merlin's Serial and Distributed context managers to work correctly.

Proposed Implementation

from merlin.core.utils import global_dask_client
from dask.base import is_dask_collection
from dask.delayed import Delayed

def  compute_dask_object(dask_obj):
    """Compute a Dask collection using Merlin's dask-client settings"""

    # Check global client
    dask_client = global_dask_client()

    if is_dask_collection(dask_obj) or isinstance(dask_obj, Delayed):
        # Compute simple Dask collections
        # (Use distributed client, or fall back to "synchronous" scheduler)
        scheduler = dask_client.get if dask_client else  "synchronous"
        return dask_obj.compute(scheduler=scheduler)
    elif isinstance(dask_obj, list):
        # Maybe check that all elements of list are collections or Delayed?
        # Compute entire list at once:
        if dask_client:
            return [r.result() for r in dask_client.compute(dask_obj)]
        else:
            return dask.compute(dask_obj, scheduler="synchronous")[0]
    else:
        raise ValueError

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions