Skip to content

Commit 7d0952a

Browse files
adapt ms2pip features to handle all input arrays and then convert
1 parent 1033f2b commit 7d0952a

1 file changed

Lines changed: 57 additions & 31 deletions

File tree

src/ms2pip_features.rs

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
// src/ms2pip_features.rs
22
//
3-
// MS2PIP feature calculation in Rust (batch + NumPy inputs), with memory-focused optimisations:
4-
// - Chunked (blocked) copying from NumPy to cap peak memory.
5-
// - In-place sorting for quantiles (no clone+sort).
6-
// - Avoid concatenating "all-ion" vectors for Pearson/MSE/Dot/Cos where possible.
3+
// MS2PIP feature calculation in Rust (batch + flexible NumPy inputs):
4+
//
5+
// - Minimize peak memory usage by processing in blocks, and keeping f32 arrays only
6+
// while holding the GIL.
7+
// - Use Rayon for parallelism outside the GIL.
8+
// - Support arbitrary array-like inputs (lists, different dtypes, non-contiguous arrays)
9+
// by converting to contiguous np.ndarray(float32) once per input spectrum.
710

811
use std::collections::HashMap;
912

1013
use numpy::{PyArray1, PyArrayMethods};
1114
use pyo3::exceptions::PyValueError;
1215
use pyo3::prelude::*;
16+
use pyo3::types::{PyAny, PyModule};
17+
1318
use rayon::prelude::*;
1419

15-
/// Clip lower bound in log2 space: log2(0.001)
1620
const CLIP_LOG2_MIN: f64 = -9.965_784_284_662_087; // (0.001_f64).log2()
1721

1822
#[inline]
@@ -33,7 +37,33 @@ fn pow2_unlog(x: f64) -> f64 {
3337

3438
#[inline]
3539
fn finite_or_zero(x: f64) -> f64 {
36-
if x.is_finite() { x } else { 0.0 }
40+
if x.is_finite() {
41+
x
42+
} else {
43+
0.0
44+
}
45+
}
46+
47+
48+
#[inline]
49+
fn any_to_vec_f32<'py>(
50+
np: &'py Bound<'py, PyModule>,
51+
obj: &Bound<'py, PyAny>,
52+
) -> PyResult<Vec<f32>> {
53+
// np.ascontiguousarray(obj, dtype="float32")
54+
let arr_any = np
55+
.getattr("ascontiguousarray")?
56+
.call1((obj, "float32"))?;
57+
58+
let arr = arr_any.extract::<&PyArray1<f32>>()?;
59+
let ro = arr.readonly();
60+
61+
// With contiguity enforced, as_slice() should usually work; keep safe fallback.
62+
if let Ok(slice) = ro.as_slice() {
63+
Ok(slice.to_vec())
64+
} else {
65+
Ok(ro.as_array().iter().copied().collect())
66+
}
3767
}
3868

3969
fn pearson(x: &[f64], y: &[f64]) -> f64 {
@@ -117,7 +147,9 @@ fn dot2(a1: &[f64], a2: &[f64], b1: &[f64], b2: &[f64]) -> f64 {
117147
}
118148
let d1 = dot(a1, b1);
119149
let d2 = dot(a2, b2);
120-
if !d1.is_finite() || !d2.is_finite() { return f64::NAN; }
150+
if !d1.is_finite() || !d2.is_finite() {
151+
return f64::NAN;
152+
}
121153
d1 + d2
122154
}
123155

