Skip to content

Commit 075dc9c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e0a9cd5 commit 075dc9c

6 files changed

Lines changed: 72 additions & 64 deletions

File tree

docs/notebooks/paul15_mouse_hematopoiesis.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1158,4 +1158,4 @@
11581158
},
11591159
"nbformat": 4,
11601160
"nbformat_minor": 0
1161-
}
1161+
}

src/eschr/tl/_clustering.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import math
22
import random
3-
import time
43
import traceback
54
import warnings
65

76
import igraph as ig
87
import leidenalg as la
9-
import zarr
108
import numpy as np
11-
import pandas as pd
12-
from scipy.sparse import coo_matrix, csr_matrix, lil_matrix, diags
9+
import zarr
10+
from scipy.sparse import coo_matrix, lil_matrix
1311
from sklearn_ann.kneighbors.annoy import AnnoyTransformer
1412

1513
from ._prune_features import run_pca_dim_reduction
@@ -20,6 +18,7 @@
2018
# Hyperparameter Utils
2119
########################################################################################################################################################
2220

21+
2322
def get_subsamp_size(n): # n==data.shape[0]
2423
"""
2524
Generate subsample size.
@@ -91,27 +90,30 @@ def get_hyperparameters(k_range, la_res_range, metric=None):
9190
metric = ["euclidean", "cosine"][random.sample(range(2), 1)[0]]
9291
return k, la_res, metric
9392

93+
9494
########################################################################################################################################################
9595
# Clustering Utils
9696
########################################################################################################################################################
9797

98+
9899
def sparse_put_clusters(n_orig, subsample_ids, cluster_values):
99100
"""Create a sparse cluster matrix without using put_along_axis"""
100-
101+
101102
# Get number of clusters (accounting for zero as non-cluster)
102103
n_clusters = len(np.unique(cluster_values))
103-
104+
104105
# Create COO matrix directly from indices and values
105106
# For each data point in subsample_ids, create a 1 in its cluster column
106107
rows = subsample_ids
107108
cols = cluster_values
108109
data = np.ones_like(subsample_ids, dtype=np.uint8)
109-
110+
110111
# Create the sparse matrix
111112
c = coo_matrix((data, (rows, cols)), shape=(n_orig, n_clusters))
112-
113+
113114
return c
114115

116+
115117
# Util adapted from scanpy:
116118
def get_igraph_from_adjacency(adjacency, directed=None):
117119
"""Get igraph graph from adjacency matrix."""
@@ -180,6 +182,7 @@ def run_la_clustering(X, k, la_res, metric="euclidean", method="sw-graph"):
180182
# print ("time to run leiden clustering: " + str(time_leiden))
181183
return np.array([leiden_out.membership])
182184

185+
183186
def get_hard_soft_clusters(n, clustering, bg):
184187
"""
185188
Generate hard and soft clusters for a single bipartite clustering.
@@ -199,68 +202,70 @@ def get_hard_soft_clusters(n, clustering, bg):
199202
Hard cluster assignments for every sample.
200203
soft_membership_matrix : :class:`scipy.sparse.csr_matrix`
201204
Contains membership values for each sample in each consensus cluster.
202-
"""
205+
"""
203206
# Identify cluster vertices
204207
clusters_vertex_ids = np.array(bg.vs.indices)[[x >= n for x in bg.vs.indices]]
205208
# Get unique cluster assignments
206209
cells_clusts = np.unique(clustering)
207210
# Create mapping from cluster ID to column index
208211
clust_id_to_idx = {clust_id: idx for idx, clust_id in enumerate(cells_clusts)}
209-
212+
210213
# Initialize sparse matrix in LIL format (efficient for incremental construction)
211214
clust_occ_mat = lil_matrix((n, len(cells_clusts)), dtype=int)
212-
215+
213216
# Process each cluster
214217
for cluster_id in cells_clusts:
215218
# Get the vertices corresponding to this cluster
216219
cluster_memb = [
217220
clusters_vertex_ids[i] for i, j in enumerate(clustering) if j == cluster_id
218221
]
219-
222+
220223
# Get the edges from cells to this cluster
221224
edges = bg.es.select(_source_in=cluster_memb)
222-
225+
223226
if edges:
224227
# Get the source nodes and their counts
225228
sources = [e.source for e in edges]
226229
source_nodes, counts = np.unique(sources, return_counts=True)
227-
230+
228231
# Update the sparse matrix for this cluster
229232
col_idx = clust_id_to_idx[cluster_id]
230233
clust_occ_mat[source_nodes, col_idx] = counts
231-
234+
232235
# Convert to CSR format for efficient row operations
233236
clust_occ_csr = clust_occ_mat.tocsr()
234-
237+
235238
# Find the max value index for each row (for hard assignments)
236-
row_maxes = []
237239
hard_clusters = np.zeros(n, dtype=int)
238-
240+
239241
# Process each row to find max value index
240242
for i in range(n):
241243
row = clust_occ_csr[i].toarray().flatten()
242244
if np.any(row > 0): # Check if row has any non-zero values
243245
max_indices = np.where(row == row.max())[0]
244246
hard_clusters[i] = np.random.choice(max_indices)
245-
247+
246248
# Create the soft membership matrix (normalize rows)
247249
row_sums = clust_occ_csr.sum(axis=1).A.flatten()
248250
# Avoid division by zero
249251
row_sums[row_sums == 0] = 1
250-
252+
251253
# Create a diagonal matrix with 1/row_sum
252254
from scipy.sparse import diags
255+
253256
row_sum_diag_inv = diags(1.0 / row_sums, 0)
254-
257+
255258
# Multiply to normalize rows
256259
soft_membership_matrix = row_sum_diag_inv @ clust_occ_csr
257-
260+
258261
return hard_clusters, soft_membership_matrix
259262

