Skip to content

Fix NeMo Curator Cluster Creation Cuda context issues #675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nemo_curator/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@
)
FuzzyDuplicates = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup.fuzzyduplicates", "FuzzyDuplicates")

# PyTorch-related imports must come after all imports that require cuGraph
# because of context cleanup issues between PyTorch and cuGraph
# See this issue: https://github.com/rapidsai/cugraph/issues/2718

EmbeddingCreator = gpu_only_import_from("nemo_curator.modules.semantic_dedup.embeddings", "EmbeddingCreator")
ClusteringModel = gpu_only_import_from("nemo_curator.modules.semantic_dedup.clusteringmodel", "ClusteringModel")
SemanticClusterLevelDedup = gpu_only_import_from(
"nemo_curator.modules.semantic_dedup.semanticclusterleveldedup",
"SemanticClusterLevelDedup",
)
SemDedup = gpu_only_import_from("nemo_curator.modules.semantic_dedup.semdedup", "SemDedup")

# PyTorch-related imports must come after all imports that require cuGraph
# because of context cleanup issues between PyTorch and cuGraph
# See this issue: https://github.com/rapidsai/cugraph/issues/2718
from .filter import Filter, ParallelScoreFilter, Score, ScoreFilter

__all__ = [
Expand Down
64 changes: 60 additions & 4 deletions nemo_curator/utils/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import ast
import os
import shutil
import socket
import subprocess
from collections import defaultdict

import dask

Expand Down Expand Up @@ -77,6 +79,59 @@ def get_filepath_without_extension(path: str) -> str:
return filename


def _worker_gpu_tuple() -> tuple[str, int]:
"""
Runs on a Dask-CUDA worker.
Returns (hostname, gpu_index) where `gpu_index` is the index shown by `nvidia-smi`.
"""
import cupy # noqa
from pynvml import (
nvmlDeviceGetHandleByPciBusId,
nvmlDeviceGetIndex,
NVMLError,
)

dev_id = cupy.cuda.runtime.getDevice()
props = cupy.cuda.runtime.getDeviceProperties(dev_id)

pci_bus_id = (f"{props['pciDomainID']:08x}:{props['pciBusID']:02x}:{props['pciDeviceID']:02x}.0").upper()

try:
handle = nvmlDeviceGetHandleByPciBusId(pci_bus_id)
index = nvmlDeviceGetIndex(handle)
except NVMLError as e:
warnings.warn(f"NVML error occurred: {e} while verifying GPU index", stacklevel=2)
index = -1 # fallback - shouldn't happen

return socket.gethostname(), index


def _assert_unique_gpu_per_host(client: Client) -> None:
"""
Raises RuntimeError if two workers on the same host map to the same GPU.
"""
info = client.run(_worker_gpu_tuple) # {worker_addr: (host, gpu)}
per_host = defaultdict(list)
for host, gpu in info.values():
per_host[host].append(gpu)

# Find hosts where GPUs are assigned more than once
dups = {}
for host, gpus in per_host.items():
unique_gpus = set(gpus)
if len(gpus) != len(unique_gpus):
# Find which GPUs are duplicated
duplicated = [gpu for gpu in unique_gpus if gpus.count(gpu) > 1]
dups[host] = duplicated

# If any duplicates are found, raise an error with details
if dups:
duplicate_error = "Duplicate GPU assignment detected on host(s): " + ", ".join(
f"{host}: {dups[host]}" for host in dups
)
raise RuntimeError(duplicate_error)


def start_dask_gpu_local_cluster( # noqa: PLR0913
nvlink_only: bool = False,
protocol: str = "tcp",
Expand Down Expand Up @@ -245,15 +300,13 @@ def get_client( # noqa: PLR0913
if get_num_workers(client) <= 0:
msg = "No workers are currently connected."
raise NoWorkerError(msg)
return client
elif scheduler_file:
client = Client(scheduler_file=scheduler_file, timeout="30s")
if get_num_workers(client) <= 0:
msg = "No workers are currently connected."
raise NoWorkerError(msg)
return client
elif cluster_type == "gpu":
return start_dask_gpu_local_cluster(
client = start_dask_gpu_local_cluster(
nvlink_only=nvlink_only,
protocol=protocol,
rmm_pool_size=rmm_pool_size,
Expand All @@ -266,11 +319,14 @@ def get_client( # noqa: PLR0913
**cluster_kwargs,
)
else:
return start_dask_cpu_local_cluster(
client = start_dask_cpu_local_cluster(
n_workers=n_workers,
threads_per_worker=threads_per_worker,
**cluster_kwargs,
)
if cluster_type == "gpu":
_assert_unique_gpu_per_host(client)
return client


def _set_torch_to_use_rmm() -> None:
Expand Down