11// src/ms2pip_features.rs
22//
3- // MS2PIP feature calculation in Rust (batch + NumPy inputs), with memory-focused optimisations:
4- // - Chunked (blocked) copying from NumPy to cap peak memory.
5- // - In-place sorting for quantiles (no clone+sort).
6- // - Avoid concatenating "all-ion" vectors for Pearson/MSE/Dot/Cos where possible.
3+ // MS2PIP feature calculation in Rust (batch + flexible NumPy inputs):
4+ //
5+ // - Minimize peak memory usage by processing in blocks, and keeping f32 arrays only
6+ // while holding the GIL.
7+ // - Use Rayon for parallelism outside the GIL.
8+ // - Support arbitrary array-like inputs (lists, different dtypes, non-contiguous arrays)
9+ // by converting to contiguous np.ndarray(float32) once per input spectrum.
710
811use std:: collections:: HashMap ;
912
1013use numpy:: { PyArray1 , PyArrayMethods } ;
1114use pyo3:: exceptions:: PyValueError ;
1215use pyo3:: prelude:: * ;
16+ use pyo3:: types:: { PyAny , PyModule } ;
17+
1318use rayon:: prelude:: * ;
1419
15- /// Clip lower bound in log2 space: log2(0.001)
1620const CLIP_LOG2_MIN : f64 = -9.965_784_284_662_087 ; // (0.001_f64).log2()
1721
1822#[ inline]
@@ -33,7 +37,33 @@ fn pow2_unlog(x: f64) -> f64 {
3337
3438#[ inline]
3539fn finite_or_zero ( x : f64 ) -> f64 {
36- if x. is_finite ( ) { x } else { 0.0 }
40+ if x. is_finite ( ) {
41+ x
42+ } else {
43+ 0.0
44+ }
45+ }
46+
47+
48+ #[ inline]
49+ fn any_to_vec_f32 < ' py > (
50+ np : & ' py Bound < ' py , PyModule > ,
51+ obj : & Bound < ' py , PyAny > ,
52+ ) -> PyResult < Vec < f32 > > {
53+ // np.ascontiguousarray(obj, dtype="float32")
54+ let arr_any = np
55+ . getattr ( "ascontiguousarray" ) ?
56+ . call1 ( ( obj, "float32" ) ) ?;
57+
58+ let arr = arr_any. extract :: < & PyArray1 < f32 > > ( ) ?;
59+ let ro = arr. readonly ( ) ;
60+
61+ // With contiguity enforced, as_slice() should usually work; keep safe fallback.
62+ if let Ok ( slice) = ro. as_slice ( ) {
63+ Ok ( slice. to_vec ( ) )
64+ } else {
65+ Ok ( ro. as_array ( ) . iter ( ) . copied ( ) . collect ( ) )
66+ }
3767}
3868
3969fn pearson ( x : & [ f64 ] , y : & [ f64 ] ) -> f64 {
@@ -117,7 +147,9 @@ fn dot2(a1: &[f64], a2: &[f64], b1: &[f64], b2: &[f64]) -> f64 {
117147 }
118148 let d1 = dot ( a1, b1) ;
119149 let d2 = dot ( a2, b2) ;
120- if !d1. is_finite ( ) || !d2. is_finite ( ) { return f64:: NAN ; }
150+ if !d1. is_finite ( ) || !d2. is_finite ( ) {
151+ return f64:: NAN ;
152+ }
121153 d1 + d2
122154}
123155
@@ -246,7 +278,6 @@ fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
246278 if n == 1 {
247279 return sorted[ 0 ] ;
248280 }
249- // numpy default: linear interpolation on (n-1)*q
250281 let pos = ( n as f64 - 1.0 ) * q;
251282 let lo = pos. floor ( ) as usize ;
252283 let hi = pos. ceil ( ) as usize ;
@@ -261,7 +292,7 @@ fn ranks_average_ties(values: &[f64]) -> Vec<f64> {
261292 let n = values. len ( ) ;
262293 let mut idx: Vec < usize > = ( 0 ..n) . collect ( ) ;
263294
264- // Deterministic ordering (handles NaN consistently). If NaNs exist, caller should typically return NaN.
295+ // Deterministic ordering
265296 idx. sort_by ( |& i, & j| values[ i] . total_cmp ( & values[ j] ) ) ;
266297
267298 let mut ranks = vec ! [ 0.0 ; n] ;
@@ -287,7 +318,6 @@ fn spearman(x: &[f64], y: &[f64]) -> f64 {
287318 if x. len ( ) != y. len ( ) || x. len ( ) < 2 {
288319 return f64:: NAN ;
289320 }
290- // pandas rank/corr will propagate NaNs; we emulate that by returning NaN if any non-finite
291321 if x. iter ( ) . any ( |v| !v. is_finite ( ) ) || y. iter ( ) . any ( |v| !v. is_finite ( ) ) {
292322 return f64:: NAN ;
293323 }
@@ -300,10 +330,10 @@ fn spearman(x: &[f64], y: &[f64]) -> f64 {
300330pub fn batch_ms2pip_features_numpy (
301331 py : Python < ' _ > ,
302332 psm_indices : Vec < usize > ,
303- predicted_b : Vec < Py < PyArray1 < f32 > > > ,
304- predicted_y : Vec < Py < PyArray1 < f32 > > > ,
305- observed_b : Vec < Py < PyArray1 < f32 > > > ,
306- observed_y : Vec < Py < PyArray1 < f32 > > > ,
333+ predicted_b : Vec < Py < PyAny > > ,
334+ predicted_y : Vec < Py < PyAny > > ,
335+ observed_b : Vec < Py < PyAny > > ,
336+ observed_y : Vec < Py < PyAny > > ,
307337) -> PyResult < Vec < ( usize , HashMap < String , f64 > ) > > {
308338 let n = psm_indices. len ( ) ;
309339 if predicted_b. len ( ) != n
@@ -325,28 +355,29 @@ pub fn batch_ms2pip_features_numpy(
325355 oy : Vec < f32 > ,
326356 }
327357
328- // Main output: keep capacity to avoid reallocations.
329358 let mut out: Vec < ( usize , HashMap < String , f64 > ) > = Vec :: with_capacity ( n) ;
330359
331- // Chunking keeps peak memory bounded. Tune as needed .
360+ // Tune for peak memory vs overhead .
332361 let block_size: usize = 4096 ;
333362
363+ // Import numpy once per call.
364+ let np = PyModule :: import_bound ( py, "numpy" ) ?;
365+
334366 for start in ( 0 ..n) . step_by ( block_size) {
335367 let end = ( start + block_size) . min ( n) ;
336368
337- // ---- Copy out of NumPy while holding the GIL (only this block) ----
369+ // ---- Convert/copy while holding the GIL (only this block) ----
338370 let mut owned: Vec < Owned > = Vec :: with_capacity ( end - start) ;
339371 for i in start..end {
340- let pb = predicted_b[ i] . bind ( py) ;
341- let pyv = predicted_y[ i] . bind ( py) ;
342- let ob = observed_b[ i] . bind ( py) ;
343- let oy = observed_y[ i] . bind ( py) ;
372+ let pb_obj = predicted_b[ i] . bind ( py) ;
373+ let py_obj = predicted_y[ i] . bind ( py) ;
374+ let ob_obj = observed_b[ i] . bind ( py) ;
375+ let oy_obj = observed_y[ i] . bind ( py) ;
344376
345- // Supports non-contiguous by iterating; contiguous arrays will still be fast.
346- let pb_vec: Vec < f32 > = pb. readonly ( ) . as_array ( ) . iter ( ) . copied ( ) . collect ( ) ;
347- let py_vec: Vec < f32 > = pyv. readonly ( ) . as_array ( ) . iter ( ) . copied ( ) . collect ( ) ;
348- let ob_vec: Vec < f32 > = ob. readonly ( ) . as_array ( ) . iter ( ) . copied ( ) . collect ( ) ;
349- let oy_vec: Vec < f32 > = oy. readonly ( ) . as_array ( ) . iter ( ) . copied ( ) . collect ( ) ;
377+ let pb_vec = any_to_vec_f32 ( & np, & pb_obj) ?;
378+ let py_vec = any_to_vec_f32 ( & np, & py_obj) ?;
379+ let ob_vec = any_to_vec_f32 ( & np, & ob_obj) ?;
380+ let oy_vec = any_to_vec_f32 ( & np, & oy_obj) ?;
350381
351382 owned. push ( Owned {
352383 idx : psm_indices[ i] ,
@@ -362,7 +393,6 @@ pub fn batch_ms2pip_features_numpy(
362393 owned
363394 . into_par_iter ( )
364395 . map ( |it| {
365- // mimic Python behavior: if mismatched, return empty dict
366396 if it. pb . len ( ) != it. ob . len ( ) || it. py . len ( ) != it. oy . len ( ) {
367397 return ( it. idx , HashMap :: new ( ) ) ;
368398 }
@@ -400,7 +430,6 @@ pub fn batch_ms2pip_features_numpy(
400430 abs_all_u. extend_from_slice ( & abs_b_u) ;
401431 abs_all_u. extend_from_slice ( & abs_y_u) ;
402432
403- // mean/std before sorting
404433 let ( mean_abs_all, std_abs_all) = mean_std ( & abs_all) ;
405434 let ( mean_abs_b, std_abs_b) = mean_std ( & abs_b) ;
406435 let ( mean_abs_y, std_abs_y) = mean_std ( & abs_y) ;
@@ -409,7 +438,6 @@ pub fn batch_ms2pip_features_numpy(
409438 let ( mean_abs_b_u, std_abs_b_u) = mean_std ( & abs_b_u) ;
410439 let ( mean_abs_y_u, std_abs_y_u) = mean_std ( & abs_y_u) ;
411440
412- // sort in place for quantiles + min/max (no clone)
413441 abs_all. sort_by ( |a, b| a. total_cmp ( b) ) ;
414442 abs_b. sort_by ( |a, b| a. total_cmp ( b) ) ;
415443 abs_y. sort_by ( |a, b| a. total_cmp ( b) ) ;
@@ -454,7 +482,6 @@ pub fn batch_ms2pip_features_numpy(
454482 let q2_y_u = quantile_sorted ( & abs_y_u, 0.5 ) ;
455483 let q3_y_u = quantile_sorted ( & abs_y_u, 0.75 ) ;
456484
457- // correlations and similarities (avoid concatenating for "all" where possible)
458485 let spec_pearson_norm = pearson2 ( & tb, & ty, & pb, & pyv) ;
459486 let ionb_pearson_norm = pearson ( & tb, & pb) ;
460487 let iony_pearson_norm = pearson ( & ty, & pyv) ;
@@ -475,7 +502,6 @@ pub fn batch_ms2pip_features_numpy(
475502 let ionb_pearson = pearson ( & tb_u, & pb_u) ;
476503 let iony_pearson = pearson ( & ty_u, & py_u) ;
477504
478- // Spearman "all ions": concatenate only for this metric (keeps parity, limits memory)
479505 let mut t_all_u = Vec :: with_capacity ( tb_u. len ( ) + ty_u. len ( ) ) ;
480506 t_all_u. extend_from_slice ( & tb_u) ;
481507 t_all_u. extend_from_slice ( & ty_u) ;
0 commit comments