11import math
22import random
3- import time
43import traceback
54import warnings
65
76import igraph as ig
87import leidenalg as la
9- import zarr
108import numpy as np
11- import pandas as pd
12- from scipy .sparse import coo_matrix , csr_matrix , lil_matrix , diags
9+ import zarr
10+ from scipy .sparse import coo_matrix , lil_matrix
1311from sklearn_ann .kneighbors .annoy import AnnoyTransformer
1412
1513from ._prune_features import run_pca_dim_reduction
2018# Hyperparameter Utils
2119########################################################################################################################################################
2220
21+
2322def get_subsamp_size (n ): # n==data.shape[0]
2423 """
2524 Generate subsample size.
@@ -91,27 +90,30 @@ def get_hyperparameters(k_range, la_res_range, metric=None):
9190 metric = ["euclidean" , "cosine" ][random .sample (range (2 ), 1 )[0 ]]
9291 return k , la_res , metric
9392
93+
9494########################################################################################################################################################
9595# Clustering Utils
9696########################################################################################################################################################
9797
98+
9899def sparse_put_clusters (n_orig , subsample_ids , cluster_values ):
99100 """Create a sparse cluster matrix without using put_along_axis"""
100-
101+
101102 # Get number of clusters (accounting for zero as non-cluster)
102103 n_clusters = len (np .unique (cluster_values ))
103-
104+
104105 # Create COO matrix directly from indices and values
105106 # For each data point in subsample_ids, create a 1 in its cluster column
106107 rows = subsample_ids
107108 cols = cluster_values
108109 data = np .ones_like (subsample_ids , dtype = np .uint8 )
109-
110+
110111 # Create the sparse matrix
111112 c = coo_matrix ((data , (rows , cols )), shape = (n_orig , n_clusters ))
112-
113+
113114 return c
114115
116+
115117# Util adapted from scanpy:
116118def get_igraph_from_adjacency (adjacency , directed = None ):
117119 """Get igraph graph from adjacency matrix."""
@@ -180,6 +182,7 @@ def run_la_clustering(X, k, la_res, metric="euclidean", method="sw-graph"):
180182 # print ("time to run leiden clustering: " + str(time_leiden))
181183 return np .array ([leiden_out .membership ])
182184
185+
183186def get_hard_soft_clusters (n , clustering , bg ):
184187 """
185188 Generate hard and soft clusters for a single bipartite clustering.
@@ -199,68 +202,70 @@ def get_hard_soft_clusters(n, clustering, bg):
199202 Hard cluster assignments for every sample.
200203 soft_membership_matrix : :class:`scipy.sparse.csr_matrix`
201204 Contains membership values for each sample in each consensus cluster.
202- """
205+ """
203206 # Identify cluster vertices
204207 clusters_vertex_ids = np .array (bg .vs .indices )[[x >= n for x in bg .vs .indices ]]
205208 # Get unique cluster assignments
206209 cells_clusts = np .unique (clustering )
207210 # Create mapping from cluster ID to column index
208211 clust_id_to_idx = {clust_id : idx for idx , clust_id in enumerate (cells_clusts )}
209-
212+
210213 # Initialize sparse matrix in LIL format (efficient for incremental construction)
211214 clust_occ_mat = lil_matrix ((n , len (cells_clusts )), dtype = int )
212-
215+
213216 # Process each cluster
214217 for cluster_id in cells_clusts :
215218 # Get the vertices corresponding to this cluster
216219 cluster_memb = [
217220 clusters_vertex_ids [i ] for i , j in enumerate (clustering ) if j == cluster_id
218221 ]
219-
222+
220223 # Get the edges from cells to this cluster
221224 edges = bg .es .select (_source_in = cluster_memb )
222-
225+
223226 if edges :
224227 # Get the source nodes and their counts
225228 sources = [e .source for e in edges ]
226229 source_nodes , counts = np .unique (sources , return_counts = True )
227-
230+
228231 # Update the sparse matrix for this cluster
229232 col_idx = clust_id_to_idx [cluster_id ]
230233 clust_occ_mat [source_nodes , col_idx ] = counts
231-
234+
232235 # Convert to CSR format for efficient row operations
233236 clust_occ_csr = clust_occ_mat .tocsr ()
234-
237+
235238 # Find the max value index for each row (for hard assignments)
236- row_maxes = []
237239 hard_clusters = np .zeros (n , dtype = int )
238-
240+
239241 # Process each row to find max value index
240242 for i in range (n ):
241243 row = clust_occ_csr [i ].toarray ().flatten ()
242244 if np .any (row > 0 ): # Check if row has any non-zero values
243245 max_indices = np .where (row == row .max ())[0 ]
244246 hard_clusters [i ] = np .random .choice (max_indices )
245-
247+
246248 # Create the soft membership matrix (normalize rows)
247249 row_sums = clust_occ_csr .sum (axis = 1 ).A .flatten ()
248250 # Avoid division by zero
249251 row_sums [row_sums == 0 ] = 1
250-
252+
251253 # Create a diagonal matrix with 1/row_sum
252254 from scipy .sparse import diags
255+
253256 row_sum_diag_inv = diags (1.0 / row_sums , 0 )
254-
257+
255258 # Multiply to normalize rows
256259 soft_membership_matrix = row_sum_diag_inv @ clust_occ_csr
257-
260+
258261 return hard_clusters , soft_membership_matrix
259262
263+
260264########################################################################################################################################################
261265# Main clustering
262266########################################################################################################################################################
263267
268+
264269def run_base_clustering (args_in ):
265270 """
266271 Run a single iteration of leiden clustering.
@@ -269,7 +274,7 @@ def run_base_clustering(args_in):
269274 ----------
270275 args_in : zip
271276 List containing each hyperparameter required for one round of
272- clustering (k, la_res, metric, subsample_size) as well as the
277+ clustering (k, la_res, metric, subsample_size) as well as the
273278 sparse boolean and the path to the zarr data store.
274279
275280 Returns
@@ -284,7 +289,7 @@ def run_base_clustering(args_in):
284289 zarr_loc = args_in [0 ]
285290 hyperparams_ls = args_in [1 ]
286291 sparse = args_in [2 ]
287-
292+
288293 z1 = zarr .open (zarr_loc , mode = "r" )
289294
290295 if sparse :
0 commit comments