Skip to content

Commit e83615a

Browse files
authored
Allow any index type when creating sparse matrix (#18)
1 parent 5b23d8f commit e83615a

5 files changed

Lines changed: 21 additions & 82 deletions

File tree

.github/workflows/ci.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ jobs:
1818
uses: Swatinem/rust-cache@v2
1919

2020
- name: Test Rust package
21-
run: |
22-
cd ${GITHUB_WORKSPACE}/anndata-hdf5 && cargo test --no-fail-fast
23-
cd ${GITHUB_WORKSPACE}/anndata && cargo test --no-fail-fast
24-
cd ${GITHUB_WORKSPACE}/anndata-test-utils && cargo test --no-fail-fast
21+
run: cargo test --no-fail-fast
2522

2623
#- name: benchmark
2724
# run: |

pyanndata/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ rayon = "1.11"
2828

2929
[dependencies.pyo3]
3030
version = "0.25"
31-
features = ["extension-module", "multiple-pymethods", "anyhow"]
31+
features = ["multiple-pymethods", "anyhow"]
32+
33+
[features]
34+
extension-module = ["pyo3/extension-module"]
3235

3336
[lib]
3437
crate-type = ["lib"]

pyanndata/src/data/array.rs

Lines changed: 11 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -86,37 +86,19 @@ pub(super) fn to_array(ob: &Bound<'_, PyAny>) -> PyResult<DynArray> {
8686
Ok(arr)
8787
}
8888

89-
pub(super) fn to_csr(ob: &Bound<'_, PyAny>) -> PyResult<DynCsrMatrix> {
90-
fn extract_csr_indicies(indicies: Bound<'_, PyAny>) -> PyResult<Vec<usize>> {
91-
let res = match indicies
92-
.getattr("dtype")?
93-
.getattr("name")?
94-
.extract::<&str>()?
95-
{
96-
"int32" => indicies
97-
.extract::<PyReadonlyArrayDyn<i32>>()?
98-
.as_array()
99-
.iter()
100-
.map(|x| (*x).try_into().unwrap())
101-
.collect(),
102-
"int64" => indicies
103-
.extract::<PyReadonlyArrayDyn<i64>>()?
104-
.as_array()
105-
.iter()
106-
.map(|x| (*x).try_into().unwrap())
107-
.collect(),
108-
other => panic!("CSR indicies type '{}' is not supported", other),
109-
};
110-
Ok(res)
111-
}
89+
fn extract_array_as_usize(arr: Bound<'_, PyAny>) -> PyResult<Vec<usize>> {
90+
arr.call_method1("astype", ("uintp",))?
91+
.extract::<Vec<usize>>()
92+
}
11293

94+
pub(super) fn to_csr(ob: &Bound<'_, PyAny>) -> PyResult<DynCsrMatrix> {
11395
if !isinstance_of_csr(ob)? {
11496
return Err(PyTypeError::new_err("not a csr matrix"));
11597
}
11698

11799
let shape: Vec<usize> = ob.getattr("shape")?.extract()?;
118-
let indices = extract_csr_indicies(ob.getattr("indices")?)?;
119-
let indptr = extract_csr_indicies(ob.getattr("indptr")?)?;
100+
let indices = extract_array_as_usize(ob.getattr("indices")?)?;
101+
let indptr = extract_array_as_usize(ob.getattr("indptr")?)?;
120102
let ty = ob.getattr("data")?.getattr("dtype")?.getattr("name")?;
121103
let ty = ty.extract::<&str>()?;
122104

@@ -139,36 +121,13 @@ pub(super) fn to_csr(ob: &Bound<'_, PyAny>) -> PyResult<DynCsrMatrix> {
139121
}
140122

141123
pub(super) fn to_csr_noncanonical(ob: &Bound<'_, PyAny>) -> PyResult<DynCsrNonCanonical> {
142-
fn extract_csr_indicies(indicies: Bound<'_, PyAny>) -> PyResult<Vec<usize>> {
143-
let res = match indicies
144-
.getattr("dtype")?
145-
.getattr("name")?
146-
.extract::<&str>()?
147-
{
148-
"int32" => indicies
149-
.extract::<PyReadonlyArrayDyn<i32>>()?
150-
.as_array()
151-
.iter()
152-
.map(|x| (*x).try_into().unwrap())
153-
.collect(),
154-
"int64" => indicies
155-
.extract::<PyReadonlyArrayDyn<i64>>()?
156-
.as_array()
157-
.iter()
158-
.map(|x| (*x).try_into().unwrap())
159-
.collect(),
160-
other => panic!("CSR indicies type '{}' is not supported", other),
161-
};
162-
Ok(res)
163-
}
164-
165124
if !isinstance_of_csr(ob)? {
166125
return Err(PyTypeError::new_err("not a csr matrix"));
167126
}
168127

169128
let shape: Vec<usize> = ob.getattr("shape")?.extract()?;
170-
let indices = extract_csr_indicies(ob.getattr("indices")?)?;
171-
let indptr = extract_csr_indicies(ob.getattr("indptr")?)?;
129+
let indices = extract_array_as_usize(ob.getattr("indices")?)?;
130+
let indptr = extract_array_as_usize(ob.getattr("indptr")?)?;
172131
let ty = ob.getattr("data")?.getattr("dtype")?.getattr("name")?;
173132
let ty = ty.extract::<&str>()?;
174133

@@ -190,36 +149,13 @@ pub(super) fn to_csr_noncanonical(ob: &Bound<'_, PyAny>) -> PyResult<DynCsrNonCa
190149
}
191150

192151
pub(super) fn to_csc(ob: &Bound<'_, PyAny>) -> PyResult<DynCscMatrix> {
193-
fn extract_csc_indicies(indicies: Bound<'_, PyAny>) -> PyResult<Vec<usize>> {
194-
let res = match indicies
195-
.getattr("dtype")?
196-
.getattr("name")?
197-
.extract::<&str>()?
198-
{
199-
"int32" => indicies
200-
.extract::<PyReadonlyArrayDyn<i32>>()?
201-
.as_array()
202-
.iter()
203-
.map(|x| (*x).try_into().unwrap())
204-
.collect(),
205-
"int64" => indicies
206-
.extract::<PyReadonlyArrayDyn<i64>>()?
207-
.as_array()
208-
.iter()
209-
.map(|x| (*x).try_into().unwrap())
210-
.collect(),
211-
other => panic!("CSC indicies type '{}' is not supported", other),
212-
};
213-
Ok(res)
214-
}
215-
216152
if !isinstance_of_csc(ob)? {
217153
return Err(PyTypeError::new_err("not a csc matrix"));
218154
}
219155

220156
let shape: Vec<usize> = ob.getattr("shape")?.extract()?;
221-
let indices = extract_csc_indicies(ob.getattr("indices")?)?;
222-
let indptr = extract_csc_indicies(ob.getattr("indptr")?)?;
157+
let indices = extract_array_as_usize(ob.getattr("indices")?)?;
158+
let indptr = extract_array_as_usize(ob.getattr("indptr")?)?;
223159
let ty = ob.getattr("data")?.getattr("dtype")?.getattr("name")?;
224160
let ty = ty.extract::<&str>()?;
225161

python/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ pyo3-log = "0.12"
1515

1616
[dependencies.pyo3]
1717
version = "0.25"
18-
features = ["extension-module", "multiple-pymethods"]
18+
features = ["multiple-pymethods"]
19+
20+
[features]
21+
extension-module = ["pyo3/extension-module", "pyanndata/extension-module"]
1922

2023
[lib]
2124
name = "anndata_rs"

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ requires = ["maturin>=1.4,<2.0"]
33
build-backend = "maturin"
44

55
[tool.maturin]
6-
features = ["pyo3/extension-module"]
6+
features = ["extension-module"]
77
module-name = "anndata_rs"
88

99
[project]

0 commit comments

Comments
 (0)