Skip to content

Commit 680c39c

Browse files
authored
add ensemble and consensus tests
1 parent 88c66dd commit 680c39c

1 file changed

Lines changed: 49 additions & 0 deletions

File tree

tests/test_eschr.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,55 @@ def test_consensus_cluster_leiden(bipartite_graph_array):
266266
assert np.allclose(soft_membership_matrix.sum(axis=1), 1.0)
267267
assert resolution == 1.0
268268

269+
# Test ensemble function
270+
@pytest.fixture
271+
def ensemble_args(zarr_loc_static):
272+
return {
273+
"zarr_loc": zarr_loc_static,
274+
"ensemble_size": 3, # Small size for testing
275+
"nprocs": 1,
276+
"sparse": False
277+
}
278+
279+
def test_ensemble(ensemble_args):
280+
result = es.tl.main.ensemble(**ensemble_args)
281+
assert isinstance(result, coo_matrix)
282+
283+
# The shape should be (n_cells, n_clusters_total)
284+
z1 = zarr.open(ensemble_args["zarr_loc"], mode="r")
285+
n_cells = z1["X"].shape[0]
286+
assert result.shape[0] == n_cells
287+
288+
# There should be at least one cluster for each member in the ensemble
289+
assert result.shape[1] >= 3
290+
291+
# Test consensus function
292+
@pytest.fixture
293+
def consensus_args(bipartite_graph_array):
294+
n = np.max(bipartite_graph_array.row) + 1
295+
return {
296+
"n": n,
297+
"bg": bipartite_graph_array,
298+
"nprocs": 1
299+
}
300+
301+
def test_consensus(consensus_args):
302+
hard_clusters, soft_membership_matrix, all_clusterings = es.tl.main.consensus(**consensus_args)
303+
304+
# Check hard clusters
305+
assert len(hard_clusters) == consensus_args["n"]
306+
assert isinstance(hard_clusters, np.ndarray)
307+
308+
# Check soft membership matrix
309+
assert soft_membership_matrix.shape[0] == consensus_args["n"]
310+
assert np.allclose(soft_membership_matrix.sum(axis=1), 1.0)
311+
312+
# Check all_clusterings
313+
assert isinstance(all_clusterings, np.ndarray)
314+
assert all_clusterings.shape[0] == consensus_args["n"]
315+
# Should have multiple resolutions tested
316+
assert all_clusterings_df.shape[1] > 1
317+
269318
# Test main consensus_cluster function
270319
def test_consensus_cluster_basic(adata, zarr_loc):
271320

0 commit comments

Comments
 (0)