1+ from __future__ import annotations
2+
13import warnings
4+
25warnings .filterwarnings ("ignore" )
3- import cupy as cp
4- import numpy as np
5- import anndata as ad
6- from pathlib import Path
76import os
87import time
8+ from pathlib import Path
9+
10+ import anndata as ad
11+ import cupy as cp
12+ import numpy as np
13+
914import rapids_singlecell as rsc
1015
1116# Add utils to path for GPU version
@@ -18,64 +23,71 @@ def compare_indices(adata_cpu):
1823 Compare the saturated/unsaturated indices between CPU and GPU versions.
1924 """
2025 print ("🔍 Comparing CPU vs GPU index computation..." )
21-
26+
2227 # Import the helper functions
2328 from utils .sepal_cpu import _compute_idxs
2429 from utils .sepal_gpu import _compute_idxs_gpu
25-
30+
2631 # Get connectivity and spatial data
2732 adata_gpu = ad .read_h5ad (HOME / "data/visium_hne_adata.h5ad" )
2833 rsc .get .anndata_to_GPU (adata_gpu , convert_all = True )
29- g = adata_cpu .obsp [' spatial_connectivities' ]
30- g_gpu = adata_gpu .obsp [' spatial_connectivities' ]
34+ g = adata_cpu .obsp [" spatial_connectivities" ]
35+ g_gpu = adata_gpu .obsp [" spatial_connectivities" ]
3136 degrees = cp .diff (g_gpu .indptr )
32- spatial_cpu = adata_cpu .obsm [' spatial' ].astype (np .float64 )
33- spatial_gpu = adata_gpu .obsm [' spatial' ].astype (cp .float32 )
34-
37+ spatial_cpu = adata_cpu .obsm [" spatial" ].astype (np .float64 )
38+ spatial_gpu = adata_gpu .obsm [" spatial" ].astype (cp .float32 )
39+
3540 g_gpu = rsc .get .X_to_GPU (g_gpu )
3641 spatial_gpu = rsc .get .X_to_GPU (spatial_gpu )
3742
3843 # Compute indices with both methods
3944 start = time .time ()
4045 degrees = cp .diff (g_gpu .indptr )
41- sat_gpu , sat_idx_gpu , unsat_gpu , unsat_idx_gpu = _compute_idxs_gpu (g_gpu , degrees , spatial_gpu , 6 )
46+ sat_gpu , sat_idx_gpu , unsat_gpu , unsat_idx_gpu = _compute_idxs_gpu (
47+ g_gpu , degrees , spatial_gpu , 6
48+ )
4249 end = time .time ()
4350 print ("GPU indices computed in " , end - start , "seconds" )
44-
51+
4552 start = time .time ()
46- sat_cpu , sat_idx_cpu , unsat_cpu , unsat_idx_cpu = _compute_idxs (g , spatial_cpu , 6 , "l1" )
53+ sat_cpu , sat_idx_cpu , unsat_cpu , unsat_idx_cpu = _compute_idxs (
54+ g , spatial_cpu , 6 , "l1"
55+ )
4756 end = time .time ()
4857 print ("CPU indices computed in " , end - start , "seconds" )
4958 # Convert GPU results to CPU for comparison
5059 sat_gpu_cpu = sat_gpu .get ()
5160 sat_idx_gpu_cpu = sat_idx_gpu .get ()
5261 unsat_gpu_cpu = unsat_gpu .get ()
5362 unsat_idx_gpu_cpu = unsat_idx_gpu .get ()
54-
63+
5564 print (f"Saturated nodes - CPU: { len (sat_cpu )} , GPU: { len (sat_gpu_cpu )} " )
5665 print (f"Saturated nodes identical: { np .array_equal (sat_cpu , sat_gpu_cpu )} " )
57-
58- print (f"Unsaturated nodes - CPU: { len (unsat_cpu )} , GPU: { len (unsat_gpu_cpu )} " )
66+
67+ print (f"Unsaturated nodes - CPU: { len (unsat_cpu )} , GPU: { len (unsat_gpu_cpu )} " )
5968 print (f"Unsaturated nodes identical: { np .array_equal (unsat_cpu , unsat_gpu_cpu )} " )
60-
61- print (f"Saturated indices identical: { np .array_equal (sat_idx_cpu , sat_idx_gpu_cpu )} " )
62-
69+
70+ print (
71+ f"Saturated indices identical: { np .array_equal (sat_idx_cpu , sat_idx_gpu_cpu )} "
72+ )
73+
6374 # Check unsat_idx differences (these might differ due to tie-breaking)
6475 unsat_idx_diff = np .sum (unsat_idx_cpu != unsat_idx_gpu_cpu )
65- print (f"Unsaturated index differences: { unsat_idx_diff } /{ len (unsat_idx_cpu )} ({ 100 * unsat_idx_diff / len (unsat_idx_cpu ):.1f} %)" )
66-
76+ print (
77+ f"Unsaturated index differences: { unsat_idx_diff } /{ len (unsat_idx_cpu )} ({ 100 * unsat_idx_diff / len (unsat_idx_cpu ):.1f} %)"
78+ )
79+
6780 return {
68- ' sat_identical' : np .array_equal (sat_cpu , sat_gpu_cpu ),
69- ' unsat_identical' : np .array_equal (unsat_cpu , unsat_gpu_cpu ),
70- ' sat_idx_identical' : np .array_equal (sat_idx_cpu , sat_idx_gpu_cpu ),
71- ' unsat_idx_diff_count' : unsat_idx_diff ,
72- ' unsat_idx_diff_percent' : 100 * unsat_idx_diff / len (unsat_idx_cpu )
81+ " sat_identical" : np .array_equal (sat_cpu , sat_gpu_cpu ),
82+ " unsat_identical" : np .array_equal (unsat_cpu , unsat_gpu_cpu ),
83+ " sat_idx_identical" : np .array_equal (sat_idx_cpu , sat_idx_gpu_cpu ),
84+ " unsat_idx_diff_count" : unsat_idx_diff ,
85+ " unsat_idx_diff_percent" : 100 * unsat_idx_diff / len (unsat_idx_cpu ),
7386 }
7487
88+
7589if __name__ == "__main__" :
7690 # Run comparison
7791 adata_cpu = ad .read_h5ad (HOME / "data/visium_hne_adata.h5ad" )
7892 res = compare_indices (adata_cpu )
7993 print (res )
80-
81-
0 commit comments