@@ -246,7 +278,6 @@ fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
246278
if n == 1 {
247279
return sorted[0];
248280
}
249-
// numpy default: linear interpolation on (n-1)*q
250281
let pos = (n as f64 - 1.0) * q;
251282
let lo = pos.floor() as usize;
252283
let hi = pos.ceil() as usize;
@@ -261,7 +292,7 @@ fn ranks_average_ties(values: &[f64]) -> Vec<f64> {
261292
let n = values.len();
262293
let mut idx: Vec<usize> = (0..n).collect();
263294

264-
// Deterministic ordering (handles NaN consistently). If NaNs exist, caller should typically return NaN.
295+
// Deterministic ordering
265296
idx.sort_by(|&i, &j| values[i].total_cmp(&values[j]));
266297

267298
let mut ranks = vec![0.0; n];
@@ -287,7 +318,6 @@ fn spearman(x: &[f64], y: &[f64]) -> f64 {
287318
if x.len() != y.len() || x.len() < 2 {
288319
return f64::NAN;
289320
}
290-
// pandas rank/corr will propagate NaNs; we emulate that by returning NaN if any non-finite
291321
if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) {
292322
return f64::NAN;
293323
}
@@ -300,10 +330,10 @@ fn spearman(x: &[f64], y: &[f64]) -> f64 {
300330
pub fn batch_ms2pip_features_numpy(
301331
py: Python<'_>,
302332
psm_indices: Vec<usize>,
303-
predicted_b: Vec<Py<PyArray1<f32>>>,
304-
predicted_y: Vec<Py<PyArray1<f32>>>,
305-
observed_b: Vec<Py<PyArray1<f32>>>,
306-
observed_y: Vec<Py<PyArray1<f32>>>,
333+
predicted_b: Vec<Py<PyAny>>,
334+
predicted_y: Vec<Py<PyAny>>,
335+
observed_b: Vec<Py<PyAny>>,
336+
observed_y: Vec<Py<PyAny>>,
307337
) -> PyResult<Vec<(usize, HashMap<String, f64>)>> {
308338
let n = psm_indices.len();
309339
if predicted_b.len() != n
@@ -325,28 +355,29 @@ pub fn batch_ms2pip_features_numpy(
325355
oy: Vec<f32>,
326356
}
327357

328-
// Main output: keep capacity to avoid reallocations.
329358
let mut out: Vec<(usize, HashMap<String, f64>)> = Vec::with_capacity(n);
330359

331-
// Chunking keeps peak memory bounded. Tune as needed.
360+
// Tune for peak memory vs overhead.
332361
let block_size: usize = 4096;
333362

363+
// Import numpy once per call.
364+
let np = PyModule::import_bound(py, "numpy")?;
365+
334366
for start in (0..n).step_by(block_size) {
335367
let end = (start + block_size).min(n);
336368

337-
// ---- Copy out of NumPy while holding the GIL (only this block) ----
369+
// ---- Convert/copy while holding the GIL (only this block) ----
338370
let mut owned: Vec<Owned> = Vec::with_capacity(end - start);
339371
for i in start..end {
340-
let pb = predicted_b[i].bind(py);
341-
let pyv = predicted_y[i].bind(py);
342-
let ob = observed_b[i].bind(py);
343-
let oy = observed_y[i].bind(py);
372+
let pb_obj = predicted_b[i].bind(py);
373+
let py_obj = predicted_y[i].bind(py);
374+
let ob_obj = observed_b[i].bind(py);
375+
let oy_obj = observed_y[i].bind(py);
344376

345-
// Supports non-contiguous by iterating; contiguous arrays will still be fast.
346-
let pb_vec: Vec<f32> = pb.readonly().as_array().iter().copied().collect();
347-
let py_vec: Vec<f32> = pyv.readonly().as_array().iter().copied().collect();
348-
let ob_vec: Vec<f32> = ob.readonly().as_array().iter().copied().collect();
349-
let oy_vec: Vec<f32> = oy.readonly().as_array().iter().copied().collect();
377+
let pb_vec = any_to_vec_f32(&np, &pb_obj)?;
378+
let py_vec = any_to_vec_f32(&np, &py_obj)?;
379+
let ob_vec = any_to_vec_f32(&np, &ob_obj)?;
380+
let oy_vec = any_to_vec_f32(&np, &oy_obj)?;
350381

351382
owned.push(Owned {
352383
idx: psm_indices[i],
@@ -362,7 +393,6 @@ pub fn batch_ms2pip_features_numpy(
362393
owned
363394
.into_par_iter()
364395
.map(|it| {
365-
// mimic Python behavior: if mismatched, return empty dict
366396
if it.pb.len() != it.ob.len() || it.py.len() != it.oy.len() {
367397
return (it.idx, HashMap::new());
368398
}
@@ -400,7 +430,6 @@ pub fn batch_ms2pip_features_numpy(
400430
abs_all_u.extend_from_slice(&abs_b_u);
401431
abs_all_u.extend_from_slice(&abs_y_u);
402432

403-
// mean/std before sorting
404433
let (mean_abs_all, std_abs_all) = mean_std(&abs_all);
405434
let (mean_abs_b, std_abs_b) = mean_std(&abs_b);
406435
let (mean_abs_y, std_abs_y) = mean_std(&abs_y);
@@ -409,7 +438,6 @@ pub fn batch_ms2pip_features_numpy(
409438
let (mean_abs_b_u, std_abs_b_u) = mean_std(&abs_b_u);
410439
let (mean_abs_y_u, std_abs_y_u) = mean_std(&abs_y_u);
411440

412-
// sort in place for quantiles + min/max (no clone)
413441
abs_all.sort_by(|a, b| a.total_cmp(b));
414442
abs_b.sort_by(|a, b| a.total_cmp(b));
415443
abs_y.sort_by(|a, b| a.total_cmp(b));
@@ -454,7 +482,6 @@ pub fn batch_ms2pip_features_numpy(
454482
let q2_y_u = quantile_sorted(&abs_y_u, 0.5);
455483
let q3_y_u = quantile_sorted(&abs_y_u, 0.75);
456484

457-
// correlations and similarities (avoid concatenating for "all" where possible)
458485
let spec_pearson_norm = pearson2(&tb, &ty, &pb, &pyv);
459486
let ionb_pearson_norm = pearson(&tb, &pb);
460487
let iony_pearson_norm = pearson(&ty, &pyv);
@@ -475,7 +502,6 @@ pub fn batch_ms2pip_features_numpy(
475502
let ionb_pearson = pearson(&tb_u, &pb_u);
476503
let iony_pearson = pearson(&ty_u, &py_u);
477504

478-
// Spearman "all ions": concatenate only for this metric (keeps parity, limits memory)
479505
let mut t_all_u = Vec::with_capacity(tb_u.len() + ty_u.len());
480506
t_all_u.extend_from_slice(&tb_u);
481507
t_all_u.extend_from_slice(&ty_u);

0 commit comments

Comments
 (0)