Skip to content

Fail loudly for NeMo Curator Dask-Cuda cluster creation CUDA context issues #675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
8 changes: 4 additions & 4 deletions nemo_curator/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@
)
FuzzyDuplicates = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup.fuzzyduplicates", "FuzzyDuplicates")

# PyTorch-related imports must come after all imports that require cuGraph
# because of context cleanup issues between PyTorch and cuGraph
# See this issue: https://github.com/rapidsai/cugraph/issues/2718

EmbeddingCreator = gpu_only_import_from("nemo_curator.modules.semantic_dedup.embeddings", "EmbeddingCreator")
ClusteringModel = gpu_only_import_from("nemo_curator.modules.semantic_dedup.clusteringmodel", "ClusteringModel")
SemanticClusterLevelDedup = gpu_only_import_from(
"nemo_curator.modules.semantic_dedup.semanticclusterleveldedup",
"SemanticClusterLevelDedup",
)
SemDedup = gpu_only_import_from("nemo_curator.modules.semantic_dedup.semdedup", "SemDedup")

# PyTorch-related imports must come after all imports that require cuGraph
# because of context cleanup issues between PyTorch and cuGraph
# See this issue: https://github.com/rapidsai/cugraph/issues/2718
from .filter import Filter, ParallelScoreFilter, Score, ScoreFilter

__all__ = [
Expand Down
89 changes: 85 additions & 4 deletions nemo_curator/utils/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import ast
import os
import shutil
import socket
import subprocess
import uuid
from collections import Counter, defaultdict

import dask

Expand All @@ -35,11 +38,13 @@
import pandas as pd
import psutil
from dask.distributed import Client, LocalCluster, get_worker, performance_report
from distributed.diagnostics.nvml import has_cuda_context

from nemo_curator.utils.gpu_utils import is_cudf_type
from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from

cudf = gpu_only_import("cudf")
cp = gpu_only_import("cupy")
LocalCUDACluster = gpu_only_import_from("dask_cuda", "LocalCUDACluster")
get_device_total_memory = gpu_only_import_from("dask_cuda.utils", "get_device_total_memory")
if TYPE_CHECKING:
Expand Down Expand Up @@ -77,6 +82,81 @@ def get_filepath_without_extension(path: str) -> str:
return filename


def _worker_gpu_tuple() -> tuple[str, int]:
"""
Runs on a Dask-CUDA worker.
Returns (hostname, gpu_index) where `gpu_index` is the index shown by `nvidia-smi`.
"""

# Touch the GPU so a context is created (idempotent if one already exists)
cp.cuda.runtime.getDevice()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not create a CUDA context, you can verify that by watching nvidia-smi while you launch a process that calls this. We canonically use numba.cuda.current_context() in Dask, I would suggest the same because that is already tested and known to work without any side-effects.

ctx = has_cuda_context()

# Robust hostname lookup
try:
hostname = socket.gethostname()
except Exception as exc: # noqa: BLE001 (broad on purpose)
placeholder = f"unknown-{uuid.uuid4().hex[:8]}"
warnings.warn(
f"socket.gethostname() failed: {exc!r}. Using placeholder host name '{placeholder}'.",
stacklevel=2,
)
hostname = placeholder
if ctx.has_context and ctx.device_info is not None:
device_index = ctx.device_info.device_index
else:
# Fallback - context not yet created or NVML unavailable
warnings.warn(
f"Unable to retrieve valid GPU index for host '{hostname}'. "
"GPU context may not be initialized or NVML may be unavailable. Returning -1.",
stacklevel=2,
)
device_index = -1
return hostname, device_index


def _assert_unique_gpu_per_host(client: Client) -> None:
"""
Verifies that each Dask worker on a given host is bound to a unique GPU.

Raises
------
RuntimeError
If two or more workers on the same host are bound to the same GPU.
The error message details:
• host name
• GPU index with duplicates
• number of workers bound to that GPU
• total workers detected on the host
"""
# Returns a dictionary of worker addresses to (hostname, gpu_index)
info = client.run(_worker_gpu_tuple)

# Group GPU indices by host
per_host: dict[str, list[int]] = defaultdict(list)
for host, gpu in info.values():
per_host[host].append(gpu)

# Build a human-readable report of duplicates
duplicate_hosts: list[str] = []
for host, gpus in per_host.items():
counts = Counter(gpus)
# Keep only GPUs bound more than once
dup_gpus = {gpu: n for gpu, n in counts.items() if n > 1}
if dup_gpus:
lines = [f" GPU {gpu} → {n} workers" for gpu, n in sorted(dup_gpus.items())]
summary = f"\nHost: {host} (total workers: {len(gpus)})\n" + "\n".join(lines)
duplicate_hosts.append(summary)

if duplicate_hosts:
report = (
"Duplicate GPU assignment detected!\n"
+ "\n".join(duplicate_hosts)
+ "\nEach worker on a host must own a distinct GPU."
)
raise RuntimeError(report)


def start_dask_gpu_local_cluster( # noqa: PLR0913
nvlink_only: bool = False,
protocol: str = "tcp",
Expand Down Expand Up @@ -245,15 +325,13 @@ def get_client( # noqa: PLR0913
if get_num_workers(client) <= 0:
msg = "No workers are currently connected."
raise NoWorkerError(msg)
return client
elif scheduler_file:
client = Client(scheduler_file=scheduler_file, timeout="30s")
if get_num_workers(client) <= 0:
msg = "No workers are currently connected."
raise NoWorkerError(msg)
return client
elif cluster_type == "gpu":
return start_dask_gpu_local_cluster(
client = start_dask_gpu_local_cluster(
nvlink_only=nvlink_only,
protocol=protocol,
rmm_pool_size=rmm_pool_size,
Expand All @@ -266,11 +344,14 @@ def get_client( # noqa: PLR0913
**cluster_kwargs,
)
else:
return start_dask_cpu_local_cluster(
client = start_dask_cpu_local_cluster(
n_workers=n_workers,
threads_per_worker=threads_per_worker,
**cluster_kwargs,
)
if cluster_type == "gpu":
_assert_unique_gpu_per_host(client)
return client


def _set_torch_to_use_rmm() -> None:
Expand Down