@@ -427,8 +427,7 @@ mod csc_matrix_index_tests {
427427 I : Iterator < Item = usize > ,
428428 {
429429 let i = row_indices. collect :: < Vec < _ > > ( ) ;
430- let mut dm = DMatrix :: < i64 > :: zeros ( csc. nrows ( ) , csc. ncols ( ) ) ;
431- csc. triplet_iter ( ) . for_each ( |( r, c, v) | dm[ ( r, c) ] = * v) ;
430+ let dm = csc_to_dmat ( csc) ;
432431 CscMatrix :: from ( & dm. select_rows ( & i) )
433432 }
434433
@@ -437,9 +436,18 @@ mod csc_matrix_index_tests {
437436 I : Iterator < Item = usize > ,
438437 {
439438 let j = col_indices. collect :: < Vec < _ > > ( ) ;
439+ let dm = csc_to_dmat ( csc) ;
440+ CscMatrix :: from ( & dm. select_columns ( & j) )
441+ }
442+
443+ fn csc_to_dmat ( csc : & CscMatrix < i64 > ) -> DMatrix < i64 > {
440444 let mut dm = DMatrix :: < i64 > :: zeros ( csc. nrows ( ) , csc. ncols ( ) ) ;
441445 csc. triplet_iter ( ) . for_each ( |( r, c, v) | dm[ ( r, c) ] = * v) ;
442- CscMatrix :: from ( & dm. select_columns ( & j) )
446+ dm
447+ }
448+
449+ fn assert_csc_eq ( mat1 : CscMatrix < i64 > , mat2 : CscMatrix < i64 > ) {
450+ assert_eq ! ( csc_to_dmat( & mat1) , csc_to_dmat( & mat2) ) ;
443451 }
444452
445453 #[ test]
@@ -448,58 +456,60 @@ mod csc_matrix_index_tests {
448456 let m: usize = 200 ;
449457 let nnz: usize = 1000 ;
450458
451- let ridx = Array :: random ( 220 , Uniform :: new ( 0 , n) ) . to_vec ( ) ;
452- let cidx = Array :: random ( 100 , Uniform :: new ( 0 , m) ) . to_vec ( ) ;
453-
454- let row_indices = Array :: random ( nnz, Uniform :: new ( 0 , n) ) . to_vec ( ) ;
455- let col_indices = Array :: random ( nnz, Uniform :: new ( 0 , m) ) . to_vec ( ) ;
456- let values = Array :: random ( nnz, Uniform :: new ( -10000 , 10000 ) ) . to_vec ( ) ;
457-
458- let csc_matrix: CscMatrix < i64 > =
459- ( & CooMatrix :: try_from_triplets ( n, m, row_indices, col_indices, values) . unwrap ( ) ) . into ( ) ;
460-
461- // Row slice
462- assert_eq ! (
463- csc_matrix. select( s![ 2 ..177 , ..] . as_ref( ) ) ,
464- csc_select_rows( & csc_matrix, 2 ..177 ) ,
465- ) ;
466- assert_eq ! (
467- csc_matrix. select( s![ 0 ..2 , ..] . as_ref( ) ) ,
468- csc_select_rows( & csc_matrix, 0 ..2 ) ,
469- ) ;
470-
471- // Row fancy indexing
472- assert_eq ! (
473- csc_matrix. select( s![ & ridx, ..] . as_ref( ) ) ,
474- csc_select_rows( & csc_matrix, ridx. iter( ) . cloned( ) ) ,
475- ) ;
476-
477- // Column slice
478- assert_eq ! (
479- csc_matrix. select( s![ .., 77 ..200 ] . as_ref( ) ) ,
480- csc_select_cols( & csc_matrix, 77 ..200 ) ,
481- ) ;
482-
483- // Column fancy indexing
484- assert_eq ! (
485- csc_matrix. select( s![ .., & cidx] . as_ref( ) ) ,
486- csc_select_cols( & csc_matrix, cidx. iter( ) . cloned( ) ) ,
487- ) ;
488-
489- // Both
490- assert_eq ! (
491- csc_matrix. select( s![ 2 ..49 , 0 ..77 ] . as_ref( ) ) ,
492- csc_select( & csc_matrix, 2 ..49 , 0 ..77 ) ,
493- ) ;
494-
495- assert_eq ! (
496- csc_matrix. select( s![ 2 ..177 , & cidx] . as_ref( ) ) ,
497- csc_select( & csc_matrix, 2 ..177 , cidx. iter( ) . cloned( ) ) ,
498- ) ;
499-
500- assert_eq ! (
501- csc_matrix. select( s![ & ridx, & cidx] . as_ref( ) ) ,
502- csc_select( & csc_matrix, ridx. iter( ) . cloned( ) , cidx. iter( ) . cloned( ) ) ,
503- ) ;
459+ for _ in 0 ..50 {
460+ let ridx = Array :: random ( 220 , Uniform :: new ( 0 , n) ) . to_vec ( ) ;
461+ let cidx = Array :: random ( 100 , Uniform :: new ( 0 , m) ) . to_vec ( ) ;
462+
463+ let row_indices = Array :: random ( nnz, Uniform :: new ( 0 , n) ) . to_vec ( ) ;
464+ let col_indices = Array :: random ( nnz, Uniform :: new ( 0 , m) ) . to_vec ( ) ;
465+ let values = Array :: random ( nnz, Uniform :: new ( -10000 , 10000 ) ) . to_vec ( ) ;
466+
467+ let csc_matrix: CscMatrix < i64 > =
468+ ( & CooMatrix :: try_from_triplets ( n, m, row_indices, col_indices, values) . unwrap ( ) ) . into ( ) ;
469+
470+ // Row slice
471+ assert_csc_eq (
472+ csc_matrix. select ( s ! [ 2 ..177 , ..] . as_ref ( ) ) ,
473+ csc_select_rows ( & csc_matrix, 2 ..177 ) ,
474+ ) ;
475+ assert_csc_eq (
476+ csc_matrix. select ( s ! [ 0 ..2 , ..] . as_ref ( ) ) ,
477+ csc_select_rows ( & csc_matrix, 0 ..2 ) ,
478+ ) ;
479+
480+ // Row fancy indexing
481+ assert_csc_eq (
482+ csc_matrix. select ( s ! [ & ridx, ..] . as_ref ( ) ) ,
483+ csc_select_rows ( & csc_matrix, ridx. iter ( ) . cloned ( ) ) ,
484+ ) ;
485+
486+ // Column slice
487+ assert_csc_eq (
488+ csc_matrix. select ( s ! [ .., 77 ..200 ] . as_ref ( ) ) ,
489+ csc_select_cols ( & csc_matrix, 77 ..200 ) ,
490+ ) ;
491+
492+ // Column fancy indexing
493+ assert_csc_eq (
494+ csc_matrix. select ( s ! [ .., & cidx] . as_ref ( ) ) ,
495+ csc_select_cols ( & csc_matrix, cidx. iter ( ) . cloned ( ) ) ,
496+ ) ;
497+
498+ // Both
499+ assert_csc_eq (
500+ csc_matrix. select ( s ! [ 2 ..49 , 0 ..77 ] . as_ref ( ) ) ,
501+ csc_select ( & csc_matrix, 2 ..49 , 0 ..77 ) ,
502+ ) ;
503+
504+ assert_csc_eq (
505+ csc_matrix. select ( s ! [ 2 ..177 , & cidx] . as_ref ( ) ) ,
506+ csc_select ( & csc_matrix, 2 ..177 , cidx. iter ( ) . cloned ( ) ) ,
507+ ) ;
508+
509+ assert_csc_eq (
510+ csc_matrix. select ( s ! [ & ridx, & cidx] . as_ref ( ) ) ,
511+ csc_select ( & csc_matrix, ridx. iter ( ) . cloned ( ) , cidx. iter ( ) . cloned ( ) ) ,
512+ ) ;
513+ }
504514 }
505515}
0 commit comments