-
Notifications
You must be signed in to change notification settings - Fork 34
Description
Describe the bug
rank_genes_groups_logreg now fails. It was working before.
Steps/Code to reproduce bug
Load in an anndata with Dask arrays for raw (adata.X) and normalized (adata.layers["Norm"]).
Set up multiGPU OOM
#Code
cluster = LocalCUDACluster(
CUDA_VISIBLE_DEVICES=[0,1]
threads_per_worker=2
)
def set_mem():
try:
rmm.reinitialize(
managed_memory=True, # Allows oversubscription
)
cp.cuda.set_allocator(rmm_cupy_allocator)
except Exception as e:
print(f"Warning: Could not initialize RMM with managed memory: {e}")
print("Continuing with default memory management...")
# Create local CUDA cluster
client = Client(cluster)
client.run(set_mem)
rsc.get.anndata_to_GPU(adata, layer = 'Norm')
adata.layers['Norm'] = adata.layers['Norm'].persist()
adata.layers['Norm'].compute_chunk_sizes()
rsc.tl.rank_genes_groups_logreg(adata, groupby = 'leiden', layer = 'Norm', use_raw = False) /home/mppebworth/miniforge3/envs/newrapids/lib/python3.12/site-packages/distributed/client.py:3371: UserWarning: Sending large graph of size 14.08 MiB.
This may cause some slowdown.
Consider loading the data with Dask directly or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information. warnings.warn(
2025-08-21 00:27:14,693 - distributed.worker - ERROR - Compute Failed
Key: _func_fit-5eeac890-7832-4c1e-9143-2cae4eda6e5f
State: executing
Task: <Task '_func_fit-5eeac890-7832-4c1e-9143-2cae4eda6e5f' _func_fit(...)>
Exception: "ValueError('please use scipy csr_matrix because cupyx uses int32 index dtype that does not support 3375391001 non-zero values of a partition')"
Traceback: ' File "/home/mppebworth/miniforge3/envs/newrapids/lib/python3.12/site-packages/cuml/dask/linear_model/logistic_regression.py", line 188, in _func_fit\n raise ValueError(\n' 2025-08-21 00:27:14,694 - distributed.worker - ERROR - Compute Failed Key: _func_fit-0d6eaf84-ca1e-41bf-868c-75331b6e8a15 S
tate: executing Task: <Task '_func_fit-0d6eaf84-ca1e-41bf-868c-75331b6e8a15' _func_fit(...)> Exception: "ValueError('please use scipy csr_matrix because cupyx uses int32 index dtype that does not support 3862347464 non-zero values of a partition')" Traceback: ' File "/home/mppebworth/miniforge3/envs/newrapids/lib/python3.12/site-packages/cuml/dask/linear_model/logistic_regression.py", line 188, in _func_fit\n raise ValueError(\n' Traceback (most recent call last): File "", line 1, in File "/home/mppebworth/rapids_singlecell/src/rapids_singlecell/tools/_rank_gene_groups.py", line 191, in rank_genes_groups_logreg clf.fit(X, grouping_logreg) File "/home/mppebworth/miniforge3/envs/newrapids/lib/python3.12/site-packages/cuml/dask/linear_model/logistic_regression.py", line 159, in fit models = self._fit( ^^^^^^^^^^ File "/home/mppebworth/miniforge3/envs/newrapids/lib/python3.12/site-packages/cuml/dask/common/base.py", line 450, in _fit wait_and_raise_from_futures(list(lin_fit.values())) File "/home/mppebworth/miniforge3/envs/newrapids/lib/python3.12/site-packages/cuml/dask/common/utils.py", line 164, in wait_and_raise_from_futures raise_exception_from_futures(futures) File "/home/mppebworth/miniforge3/envs/newrapids/lib/python3.12/site-packages/cuml/dask/common/utils.py", line 152, in raise_exception_from_futures raise RuntimeError( RuntimeError: 2 of 2 worker jobs failed: please use scipy csr_matrix because cupyx uses int32 index dtype that does not support 3375391001 non-zero values of a partition, please use scipy
Expected behavior
A clear and concise description of what you expected to happen.
Environment details (please complete the following information):
- Environment location: GCP Cloud environment
- Linux Distro/Architecture: Ubuntu 22
- GPU Model/Driver: A100, 570.172.08
- CUDA: [12.8]
- Method of Rapids install: pip install -e . from main github repo
pip list:
Package Version
archspec 0.2.5
boltons 24.0.0
Brotli 1.1.0
certifi 2025.1.31
cffi 1.17.1
charset-normalizer 3.4.1
colorama 0.4.6
conda 24.11.3
conda-libmamba-solver 24.9.0
conda-package-handling 2.4.0
conda_package_streaming 0.11.0
distro 1.9.0
frozendict 2.4.6
h2 4.2.0
hpack 4.1.0
hyperframe 6.1.0
idna 3.10
jsonpatch 1.33
jsonpointer 3.0.0
libmambapy 1.5.12
mamba 1.5.12
menuinst 2.2.0
packaging 24.2
pip 25.0.1
platformdirs 4.3.6
pluggy 1.5.0
pycosat 0.6.6
pycparser 2.22
PySocks 1.7.1
requests 2.32.3
ruamel.yaml 0.18.10
ruamel.yaml.clib 0.2.8
setuptools 75.8.2
tqdm 4.67.1
truststore 0.10.1
urllib3 2.3.0
wheel 0.45.1
zstandard 0.23.0
Additional context
Add any other context about the problem here.