Skip to content

Commit 541b240

Browse files
committed
fix random test fail
1 parent e83615a commit 541b240

3 files changed

Lines changed: 95 additions & 57 deletions

File tree

anndata/src/data/array/sparse/csc.rs

Lines changed: 66 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,7 @@ mod csc_matrix_index_tests {
427427
I: Iterator<Item = usize>,
428428
{
429429
let i = row_indices.collect::<Vec<_>>();
430-
let mut dm = DMatrix::<i64>::zeros(csc.nrows(), csc.ncols());
431-
csc.triplet_iter().for_each(|(r, c, v)| dm[(r, c)] = *v);
430+
let dm = csc_to_dmat(csc);
432431
CscMatrix::from(&dm.select_rows(&i))
433432
}
434433

@@ -437,9 +436,18 @@ mod csc_matrix_index_tests {
437436
I: Iterator<Item = usize>,
438437
{
439438
let j = col_indices.collect::<Vec<_>>();
439+
let dm = csc_to_dmat(csc);
440+
CscMatrix::from(&dm.select_columns(&j))
441+
}
442+
443+
fn csc_to_dmat(csc: &CscMatrix<i64>) -> DMatrix<i64> {
440444
let mut dm = DMatrix::<i64>::zeros(csc.nrows(), csc.ncols());
441445
csc.triplet_iter().for_each(|(r, c, v)| dm[(r, c)] = *v);
442-
CscMatrix::from(&dm.select_columns(&j))
446+
dm
447+
}
448+
449+
fn assert_csc_eq(mat1: CscMatrix<i64>, mat2: CscMatrix<i64>) {
450+
assert_eq!(csc_to_dmat(&mat1), csc_to_dmat(&mat2));
443451
}
444452

445453
#[test]
@@ -448,58 +456,60 @@ mod csc_matrix_index_tests {
448456
let m: usize = 200;
449457
let nnz: usize = 1000;
450458

451-
let ridx = Array::random(220, Uniform::new(0, n)).to_vec();
452-
let cidx = Array::random(100, Uniform::new(0, m)).to_vec();
453-
454-
let row_indices = Array::random(nnz, Uniform::new(0, n)).to_vec();
455-
let col_indices = Array::random(nnz, Uniform::new(0, m)).to_vec();
456-
let values = Array::random(nnz, Uniform::new(-10000, 10000)).to_vec();
457-
458-
let csc_matrix: CscMatrix<i64> =
459-
(&CooMatrix::try_from_triplets(n, m, row_indices, col_indices, values).unwrap()).into();
460-
461-
// Row slice
462-
assert_eq!(
463-
csc_matrix.select(s![2..177, ..].as_ref()),
464-
csc_select_rows(&csc_matrix, 2..177),
465-
);
466-
assert_eq!(
467-
csc_matrix.select(s![0..2, ..].as_ref()),
468-
csc_select_rows(&csc_matrix, 0..2),
469-
);
470-
471-
// Row fancy indexing
472-
assert_eq!(
473-
csc_matrix.select(s![&ridx, ..].as_ref()),
474-
csc_select_rows(&csc_matrix, ridx.iter().cloned()),
475-
);
476-
477-
// Column slice
478-
assert_eq!(
479-
csc_matrix.select(s![.., 77..200].as_ref()),
480-
csc_select_cols(&csc_matrix, 77..200),
481-
);
482-
483-
// Column fancy indexing
484-
assert_eq!(
485-
csc_matrix.select(s![.., &cidx].as_ref()),
486-
csc_select_cols(&csc_matrix, cidx.iter().cloned()),
487-
);
488-
489-
// Both
490-
assert_eq!(
491-
csc_matrix.select(s![2..49, 0..77].as_ref()),
492-
csc_select(&csc_matrix, 2..49, 0..77),
493-
);
494-
495-
assert_eq!(
496-
csc_matrix.select(s![2..177, &cidx].as_ref()),
497-
csc_select(&csc_matrix, 2..177, cidx.iter().cloned()),
498-
);
499-
500-
assert_eq!(
501-
csc_matrix.select(s![&ridx, &cidx].as_ref()),
502-
csc_select(&csc_matrix, ridx.iter().cloned(), cidx.iter().cloned()),
503-
);
459+
for _ in 0..50 {
460+
let ridx = Array::random(220, Uniform::new(0, n)).to_vec();
461+
let cidx = Array::random(100, Uniform::new(0, m)).to_vec();
462+
463+
let row_indices = Array::random(nnz, Uniform::new(0, n)).to_vec();
464+
let col_indices = Array::random(nnz, Uniform::new(0, m)).to_vec();
465+
let values = Array::random(nnz, Uniform::new(-10000, 10000)).to_vec();
466+
467+
let csc_matrix: CscMatrix<i64> =
468+
(&CooMatrix::try_from_triplets(n, m, row_indices, col_indices, values).unwrap()).into();
469+
470+
// Row slice
471+
assert_csc_eq(
472+
csc_matrix.select(s![2..177, ..].as_ref()),
473+
csc_select_rows(&csc_matrix, 2..177),
474+
);
475+
assert_csc_eq(
476+
csc_matrix.select(s![0..2, ..].as_ref()),
477+
csc_select_rows(&csc_matrix, 0..2),
478+
);
479+
480+
// Row fancy indexing
481+
assert_csc_eq(
482+
csc_matrix.select(s![&ridx, ..].as_ref()),
483+
csc_select_rows(&csc_matrix, ridx.iter().cloned()),
484+
);
485+
486+
// Column slice
487+
assert_csc_eq(
488+
csc_matrix.select(s![.., 77..200].as_ref()),
489+
csc_select_cols(&csc_matrix, 77..200),
490+
);
491+
492+
// Column fancy indexing
493+
assert_csc_eq(
494+
csc_matrix.select(s![.., &cidx].as_ref()),
495+
csc_select_cols(&csc_matrix, cidx.iter().cloned()),
496+
);
497+
498+
// Both
499+
assert_csc_eq(
500+
csc_matrix.select(s![2..49, 0..77].as_ref()),
501+
csc_select(&csc_matrix, 2..49, 0..77),
502+
);
503+
504+
assert_csc_eq(
505+
csc_matrix.select(s![2..177, &cidx].as_ref()),
506+
csc_select(&csc_matrix, 2..177, cidx.iter().cloned()),
507+
);
508+
509+
assert_csc_eq(
510+
csc_matrix.select(s![&ridx, &cidx].as_ref()),
511+
csc_select(&csc_matrix, ridx.iter().cloned(), cidx.iter().cloned()),
512+
);
513+
}
504514
}
505515
}

anndata/src/data/array/sparse/csr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ mod csr_matrix_index_tests {
545545

546546
#[test]
547547
fn test_csr() {
548-
for _ in 0..100 {
548+
for _ in 0..50 {
549549
let n: usize = 200;
550550
let m: usize = 200;
551551
let nnz: usize = 1000;

pyanndata/src/data/slice.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,34 @@ pub fn to_select_info(ob: &Bound<'_, PyAny>, shape: &Shape) -> PyResult<SelectIn
1616
}
1717
}
1818

19+
/// Converts a Python object into a `SelectInfoElem` for array indexing.
20+
///
21+
/// This function handles multiple Python indexing types and converts them to a unified
22+
/// `SelectInfoElem` representation:
23+
///
24+
/// - **PySlice**: Converts Python slice objects (e.g., `1:5:2`) into a `Slice` with
25+
/// start, stop, and step values.
26+
/// - **None**: Represents a full selection (equivalent to `:`).
27+
/// - **PyInt**: Extracts a single integer index.
28+
/// - **Boolean NumPy array**: Validates the array has dtype `bool` and matches the
29+
/// expected length, then converts to indices via `boolean_mask_to_indices`.
30+
/// - **Iterable of booleans**: Attempts to extract as a Python iterable of booleans.
31+
/// If successful and length matches, converts to indices. If length is 0, returns
32+
/// empty selection. If length mismatches, panics.
33+
/// - **Iterable of integers**: Falls back to extracting as a vector of `usize` indices.
34+
///
35+
/// # Arguments
36+
///
37+
/// * `ob` - A Python object to be interpreted as an index or mask.
38+
/// * `length` - The expected length of the dimension being indexed.
39+
///
40+
/// # Panics
41+
///
42+
/// Panics if a boolean mask's length does not match the expected `length`.
43+
///
44+
/// # Returns
45+
///
46+
/// A `SelectInfoElem` representing the selection operation.
1947
pub fn to_select_elem(ob: &Bound<'_, PyAny>, length: usize) -> PyResult<SelectInfoElem> {
2048
let select = if let Ok(slice) = ob.downcast::<pyo3::types::PySlice>() {
2149
let s = slice.indices(length.try_into()?)?;

0 commit comments

Comments
 (0)