263+
260264
########################################################################################################################################################
261265
# Main clustering
262266
########################################################################################################################################################
263267

268+
264269
def run_base_clustering(args_in):
265270
"""
266271
Run a single iteration of leiden clustering.
@@ -269,7 +274,7 @@ def run_base_clustering(args_in):
269274
----------
270275
args_in : zip
271276
List containing each hyperparameter required for one round of
272-
clustering (k, la_res, metric, subsample_size) as well as the
277+
clustering (k, la_res, metric, subsample_size) as well as the
273278
sparse boolean and the path to the zarr data store.
274279
275280
Returns
@@ -284,7 +289,7 @@ def run_base_clustering(args_in):
284289
zarr_loc = args_in[0]
285290
hyperparams_ls = args_in[1]
286291
sparse = args_in[2]
287-
292+
288293
z1 = zarr.open(zarr_loc, mode="r")
289294

290295
if sparse:

src/eschr/tl/_prune_features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Feature selection and dimensionality reduction functions."""
2+
23
import time
34
import warnings
45
from typing import Optional

src/eschr/tl/_zarr_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
2-
import traceback
3-
import warnings
2+
43
import zarr
54
from scipy.sparse import coo_matrix
65

6+
77
def make_zarr_sparse(adata, zarr_loc):
88
"""
99
Make zarr data store.

src/eschr/tl/main.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,20 @@
11
## Import packages=============================================================
2-
import math
32
import multiprocessing
43
import os
5-
import random
64
import time
7-
import traceback
85
import warnings
96
from itertools import repeat
107

118
import numpy as np
129
import pandas as pd
13-
import zarr
1410
from scipy.sparse import coo_matrix, csr_matrix, hstack
1511
from scipy.spatial.distance import pdist, squareform
1612
from sklearn import metrics
1713

18-
from ._zarr_utils import (make_zarr_sparse, make_zarr_dense)
19-
from ._clustering import (run_base_clustering, consensus_cluster_leiden)
14+
from ._clustering import consensus_cluster_leiden, run_base_clustering
2015
from ._prune_features import ( # ADD BACK PRECEDING DOTS
2116
calc_highly_variable_genes, calc_pca)
17+
from ._zarr_utils import make_zarr_dense, make_zarr_sparse
2218

2319
## End Import packages section=================================================
2420

@@ -182,7 +178,7 @@ def ensemble(
182178
out = parmap(run_base_clustering, args, nprocs=nprocs)
183179

184180
try:
185-
clust_out = hstack(out)
181+
clust_out = hstack(out)
186182
except Exception:
187183
print(
188184
"consensus_cluster.py, line 599, in ensemble: clust_out = hstack(out[:,0])"
@@ -250,7 +246,11 @@ def consensus(n, bg, nprocs):
250246
finish_time = time.perf_counter()
251247
print(f"Consensus clustering finished in {finish_time-start_time} seconds")
252248

253-
return hard_clusters, soft_membership_matrix, all_clusterings_df.to_numpy(dtype=np.uint16)
249+
return (
250+
hard_clusters,
251+
soft_membership_matrix,
252+
all_clusterings_df.to_numpy(dtype=np.uint16),
253+
)
254254

255255

256256
def consensus_cluster(
@@ -346,7 +346,7 @@ def consensus_cluster(
346346
la_res_range = (
347347
int(la_res_range[0]),
348348
int(la_res_range[1]),
349-
)
349+
)
350350
bipartite = ensemble(
351351
zarr_loc=zarr_loc,
352352
reduction=reduction,

tests/test_eschr.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_make_zarr_custom_path(adata, zarr_loc):
7272
assert os.path.exists(zarr_loc)
7373
shutil.rmtree(zarr_loc)
7474

75+
7576
@pytest.mark.skip(reason="Update to be testing zarr dense data structure")
7677
def test_make_zarr_content(adata, zarr_loc):
7778
es.tl._zarr_utils.make_zarr_dense(adata, zarr_loc)
@@ -242,7 +243,7 @@ def test_get_hard_soft_clusters_single_cluster(setup_data):
242243
)
243244

244245
soft_membership_matrix = soft_membership_matrix.toarray()
245-
246+
246247
# Test hard cluster assignments
247248
assert np.all(hard_clusters == 0)
248249

@@ -261,14 +262,15 @@ def test_consensus_cluster_leiden(bipartite_graph_array):
261262
resolution,
262263
) = es.tl._clustering.consensus_cluster_leiden(in_args)
263264

264-
#assert isinstance(hard_clusters, pd.Categorical)
265+
# assert isinstance(hard_clusters, pd.Categorical)
265266
assert len(hard_clusters) == n
266267
assert isinstance(soft_membership_matrix, csr_matrix)
267268
assert soft_membership_matrix.shape[0] == n
268269
assert soft_membership_matrix.shape[1] >= np.unique(hard_clusters).shape[0]
269270
assert np.allclose(soft_membership_matrix.sum(axis=1), 1.0)
270271
assert resolution == 1.0
271272

273+
272274
# Test ensemble function
273275
@pytest.fixture
274276
def ensemble_args(zarr_loc_static):
@@ -280,83 +282,83 @@ def ensemble_args(zarr_loc_static):
280282
"k_range": (15, 150),
281283
"la_res_range": (25, 175),
282284
"nprocs": 1,
283-
"sparse": False
285+
"sparse": False,
284286
}
285287

288+
286289
def test_ensemble(ensemble_args):
287290
result = es.tl.main.ensemble(**ensemble_args)
288291
assert isinstance(result, coo_matrix)
289-
292+
290293
# The shape should be (n_cells, n_clusters_total)
291294
z1 = zarr.open(ensemble_args["zarr_loc"], mode="r")
292295
n_cells = z1["X"].shape[0]
293296
assert result.shape[0] == n_cells
294-
297+
295298
# There should be at least one cluster for each member in the ensemble
296299
assert result.shape[1] >= 3
297-
300+
301+
298302
# Test consensus function
299303
@pytest.fixture
300304
def consensus_args(bipartite_graph_array):
301305
n = np.max(bipartite_graph_array.row) + 1
302-
return {
303-
"n": n,
304-
"bg": bipartite_graph_array,
305-
"nprocs": 1
306-
}
306+
return {"n": n, "bg": bipartite_graph_array, "nprocs": 1}
307+
307308

308309
def test_consensus(consensus_args):
309-
hard_clusters, soft_membership_matrix, all_clusterings = es.tl.main.consensus(**consensus_args)
310-
310+
hard_clusters, soft_membership_matrix, all_clusterings = es.tl.main.consensus(
311+
**consensus_args
312+
)
313+
311314
# Check hard clusters
312315
assert len(hard_clusters) == consensus_args["n"]
313316
assert isinstance(hard_clusters, np.ndarray)
314-
317+
315318
# Check soft membership matrix
316319
assert soft_membership_matrix.shape[0] == consensus_args["n"]
317320
assert np.allclose(soft_membership_matrix.sum(axis=1), 1.0)
318-
321+
319322
# Check all_clusterings
320323
assert isinstance(all_clusterings, np.ndarray)
321324
assert all_clusterings.shape[0] == consensus_args["n"]
322325
# Should have multiple resolutions tested
323326
assert all_clusterings.shape[1] > 1
324-
327+
328+
325329
# Test main consensus_cluster function
326330
def test_consensus_cluster_basic(adata, zarr_loc):
327-
331+
328332
# Run the full pipeline with minimal parameters
329333
result_adata = es.tl.consensus_cluster(
330-
adata,
331-
zarr_loc=zarr_loc,
332-
ensemble_size=3, # Small for testing
333-
nprocs=1
334+
adata, zarr_loc=zarr_loc, ensemble_size=3, nprocs=1 # Small for testing
334335
)
335-
336+
336337
# Check that results are added to adata object
337338
assert "hard_clusters" in result_adata.obs
338339
assert "soft_membership_matrix" in result_adata.obsm
339340
assert "uncertainty_score" in result_adata.obs
340341
assert "bipartite" in result_adata.obsm
341-
342+
342343
# Check shapes
343344
assert len(result_adata.obs["hard_clusters"]) == adata.shape[0]
344345
assert result_adata.obsm["soft_membership_matrix"].shape[0] == adata.shape[0]
345-
346+
346347
# Check that multiresolution results are not included by default
347348
assert "multiresolution_clusters" not in result_adata.obsm
348349

350+
349351
def test_consensus_cluster_with_multires(adata, zarr_loc):
350-
352+
351353
# Run with return_multires=True
352354
result_adata = es.tl.consensus_cluster(
353-
adata,
355+
adata,
354356
zarr_loc=zarr_loc,
355357
ensemble_size=3, # Small for testing
356358
nprocs=1,
357-
return_multires=True
359+
return_multires=True,
358360
)
359-
361+
360362
# Check that multiresolution results are included
361363
assert "multiresolution_clusters" in result_adata.obsm
362364

0 commit comments

Comments
 (0)