@@ -86,37 +86,19 @@ pub(super) fn to_array(ob: &Bound<'_, PyAny>) -> PyResult<DynArray> {
8686 Ok ( arr)
8787}
8888
89- pub ( super ) fn to_csr ( ob : & Bound < ' _ , PyAny > ) -> PyResult < DynCsrMatrix > {
90- fn extract_csr_indicies ( indicies : Bound < ' _ , PyAny > ) -> PyResult < Vec < usize > > {
91- let res = match indicies
92- . getattr ( "dtype" ) ?
93- . getattr ( "name" ) ?
94- . extract :: < & str > ( ) ?
95- {
96- "int32" => indicies
97- . extract :: < PyReadonlyArrayDyn < i32 > > ( ) ?
98- . as_array ( )
99- . iter ( )
100- . map ( |x| ( * x) . try_into ( ) . unwrap ( ) )
101- . collect ( ) ,
102- "int64" => indicies
103- . extract :: < PyReadonlyArrayDyn < i64 > > ( ) ?
104- . as_array ( )
105- . iter ( )
106- . map ( |x| ( * x) . try_into ( ) . unwrap ( ) )
107- . collect ( ) ,
108- other => panic ! ( "CSR indicies type '{}' is not supported" , other) ,
109- } ;
110- Ok ( res)
111- }
89+ fn extract_array_as_usize ( arr : Bound < ' _ , PyAny > ) -> PyResult < Vec < usize > > {
90+ arr. call_method1 ( "astype" , ( "uintp" , ) ) ?
91+ . extract :: < Vec < usize > > ( )
92+ }
11293
94+ pub ( super ) fn to_csr ( ob : & Bound < ' _ , PyAny > ) -> PyResult < DynCsrMatrix > {
11395 if !isinstance_of_csr ( ob) ? {
11496 return Err ( PyTypeError :: new_err ( "not a csr matrix" ) ) ;
11597 }
11698
11799 let shape: Vec < usize > = ob. getattr ( "shape" ) ?. extract ( ) ?;
118- let indices = extract_csr_indicies ( ob. getattr ( "indices" ) ?) ?;
119- let indptr = extract_csr_indicies ( ob. getattr ( "indptr" ) ?) ?;
100+ let indices = extract_array_as_usize ( ob. getattr ( "indices" ) ?) ?;
101+ let indptr = extract_array_as_usize ( ob. getattr ( "indptr" ) ?) ?;
120102 let ty = ob. getattr ( "data" ) ?. getattr ( "dtype" ) ?. getattr ( "name" ) ?;
121103 let ty = ty. extract :: < & str > ( ) ?;
122104
@@ -139,36 +121,13 @@ pub(super) fn to_csr(ob: &Bound<'_, PyAny>) -> PyResult<DynCsrMatrix> {
139121}
140122
141123pub ( super ) fn to_csr_noncanonical ( ob : & Bound < ' _ , PyAny > ) -> PyResult < DynCsrNonCanonical > {
142- fn extract_csr_indicies ( indicies : Bound < ' _ , PyAny > ) -> PyResult < Vec < usize > > {
143- let res = match indicies
144- . getattr ( "dtype" ) ?
145- . getattr ( "name" ) ?
146- . extract :: < & str > ( ) ?
147- {
148- "int32" => indicies
149- . extract :: < PyReadonlyArrayDyn < i32 > > ( ) ?
150- . as_array ( )
151- . iter ( )
152- . map ( |x| ( * x) . try_into ( ) . unwrap ( ) )
153- . collect ( ) ,
154- "int64" => indicies
155- . extract :: < PyReadonlyArrayDyn < i64 > > ( ) ?
156- . as_array ( )
157- . iter ( )
158- . map ( |x| ( * x) . try_into ( ) . unwrap ( ) )
159- . collect ( ) ,
160- other => panic ! ( "CSR indicies type '{}' is not supported" , other) ,
161- } ;
162- Ok ( res)
163- }
164-
165124 if !isinstance_of_csr ( ob) ? {
166125 return Err ( PyTypeError :: new_err ( "not a csr matrix" ) ) ;
167126 }
168127
169128 let shape: Vec < usize > = ob. getattr ( "shape" ) ?. extract ( ) ?;
170- let indices = extract_csr_indicies ( ob. getattr ( "indices" ) ?) ?;
171- let indptr = extract_csr_indicies ( ob. getattr ( "indptr" ) ?) ?;
129+ let indices = extract_array_as_usize ( ob. getattr ( "indices" ) ?) ?;
130+ let indptr = extract_array_as_usize ( ob. getattr ( "indptr" ) ?) ?;
172131 let ty = ob. getattr ( "data" ) ?. getattr ( "dtype" ) ?. getattr ( "name" ) ?;
173132 let ty = ty. extract :: < & str > ( ) ?;
174133
@@ -190,36 +149,13 @@ pub(super) fn to_csr_noncanonical(ob: &Bound<'_, PyAny>) -> PyResult<DynCsrNonCa
190149}
191150
192151pub ( super ) fn to_csc ( ob : & Bound < ' _ , PyAny > ) -> PyResult < DynCscMatrix > {
193- fn extract_csc_indicies ( indicies : Bound < ' _ , PyAny > ) -> PyResult < Vec < usize > > {
194- let res = match indicies
195- . getattr ( "dtype" ) ?
196- . getattr ( "name" ) ?
197- . extract :: < & str > ( ) ?
198- {
199- "int32" => indicies
200- . extract :: < PyReadonlyArrayDyn < i32 > > ( ) ?
201- . as_array ( )
202- . iter ( )
203- . map ( |x| ( * x) . try_into ( ) . unwrap ( ) )
204- . collect ( ) ,
205- "int64" => indicies
206- . extract :: < PyReadonlyArrayDyn < i64 > > ( ) ?
207- . as_array ( )
208- . iter ( )
209- . map ( |x| ( * x) . try_into ( ) . unwrap ( ) )
210- . collect ( ) ,
211- other => panic ! ( "CSC indicies type '{}' is not supported" , other) ,
212- } ;
213- Ok ( res)
214- }
215-
216152 if !isinstance_of_csc ( ob) ? {
217153 return Err ( PyTypeError :: new_err ( "not a csc matrix" ) ) ;
218154 }
219155
220156 let shape: Vec < usize > = ob. getattr ( "shape" ) ?. extract ( ) ?;
221- let indices = extract_csc_indicies ( ob. getattr ( "indices" ) ?) ?;
222- let indptr = extract_csc_indicies ( ob. getattr ( "indptr" ) ?) ?;
157+ let indices = extract_array_as_usize ( ob. getattr ( "indices" ) ?) ?;
158+ let indptr = extract_array_as_usize ( ob. getattr ( "indptr" ) ?) ?;
223159 let ty = ob. getattr ( "data" ) ?. getattr ( "dtype" ) ?. getattr ( "name" ) ?;
224160 let ty = ty. extract :: < & str > ( ) ?;
225161
0 commit comments