diff --git a/Cargo.lock b/Cargo.lock index 5854c8a..678fad7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -241,6 +241,26 @@ dependencies = [ "vsimd", ] +[[package]] +name = "bincode" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" +dependencies = [ + "bincode_derive", + "serde", + "unty", +] + +[[package]] +name = "bincode_derive" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" +dependencies = [ + "virtue", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -300,6 +320,20 @@ name = "bytemuck" version = "1.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3995eaeebcdf32f91f980d360f78732ddc061097ab4e39991ae7a6ace9194677" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] [[package]] name = "byteorder" @@ -410,6 +444,15 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" +[[package]] +name = "context_error" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da7e1b8dc6f4cdc4f6b897d6aa1b7eaec6d95331bdb765d2a51cdd948e157ee0" +dependencies = [ + "serde", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -745,6 +788,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "font-types" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39a654f404bbcbd48ea58c617c2993ee91d1cb63727a37bf2323a4edeed1b8c5" +dependencies = [ + "bytemuck", +] + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -1232,6 +1284,15 @@ dependencies = [ "serde", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.15" @@ -1429,6 +1490,16 @@ dependencies = [ "twox-hash 2.1.2", ] +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "md5" version = "0.7.0" @@ -1484,7 +1555,11 @@ name = "ms2rescore-rs" version = "0.4.3" dependencies = [ "mzdata", + "numpy", + "ordered-float 5.1.0", "pyo3", + "rayon", + "rustyms", "timsrust", ] @@ -1506,8 +1581,10 @@ dependencies = [ "md5", "memchr", "mzpeaks", + "mzsignal", "num-traits", "quick-xml", + "rayon", "regex", "sha1", "thermorawfilereader", @@ -1523,6 +1600,36 @@ dependencies = [ "num-traits", ] +[[package]] +name = "mzsignal" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d1bba49ea594cc8898df3977f9c2195ada0da4ee2ef2d057c41f00d1402d271" +dependencies = [ + "cfg-if", + "libm", + "log", + "mzpeaks", + "num-traits", + "rayon", + "thiserror 2.0.17", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "netcorehost" version = "0.18.0" @@ -1581,6 +1688,7 @@ checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", + "serde", ] [[package]] @@ -1590,6 +1698,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ "num-traits", + "serde", ] [[package]] @@ -1627,6 +1736,7 @@ dependencies = [ "num-bigint", "num-integer", "num-traits", + "serde", ] [[package]] @@ -1660,6 +1770,21 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "numpy" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b94caae805f998a07d33af06e6a3891e38556051b8045c615470a71590e13e78" +dependencies = [ + "libc", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "pyo3", + "rustc-hash", +] + [[package]] name = "object" version = "0.37.3" @@ -1684,6 +1809,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ordered-float" +version = "5.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4779c6901a562440c3786d08192c6fbda7c1c2060edd10006b05ee35d10f2d" +dependencies = [ + "num-traits", + "rand 0.8.5", + "serde", +] + [[package]] name = "outref" version = "0.5.2" @@ -1769,6 +1905,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.3" @@ -1799,6 +1944,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "probability" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42746b805e424b759d46c22c65dc66ccca057a2db96e9db4fda6c337a287e485" +dependencies = [ + "random", + "special", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -1911,7 +2066,7 @@ dependencies = [ "bytes", "getrandom 0.3.3", "lru-slab", - "rand", + "rand 0.9.2", "ring", "rustc-hash", "rustls", @@ -1952,6 +2107,16 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "rand_core 0.6.4", + "serde", +] + [[package]] name = "rand" version = "0.9.2" @@ -1959,7 +2124,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha", - "rand_core", + "rand_core 0.9.3", ] [[package]] @@ -1969,7 +2134,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "serde", ] [[package]] @@ -1981,6 +2155,18 @@ dependencies = [ "getrandom 0.3.3", ] +[[package]] +name = "random" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "474c42c904f04dfe2a595a02f71e1a0e5e92ffb5761cc9a4c02140b93b8dd504" + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.11.0" @@ -2001,6 +2187,16 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "read-fonts" +version = "0.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6717cf23b488adf64b9d711329542ba34de147df262370221940dfabc2c91358" +dependencies = [ + "bytemuck", + "font-types", +] + [[package]] name = "regex" version = "1.11.3" @@ -2173,6 +2369,33 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "rustyms" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "011d3d672ae44d5e07db0488d855f2b5ed178e3d6bb7ef5b18c6415c20bbd61e" +dependencies = [ + "bincode", + "context_error", + "flate2", + "itertools", + "mzdata", + "ndarray", + "ordered-float 5.1.0", + "paste", + "probability", + "rand 0.9.2", + "rayon", + "regex", + "serde", + "serde_json", + "similar", + "swash", + "thin-vec", + "uom", + "zeno", +] + [[package]] name = "ryu" version = "1.0.20" @@ -2269,6 +2492,22 @@ version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + +[[package]] +name = "skrifa" +version = "0.37.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c31071dedf532758ecf3fed987cdb4bd9509f900e026ab684b4ecb81ea49841" +dependencies = [ + "bytemuck", + "read-fonts", +] + [[package]] name = "slab" version = "0.4.11" @@ -2297,6 +2536,15 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "special" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b89cf0d71ae639fdd8097350bfac415a41aabf1d5ddd356295fdc95f09760382" +dependencies = [ + "libm", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -2315,6 +2563,17 @@ version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" +[[package]] +name = "swash" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47846491253e976bdd07d0f9cc24b7daf24720d11309302ccbbc6e6b6e53550a" +dependencies = [ + "skrifa", + "yazi", + "zeno", +] + [[package]] name = "syn" version = "1.0.109" @@ -2388,6 +2647,15 @@ dependencies = [ "netcorehost", ] +[[package]] +name = "thin-vec" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "144f754d318415ac792f9d69fc87abbbfc043ce2ef041c60f16ad828f638717d" +dependencies = [ + "serde", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -2436,7 +2704,7 @@ checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09" dependencies = [ "byteorder", "integer-encoding", - "ordered-float", + "ordered-float 2.10.1", ] [[package]] @@ -2653,6 +2921,26 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "unty" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" + +[[package]] +name = "uom" +version = "0.37.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd5cfe7d84f6774726717f358a37f5bca8fca273bed4de40604ad129d1107b49" +dependencies = [ + "num-bigint", + "num-complex", + "num-rational", + "num-traits", + "serde", + "typenum", +] + [[package]] name = "url" version = "2.5.7" @@ -2683,6 +2971,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "virtue" +version = "0.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" + [[package]] name = "vsimd" version = "0.8.0" @@ -3087,6 +3381,12 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" +[[package]] +name = "yazi" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01738255b5a16e78bbb83e7fbba0a1e7dd506905cfc53f4622d89015a03fbb5" + [[package]] name = "yoke" version = "0.8.0" @@ -3111,6 +3411,12 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zeno" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6df3dc4292935e51816d896edcd52aa30bc297907c26167fec31e2b0c6a32524" + [[package]] name = "zerocopy" version = "0.8.27" diff --git a/Cargo.toml b/Cargo.toml index 93b1d71..a4e25c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ms2rescore-rs" -version = "0.4.3" +version = "0.5.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -12,6 +12,10 @@ crate-type = ["cdylib"] default = [] [dependencies] -mzdata = { version = "0.59.2", features = ["thermo"] } +mzdata = { version = "0.59.2", features = ["thermo", "parallelism"] } pyo3 = { version = "0.23.3", features = ["anyhow"] } +rayon = "1.10" timsrust = "0.4.1" +rustyms = "0.11" +ordered-float = "5" +numpy = "0.23" diff --git a/src/lib.rs b/src/lib.rs index f096566..fe31bb4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,10 +3,12 @@ mod ms2_spectrum; mod parse_mzdata; mod parse_timsrust; mod precursor; +mod ms2_features; +mod ms2pip_features; use std::collections::HashMap; -use pyo3::exceptions::{PyException, PyValueError}; +use pyo3::exceptions::PyException; use pyo3::prelude::*; use file_types::{match_file_type, SpectrumFileType}; @@ -23,17 +25,20 @@ pub fn is_supported_file_type(spectrum_path: String) -> bool { /// Get mapping of spectrum identifiers to precursor information. #[pyfunction] -pub fn get_precursor_info(spectrum_path: String) -> PyResult> { +pub fn get_precursor_info(py: Python<'_>, spectrum_path: String) -> PyResult> { let file_type = match_file_type(&spectrum_path); - let precursors = match file_type { + let precursors = py.allow_threads(|| match file_type { SpectrumFileType::MascotGenericFormat | SpectrumFileType::MzML | SpectrumFileType::MzMLb | SpectrumFileType::ThermoRaw => parse_mzdata::parse_precursor_info(&spectrum_path), SpectrumFileType::BrukerRaw => parse_timsrust::parse_precursor_info(&spectrum_path), - SpectrumFileType::Unknown => return Err(PyValueError::new_err("Unsupported file type")), - }; + SpectrumFileType::Unknown => Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Unsupported file type", + )), + }); match precursors { Ok(precursors) => Ok(precursors), @@ -43,17 +48,20 @@ pub fn get_precursor_info(spectrum_path: String) -> PyResult PyResult> { +pub fn get_ms2_spectra(py: Python<'_>, spectrum_path: String) -> PyResult> { let file_type = match_file_type(&spectrum_path); - let spectra = match file_type { + let spectra = py.allow_threads(|| match file_type { SpectrumFileType::MascotGenericFormat | SpectrumFileType::MzML | SpectrumFileType::MzMLb | SpectrumFileType::ThermoRaw => parse_mzdata::read_ms2_spectra(&spectrum_path), SpectrumFileType::BrukerRaw => parse_timsrust::read_ms2_spectra(&spectrum_path), - SpectrumFileType::Unknown => return Err(PyValueError::new_err("Unsupported file type")), - }; + SpectrumFileType::Unknown => Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Unsupported file type", + )), + }); match spectra { Ok(spectra) => Ok(spectra), @@ -69,5 +77,10 @@ fn ms2rescore_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(is_supported_file_type, m)?)?; m.add_function(wrap_pyfunction!(get_precursor_info, m)?)?; m.add_function(wrap_pyfunction!(get_ms2_spectra, m)?)?; + m.add_function(wrap_pyfunction!( + ms2_features::batch_ms2_features_from_spectra, + m + )?)?; + m.add_function(wrap_pyfunction!(ms2pip_features::batch_ms2pip_features_numpy, m)?)?; Ok(()) } diff --git a/src/ms2_features.rs b/src/ms2_features.rs new file mode 100644 index 0000000..206dbe6 --- /dev/null +++ b/src/ms2_features.rs @@ -0,0 +1,358 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use pyo3::exceptions::{PyException, PyValueError}; +use pyo3::prelude::*; +use rayon::prelude::*; + +use crate::ms2_spectrum::MS2Spectrum; + +use rustyms::annotation::model::FragmentationModel; +use rustyms::chemistry::MassMode; +use rustyms::prelude::CompoundPeptidoformIon; +use rustyms::annotation::AnnotatableSpectrum; +use rustyms::spectrum::{RawPeak, RawSpectrum, PeakSpectrum}; +use rustyms::system::f64::MassOverCharge; +use rustyms::system::mass_over_charge::thomson; +use ordered_float::OrderedFloat; + + +fn parse_fragmentation_model(s: &str) -> PyResult { + match s.trim().to_ascii_lowercase().as_str() { + "cidhcd" | "cid_hcd" | "cid-hcd" => Ok((*FragmentationModel::cid_hcd()).clone()), + "etd" => Ok((*FragmentationModel::etd()).clone()), + "ethcd" | "et+hcd" | "et_hcd" => Ok((*FragmentationModel::ethcd()).clone()), + "all" => Ok((*FragmentationModel::all()).clone()), + other => Err(PyValueError::new_err(format!( + "Unsupported fragmentation_model: {other}. Expected one of: cidhcd, etd, ethcd, all." + ))), + } +} + +fn parse_mass_mode(s: &str) -> PyResult { + match s.trim().to_ascii_lowercase().as_str() { + "monoisotopic" | "mono" => Ok(MassMode::Monoisotopic), + "average" | "avg" => Ok(MassMode::Average), + other => Err(PyValueError::new_err(format!( + "Unsupported mass_mode: {other}. Expected: monoisotopic, average." + ))), + } +} + + +// ---- Hyperscore helpers (stable; avoids factorial overflow) ---- + +fn ln_factorial(n: usize) -> f64 { + (1..=n).map(|k| (k as f64).ln()).sum() +} + +fn hyperscore(ny: usize, nb: usize, sum_y: f64, sum_b: f64) -> f64 { + let sum = if (sum_y + sum_b) > 0.0 { sum_y + sum_b } else { 1.0 }; + ln_factorial(ny) + ln_factorial(nb) + sum.ln() +} + +// ---- Feature helpers ---- + +fn longest_true_run(flags: &[bool]) -> usize { + let mut max_run = 0usize; + let mut cur = 0usize; + for &v in flags { + if v { + cur += 1; + max_run = max_run.max(cur); + } else { + cur = 0; + } + } + max_run +} + +/// Parse an ion string like "b5", "y7", possibly with extra suffixes (e.g. charge notation). +/// We extract the leading series letter and the first contiguous digit run. +fn parse_ion_series_and_index(ion: &str) -> Option<(char, usize)> { + let ion = ion.trim(); + let mut chars = ion.chars(); + let series = chars.next()?; + if series != 'b' && series != 'y' { + return None; + } + let rest: String = chars.collect(); + let digits: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect(); + if digits.is_empty() { + return None; + } + let idx = digits.parse::().ok()?; + Some((series, idx)) +} + +// ---- Core batch function ---- + +#[pyfunction] +pub fn batch_ms2_features_from_spectra( + py: Python<'_>, + spectra: Vec>, + proformas: Vec, + seq_lens: Vec, + fragmentation_model: String, + mass_mode: String, + calculate_hyperscore: bool, +) -> PyResult>> { + + let n = spectra.len(); + if proformas.len() != n || seq_lens.len() != n { + return Err(PyException::new_err( + "Input arrays must have identical length: spectra, proformas, seq_lens", + )); + } + + // ---- Copy spectrum data out of Python objects (must hold GIL) ---- + #[derive(Clone)] + struct OwnedSpec { + id: String, + mz: Vec, + intensity: Vec, + seq_len: usize, + precursor_charge: i32, + proforma: String, + } + + let mut owned: Vec = Vec::with_capacity(n); + for i in 0..n { + let spec_ref = spectra[i].bind(py); + let spec = spec_ref.borrow(); + + owned.push(OwnedSpec { + id: spec.identifier.clone(), + mz: spec.mz.iter().map(|&x| x as f64).collect(), + intensity: spec.intensity.iter().map(|&x| x as f64).collect(), + seq_len: seq_lens[i], + precursor_charge: spec + .precursor + .as_ref() + .map(|p| p.charge as i32) + .unwrap_or(0), + // mimic your Python: psm.peptidoform.proforma.split("/")[0] + proforma: proformas[i].split('/').next().unwrap_or(&proformas[i]).to_string(), + }); + } + + // ---- Configure rustyms model/mode ---- + let model = parse_fragmentation_model(&fragmentation_model)?; + let mode = parse_mass_mode(&mass_mode)?; + + + + // Matching parameters (tolerance etc.). Start with default; expose knobs later. + let params = rustyms::annotation::model::MatchingParameters::default(); + + // ---- Precompute theoretical fragments per unique peptide+charge+model ---- + // Charge included because fragment charge handling depends on precursor charge. + type FragList = Vec; + + let mut frag_cache: HashMap<(String, i32), Arc> = HashMap::new(); + for item in &owned { + let key = (item.proforma.clone(), item.precursor_charge); + if item.precursor_charge <= 0 { + frag_cache.insert((item.proforma.clone(), item.precursor_charge), Arc::new(Vec::new())); + continue; + } + if frag_cache.contains_key(&key) { + continue; + } + + // Parse peptide + let peptide = match CompoundPeptidoformIon::pro_forma(&item.proforma, None) + { + Ok(p) => p, + Err(_) => { + // Store empty fragments to mimic your Python behavior: return [] on parse/annotate failure + frag_cache.insert(key, Arc::new(Vec::new())); + continue; + } + }; + + // Generate theoretical fragments up to the precursor charge. + // If you want to cap this (e.g., min(charge, 2)) for speed, do it here. + let frag_charge = rustyms::system::isize::Charge::new::( + item.precursor_charge as isize, + ); + let frags = peptide.generate_theoretical_fragments(frag_charge, &model); + frag_cache.insert(key, Arc::new(frags)); + } + + let frag_cache = Arc::new(frag_cache); + let params = Arc::new(params); + + // ---- Heavy work: release GIL and parallelize ---- + let results: Result>, String> = py.allow_threads(|| { + owned + .into_par_iter() + .map(|item| { + if item.mz.len() != item.intensity.len() { + return Err(format!("Spectrum {}: mz/intensity length mismatch", item.id)); + } + if item.seq_len == 0 { + return Ok(HashMap::new()); + } + if item.precursor_charge <= 0 { + return Ok(HashMap::new()); + } + + let key = (item.proforma.clone(), item.precursor_charge); + let empty: FragList = Vec::new(); + let frags: &FragList = frag_cache + .get(&key) + .map(|x| x.as_ref()) + .unwrap_or(&empty); + + + if frags.is_empty() { + return Ok(HashMap::new()); + } + + // Parse peptide again for annotation call (cheap compared to fragment generation; you can cache peptide too later) + let peptide = match CompoundPeptidoformIon::pro_forma(&item.proforma, None) { + Ok(p) => p, + Err(_) => return Ok(HashMap::new()), + }; + + // Build a RawSpectrum (rustyms) + let mut spectrum = RawSpectrum::default(); + spectrum.title = item.id.clone(); + spectrum.num_scans = 1; + + // Build peaks and extend into spectrum + let peaks: Vec = item + .mz + .iter() + .zip(item.intensity.iter()) + .map(|(&mz, &inten)| RawPeak { + mz: MassOverCharge::new::(mz), + intensity: OrderedFloat(inten), + }) + .collect(); + + spectrum.extend(peaks); + + + + // Annotate against precomputed fragments + let annotated = spectrum.annotate(peptide, frags.as_slice(), ¶ms, mode); + + // ---- Compute features matching your Python behavior ---- + let seq_len = item.seq_len; + let mut b_flags = vec![false; seq_len]; + let mut y_flags = vec![false; seq_len]; + + let pseudo = 1e-5_f64; + let mut total_intensity = 0.0_f64; + let mut matched_intensity = 0.0_f64; + + let mut b_sum = 0.0_f64; + let mut y_sum = 0.0_f64; + + // For hyperscore parity: count “matched b/y intensities” per fragment annotation, not per peak + let mut b_ints: Vec = Vec::new(); + let mut y_ints: Vec = Vec::new(); + + for peak in annotated.spectrum() { + let inten = peak.intensity.into_inner(); + total_intensity += inten; + + if !peak.annotation.is_empty() { + matched_intensity += inten; + + for frag in peak.annotation.iter() { + // Convert ion id to string and parse leading b/y + index + let ion_str = frag.ion.to_string(); + if let Some((series, idx)) = parse_ion_series_and_index(&ion_str) { + if idx < seq_len { + match series { + 'b' => { + b_sum += inten; + b_flags[idx] = true; + if calculate_hyperscore { + b_ints.push(inten); + } + } + 'y' => { + y_sum += inten; + y_flags[idx] = true; + if calculate_hyperscore { + y_ints.push(inten); + } + } + _ => {} + } + } + } + } + } + } + + let matched_b = b_flags.iter().filter(|&&x| x).count(); + let matched_y = y_flags.iter().filter(|&&x| x).count(); + + let mut feats: HashMap = HashMap::new(); + feats.insert("ln_explained_intensity".to_string(), (matched_intensity + pseudo).ln()); + feats.insert("ln_total_intensity".to_string(), (total_intensity + pseudo).ln()); + + let explained_ratio = if total_intensity > 0.0 { + (matched_intensity / total_intensity + pseudo).ln() + } else { + pseudo.ln() + }; + feats.insert("ln_explained_intensity_ratio".to_string(), explained_ratio); + + let b_ratio = if matched_intensity > 0.0 { + (b_sum / matched_intensity + pseudo).ln() + } else { + pseudo.ln() + }; + feats.insert("ln_explained_b_ion_ratio".to_string(), b_ratio); + + let y_ratio = if matched_intensity > 0.0 { + (y_sum / matched_intensity + pseudo).ln() + } else { + pseudo.ln() + }; + feats.insert("ln_explained_y_ion_ratio".to_string(), y_ratio); + + feats.insert( + "longest_b_ion_sequence".to_string(), + longest_true_run(&b_flags) as f64, + ); + feats.insert( + "longest_y_ion_sequence".to_string(), + longest_true_run(&y_flags) as f64, + ); + + feats.insert("matched_b_ions".to_string(), matched_b as f64); + feats.insert("matched_b_ions_pct".to_string(), matched_b as f64 / seq_len as f64); + + feats.insert("matched_y_ions".to_string(), matched_y as f64); + feats.insert("matched_y_ions_pct".to_string(), matched_y as f64 / seq_len as f64); + + feats.insert( + "matched_ions_pct".to_string(), + (matched_b + matched_y) as f64 / (2.0 * seq_len as f64), + ); + + if calculate_hyperscore { + let ny = y_ints.len(); + let nb = b_ints.len(); + let sum_y: f64 = y_ints.iter().sum(); + let sum_b: f64 = b_ints.iter().sum(); + feats.insert("hyperscore".to_string(), hyperscore(ny, nb, sum_y, sum_b)); + } + + Ok(feats) + }) + .collect() + }); + + match results { + Ok(v) => Ok(v), + Err(e) => Err(PyException::new_err(e)), + } +} diff --git a/src/ms2_spectrum.rs b/src/ms2_spectrum.rs index 45d63c1..51a89c2 100644 --- a/src/ms2_spectrum.rs +++ b/src/ms2_spectrum.rs @@ -2,7 +2,7 @@ use pyo3::prelude::*; use crate::precursor::Precursor; -#[pyclass(get_all, set_all)] +#[pyclass(module = "ms2rescore_rs", get_all, set_all)] #[derive(Debug, Clone)] pub struct MS2Spectrum { pub identifier: String, @@ -29,10 +29,26 @@ impl MS2Spectrum { #[pymethods] impl MS2Spectrum { + #[new] + #[pyo3(signature = (identifier="".to_string(), mz=vec![], intensity=vec![], precursor=None))] + pub fn py_new( + identifier: String, + mz: Vec, + intensity: Vec, + precursor: Option, + ) -> Self { + MS2Spectrum::new(identifier, mz, intensity, precursor) + } + fn __repr__(&self) -> String { format!( "MS2Spectrum(identifier='{}', mz=[..], intensity=[..], precursor={:?})", self.identifier, self.precursor ) } + + pub fn __reduce__(&self, py: Python<'_>) -> PyResult<(PyObject, (String, Vec, Vec, Option))> { + let cls = py.import("ms2rescore_rs")?.getattr("MS2Spectrum")?; + Ok((cls.into(), (self.identifier.clone(), self.mz.clone(), self.intensity.clone(), self.precursor.clone()))) + } } diff --git a/src/ms2pip_features.rs b/src/ms2pip_features.rs new file mode 100644 index 0000000..1f3dea7 --- /dev/null +++ b/src/ms2pip_features.rs @@ -0,0 +1,601 @@ +// src/ms2pip_features.rs +// +// MS2PIP feature calculation in Rust (batch + NumPy inputs), with memory-focused optimisations: +// - Chunked (blocked) copying from NumPy to cap peak memory. +// - In-place sorting for quantiles (no clone+sort). +// - Avoid concatenating "all-ion" vectors for Pearson/MSE/Dot/Cos where possible. + +use std::collections::HashMap; + +use numpy::{PyArray1, PyArrayMethods}; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use rayon::prelude::*; + +/// Clip lower bound in log2 space: log2(0.001) +const CLIP_LOG2_MIN: f64 = -9.965_784_284_662_087; // (0.001_f64).log2() + +#[inline] +fn clip_min_f32(x: f32) -> f64 { + let xf = x as f64; + if xf < CLIP_LOG2_MIN { + CLIP_LOG2_MIN + } else { + xf + } +} + +#[inline] +fn pow2_unlog(x: f64) -> f64 { + // matches Python: 2**x - 0.001 + (2.0_f64).powf(x) - 0.001 +} + +#[inline] +fn finite_or_zero(x: f64) -> f64 { + if x.is_finite() { x } else { 0.0 } +} + +fn pearson(x: &[f64], y: &[f64]) -> f64 { + let n = x.len(); + if n != y.len() || n < 2 { + return f64::NAN; + } + if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) { + return f64::NAN; + } + + let mean_x = x.iter().sum::() / n as f64; + let mean_y = y.iter().sum::() / n as f64; + + let mut num = 0.0; + let mut den_x = 0.0; + let mut den_y = 0.0; + + for i in 0..n { + let dx = x[i] - mean_x; + let dy = y[i] - mean_y; + num += dx * dy; + den_x += dx * dx; + den_y += dy * dy; + } + + if den_x <= 0.0 || den_y <= 0.0 { + return f64::NAN; + } + num / (den_x.sqrt() * den_y.sqrt()) +} + +fn mse(x: &[f64], y: &[f64]) -> f64 { + let n = x.len(); + if n != y.len() || n == 0 { + return f64::NAN; + } + if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) { + return f64::NAN; + } + + let mut s = 0.0; + for i in 0..n { + let d = x[i] - y[i]; + s += d * d; + } + s / n as f64 +} + +fn dot(x: &[f64], y: &[f64]) -> f64 { + if x.len() != y.len() { + return f64::NAN; + } + if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) { + return f64::NAN; + } + x.iter().zip(y.iter()).map(|(a, b)| a * b).sum() +} + +fn l2_norm(x: &[f64]) -> f64 { + if x.iter().any(|v| !v.is_finite()) { + return f64::NAN; + } + x.iter().map(|v| v * v).sum::().sqrt() +} + +fn cosine_similarity(x: &[f64], y: &[f64]) -> f64 { + let d = dot(x, y); + let nx = l2_norm(x); + let ny = l2_norm(y); + if !d.is_finite() || !nx.is_finite() || !ny.is_finite() || nx <= 0.0 || ny <= 0.0 { + return f64::NAN; + } + d / (nx * ny) +} + +#[inline] +fn dot2(a1: &[f64], a2: &[f64], b1: &[f64], b2: &[f64]) -> f64 { + if a1.len() != b1.len() || a2.len() != b2.len() { + return f64::NAN; + } + let d1 = dot(a1, b1); + let d2 = dot(a2, b2); + if !d1.is_finite() || !d2.is_finite() { return f64::NAN; } + d1 + d2 +} + +#[inline] +fn mse2(a1: &[f64], a2: &[f64], b1: &[f64], b2: &[f64]) -> f64 { + if a1.len() != b1.len() || a2.len() != b2.len() { + return f64::NAN; + } + let n = a1.len() + a2.len(); + if n == 0 { + return f64::NAN; + } + if a1.iter().any(|v| !v.is_finite()) + || a2.iter().any(|v| !v.is_finite()) + || b1.iter().any(|v| !v.is_finite()) + || b2.iter().any(|v| !v.is_finite()) + { + return f64::NAN; + } + + let mut s = 0.0; + for (x, y) in a1.iter().zip(b1.iter()) { + let d = x - y; + s += d * d; + } + for (x, y) in a2.iter().zip(b2.iter()) { + let d = x - y; + s += d * d; + } + s / (n as f64) +} + +#[inline] +fn pearson2(a1: &[f64], a2: &[f64], b1: &[f64], b2: &[f64]) -> f64 { + if a1.len() != b1.len() || a2.len() != b2.len() { + return f64::NAN; + } + let n = a1.len() + a2.len(); + if n < 2 { + return f64::NAN; + } + if a1.iter().any(|v| !v.is_finite()) + || a2.iter().any(|v| !v.is_finite()) + || b1.iter().any(|v| !v.is_finite()) + || b2.iter().any(|v| !v.is_finite()) + { + return f64::NAN; + } + + let sum_x = a1.iter().sum::() + a2.iter().sum::(); + let sum_y = b1.iter().sum::() + b2.iter().sum::(); + let mean_x = sum_x / n as f64; + let mean_y = sum_y / n as f64; + + let mut num = 0.0; + let mut den_x = 0.0; + let mut den_y = 0.0; + + for (x, y) in a1.iter().zip(b1.iter()) { + let dx = x - mean_x; + let dy = y - mean_y; + num += dx * dy; + den_x += dx * dx; + den_y += dy * dy; + } + for (x, y) in a2.iter().zip(b2.iter()) { + let dx = x - mean_x; + let dy = y - mean_y; + num += dx * dy; + den_x += dx * dx; + den_y += dy * dy; + } + + if den_x <= 0.0 || den_y <= 0.0 { + return f64::NAN; + } + num / (den_x.sqrt() * den_y.sqrt()) +} + +#[inline] +fn cosine2(a1: &[f64], a2: &[f64], b1: &[f64], b2: &[f64]) -> f64 { + let d = dot2(a1, a2, b1, b2); + if !d.is_finite() { + return f64::NAN; + } + let nx1 = l2_norm(a1); + let nx2 = l2_norm(a2); + let ny1 = l2_norm(b1); + let ny2 = l2_norm(b2); + if !nx1.is_finite() || !nx2.is_finite() || !ny1.is_finite() || !ny2.is_finite() { + return f64::NAN; + } + let nx = (nx1 * nx1 + nx2 * nx2).sqrt(); + let ny = (ny1 * ny1 + ny2 * ny2).sqrt(); + if nx <= 0.0 || ny <= 0.0 { + return f64::NAN; + } + d / (nx * ny) +} + +fn mean_std(x: &[f64]) -> (f64, f64) { + let n = x.len(); + if n == 0 { + return (f64::NAN, f64::NAN); + } + if x.iter().any(|v| !v.is_finite()) { + return (f64::NAN, f64::NAN); + } + let mean = x.iter().sum::() / n as f64; + let var = x + .iter() + .map(|v| { + let d = v - mean; + d * d + }) + .sum::() + / n as f64; // ddof=0 like numpy default + (mean, var.sqrt()) +} + +fn quantile_sorted(sorted: &[f64], q: f64) -> f64 { + let n = sorted.len(); + if n == 0 { + return f64::NAN; + } + if n == 1 { + return sorted[0]; + } + // numpy default: linear interpolation on (n-1)*q + let pos = (n as f64 - 1.0) * q; + let lo = pos.floor() as usize; + let hi = pos.ceil() as usize; + if lo == hi { + return sorted[lo]; + } + let w = pos - lo as f64; + sorted[lo] * (1.0 - w) + sorted[hi] * w +} + +fn ranks_average_ties(values: &[f64]) -> Vec { + let n = values.len(); + let mut idx: Vec = (0..n).collect(); + + // Deterministic ordering (handles NaN consistently). If NaNs exist, caller should typically return NaN. + idx.sort_by(|&i, &j| values[i].total_cmp(&values[j])); + + let mut ranks = vec![0.0; n]; + let mut i = 0usize; + while i < n { + let mut j = i + 1; + while j < n && values[idx[j]] == values[idx[i]] { + j += 1; + } + // average rank for ties, ranks are 1-based + let r_lo = (i + 1) as f64; + let r_hi = j as f64; + let r_avg = (r_lo + r_hi) / 2.0; + for k in i..j { + ranks[idx[k]] = r_avg; + } + i = j; + } + ranks +} + +fn spearman(x: &[f64], y: &[f64]) -> f64 { + if x.len() != y.len() || x.len() < 2 { + return f64::NAN; + } + // pandas rank/corr will propagate NaNs; we emulate that by returning NaN if any non-finite + if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) { + return f64::NAN; + } + let rx = ranks_average_ties(x); + let ry = ranks_average_ties(y); + pearson(&rx, &ry) +} + +#[pyfunction] +pub fn batch_ms2pip_features_numpy( + py: Python<'_>, + psm_indices: Vec, + predicted_b: Vec>>, + predicted_y: Vec>>, + observed_b: Vec>>, + observed_y: Vec>>, +) -> PyResult)>> { + let n = psm_indices.len(); + if predicted_b.len() != n + || predicted_y.len() != n + || observed_b.len() != n + || observed_y.len() != n + { + return Err(PyValueError::new_err( + "All inputs must have identical length: psm_indices, predicted_b, predicted_y, observed_b, observed_y", + )); + } + + #[derive(Clone)] + struct Owned { + idx: usize, + pb: Vec, + py: Vec, + ob: Vec, + oy: Vec, + } + + // Main output: keep capacity to avoid reallocations. + let mut out: Vec<(usize, HashMap)> = Vec::with_capacity(n); + + // Chunking keeps peak memory bounded. Tune as needed. + let block_size: usize = 4096; + + for start in (0..n).step_by(block_size) { + let end = (start + block_size).min(n); + + // ---- Copy out of NumPy while holding the GIL (only this block) ---- + let mut owned: Vec = Vec::with_capacity(end - start); + for i in start..end { + let pb = predicted_b[i].bind(py); + let pyv = predicted_y[i].bind(py); + let ob = observed_b[i].bind(py); + let oy = observed_y[i].bind(py); + + // Supports non-contiguous by iterating; contiguous arrays will still be fast. + let pb_vec: Vec = pb.readonly().as_array().iter().copied().collect(); + let py_vec: Vec = pyv.readonly().as_array().iter().copied().collect(); + let ob_vec: Vec = ob.readonly().as_array().iter().copied().collect(); + let oy_vec: Vec = oy.readonly().as_array().iter().copied().collect(); + + owned.push(Owned { + idx: psm_indices[i], + pb: pb_vec, + py: py_vec, + ob: ob_vec, + oy: oy_vec, + }); + } + + // ---- Compute this block without the GIL ---- + let mut block_out: Vec<(usize, HashMap)> = py.allow_threads(|| { + owned + .into_par_iter() + .map(|it| { + // mimic Python behavior: if mismatched, return empty dict + if it.pb.len() != it.ob.len() || it.py.len() != it.oy.len() { + return (it.idx, HashMap::new()); + } + if it.pb.is_empty() && it.py.is_empty() { + return (it.idx, HashMap::new()); + } + + // clip in log2 space + let tb: Vec = it.pb.into_iter().map(clip_min_f32).collect(); + let ty: Vec = it.py.into_iter().map(clip_min_f32).collect(); + let pb: Vec = it.ob.into_iter().map(clip_min_f32).collect(); + let pyv: Vec = it.oy.into_iter().map(clip_min_f32).collect(); + + // unlog arrays + let tb_u: Vec = tb.iter().copied().map(pow2_unlog).collect(); + let ty_u: Vec = ty.iter().copied().map(pow2_unlog).collect(); + let pb_u: Vec = pb.iter().copied().map(pow2_unlog).collect(); + let py_u: Vec = pyv.iter().copied().map(pow2_unlog).collect(); + + // abs diffs (log) + let mut abs_b: Vec = + tb.iter().zip(pb.iter()).map(|(a, b)| (a - b).abs()).collect(); + let mut abs_y: Vec = + ty.iter().zip(pyv.iter()).map(|(a, b)| (a - b).abs()).collect(); + let mut abs_all: Vec = Vec::with_capacity(abs_b.len() + abs_y.len()); + abs_all.extend_from_slice(&abs_b); + abs_all.extend_from_slice(&abs_y); + + // abs diffs (unlog) + let mut abs_b_u: Vec = + tb_u.iter().zip(pb_u.iter()).map(|(a, b)| (a - b).abs()).collect(); + let mut abs_y_u: Vec = + ty_u.iter().zip(py_u.iter()).map(|(a, b)| (a - b).abs()).collect(); + let mut abs_all_u: Vec = Vec::with_capacity(abs_b_u.len() + abs_y_u.len()); + abs_all_u.extend_from_slice(&abs_b_u); + abs_all_u.extend_from_slice(&abs_y_u); + + // mean/std before sorting + let (mean_abs_all, std_abs_all) = mean_std(&abs_all); + let (mean_abs_b, std_abs_b) = mean_std(&abs_b); + let (mean_abs_y, std_abs_y) = mean_std(&abs_y); + + let (mean_abs_all_u, std_abs_all_u) = mean_std(&abs_all_u); + let (mean_abs_b_u, std_abs_b_u) = mean_std(&abs_b_u); + let (mean_abs_y_u, std_abs_y_u) = mean_std(&abs_y_u); + + // sort in place for quantiles + min/max (no clone) + abs_all.sort_by(|a, b| a.total_cmp(b)); + abs_b.sort_by(|a, b| a.total_cmp(b)); + abs_y.sort_by(|a, b| a.total_cmp(b)); + + abs_all_u.sort_by(|a, b| a.total_cmp(b)); + abs_b_u.sort_by(|a, b| a.total_cmp(b)); + abs_y_u.sort_by(|a, b| a.total_cmp(b)); + + let min_abs_all = abs_all.first().copied().unwrap_or(f64::NAN); + let max_abs_all = abs_all.last().copied().unwrap_or(f64::NAN); + let q1_all = quantile_sorted(&abs_all, 0.25); + let q2_all = quantile_sorted(&abs_all, 0.5); + let q3_all = quantile_sorted(&abs_all, 0.75); + + let min_abs_b = abs_b.first().copied().unwrap_or(f64::NAN); + let max_abs_b = abs_b.last().copied().unwrap_or(f64::NAN); + let q1_b = quantile_sorted(&abs_b, 0.25); + let q2_b = quantile_sorted(&abs_b, 0.5); + let q3_b = quantile_sorted(&abs_b, 0.75); + + let min_abs_y = abs_y.first().copied().unwrap_or(f64::NAN); + let max_abs_y = abs_y.last().copied().unwrap_or(f64::NAN); + let q1_y = quantile_sorted(&abs_y, 0.25); + let q2_y = quantile_sorted(&abs_y, 0.5); + let q3_y = quantile_sorted(&abs_y, 0.75); + + let min_abs_all_u = abs_all_u.first().copied().unwrap_or(f64::NAN); + let max_abs_all_u = abs_all_u.last().copied().unwrap_or(f64::NAN); + let q1_all_u = quantile_sorted(&abs_all_u, 0.25); + let q2_all_u = quantile_sorted(&abs_all_u, 0.5); + let q3_all_u = quantile_sorted(&abs_all_u, 0.75); + + let min_abs_b_u = abs_b_u.first().copied().unwrap_or(f64::NAN); + let max_abs_b_u = abs_b_u.last().copied().unwrap_or(f64::NAN); + let q1_b_u = quantile_sorted(&abs_b_u, 0.25); + let q2_b_u = quantile_sorted(&abs_b_u, 0.5); + let q3_b_u = quantile_sorted(&abs_b_u, 0.75); + + let min_abs_y_u = abs_y_u.first().copied().unwrap_or(f64::NAN); + let max_abs_y_u = abs_y_u.last().copied().unwrap_or(f64::NAN); + let q1_y_u = quantile_sorted(&abs_y_u, 0.25); + let q2_y_u = quantile_sorted(&abs_y_u, 0.5); + let q3_y_u = quantile_sorted(&abs_y_u, 0.75); + + // correlations and similarities (avoid concatenating for "all" where possible) + let spec_pearson_norm = pearson2(&tb, &ty, &pb, &pyv); + let ionb_pearson_norm = pearson(&tb, &pb); + let iony_pearson_norm = pearson(&ty, &pyv); + + let spec_mse_norm = mse2(&tb, &ty, &pb, &pyv); + let ionb_mse_norm = mse(&tb, &pb); + let iony_mse_norm = mse(&ty, &pyv); + + let dotprod_norm = dot2(&tb, &ty, &pb, &pyv); + let dotprod_ionb_norm = dot(&tb, &pb); + let dotprod_iony_norm = dot(&ty, &pyv); + + let cos_norm = cosine2(&tb, &ty, &pb, &pyv); + let cos_ionb_norm = cosine_similarity(&tb, &pb); + let cos_iony_norm = cosine_similarity(&ty, &pyv); + + let spec_pearson = pearson2(&tb_u, &ty_u, &pb_u, &py_u); + let ionb_pearson = pearson(&tb_u, &pb_u); + let iony_pearson = pearson(&ty_u, &py_u); + + // Spearman "all ions": concatenate only for this metric (keeps parity, limits memory) + let mut t_all_u = Vec::with_capacity(tb_u.len() + ty_u.len()); + t_all_u.extend_from_slice(&tb_u); + t_all_u.extend_from_slice(&ty_u); + let mut p_all_u = Vec::with_capacity(pb_u.len() + py_u.len()); + p_all_u.extend_from_slice(&pb_u); + p_all_u.extend_from_slice(&py_u); + + let spec_spearman = spearman(&t_all_u, &p_all_u); + let ionb_spearman = spearman(&tb_u, &pb_u); + let iony_spearman = spearman(&ty_u, &py_u); + + let spec_mse = mse2(&tb_u, &ty_u, &pb_u, &py_u); + let ionb_mse = mse(&tb_u, &pb_u); + let iony_mse = mse(&ty_u, &py_u); + + let dotprod = dot2(&tb_u, &ty_u, &pb_u, &py_u); + let dotprod_ionb = dot(&tb_u, &pb_u); + let dotprod_iony = dot(&ty_u, &py_u); + + let cos = cosine2(&tb_u, &ty_u, &pb_u, &py_u); + let cos_ionb = cosine_similarity(&tb_u, &pb_u); + let cos_iony = cosine_similarity(&ty_u, &py_u); + + // iontype min/max (unlog) like Python: 0 if b else 1 if y + let min_abs_diff_iontype = if min_abs_b_u <= min_abs_y_u { 0.0 } else { 1.0 }; + let max_abs_diff_iontype = if max_abs_b_u >= max_abs_y_u { 0.0 } else { 1.0 }; + + let mut feats: HashMap = HashMap::with_capacity(66); + + // log space + feats.insert("spec_pearson_norm".into(), finite_or_zero(spec_pearson_norm)); + feats.insert("ionb_pearson_norm".into(), finite_or_zero(ionb_pearson_norm)); + feats.insert("iony_pearson_norm".into(), finite_or_zero(iony_pearson_norm)); + feats.insert("spec_mse_norm".into(), finite_or_zero(spec_mse_norm)); + feats.insert("ionb_mse_norm".into(), finite_or_zero(ionb_mse_norm)); + feats.insert("iony_mse_norm".into(), finite_or_zero(iony_mse_norm)); + + feats.insert("min_abs_diff_norm".into(), finite_or_zero(min_abs_all)); + feats.insert("max_abs_diff_norm".into(), finite_or_zero(max_abs_all)); + feats.insert("abs_diff_Q1_norm".into(), finite_or_zero(q1_all)); + feats.insert("abs_diff_Q2_norm".into(), finite_or_zero(q2_all)); + feats.insert("abs_diff_Q3_norm".into(), finite_or_zero(q3_all)); + feats.insert("mean_abs_diff_norm".into(), finite_or_zero(mean_abs_all)); + feats.insert("std_abs_diff_norm".into(), finite_or_zero(std_abs_all)); + + feats.insert("ionb_min_abs_diff_norm".into(), finite_or_zero(min_abs_b)); + feats.insert("ionb_max_abs_diff_norm".into(), finite_or_zero(max_abs_b)); + feats.insert("ionb_abs_diff_Q1_norm".into(), finite_or_zero(q1_b)); + feats.insert("ionb_abs_diff_Q2_norm".into(), finite_or_zero(q2_b)); + feats.insert("ionb_abs_diff_Q3_norm".into(), finite_or_zero(q3_b)); + feats.insert("ionb_mean_abs_diff_norm".into(), finite_or_zero(mean_abs_b)); + feats.insert("ionb_std_abs_diff_norm".into(), finite_or_zero(std_abs_b)); + + feats.insert("iony_min_abs_diff_norm".into(), finite_or_zero(min_abs_y)); + feats.insert("iony_max_abs_diff_norm".into(), finite_or_zero(max_abs_y)); + feats.insert("iony_abs_diff_Q1_norm".into(), finite_or_zero(q1_y)); + feats.insert("iony_abs_diff_Q2_norm".into(), finite_or_zero(q2_y)); + feats.insert("iony_abs_diff_Q3_norm".into(), finite_or_zero(q3_y)); + feats.insert("iony_mean_abs_diff_norm".into(), finite_or_zero(mean_abs_y)); + feats.insert("iony_std_abs_diff_norm".into(), finite_or_zero(std_abs_y)); + + feats.insert("dotprod_norm".into(), finite_or_zero(dotprod_norm)); + feats.insert("dotprod_ionb_norm".into(), finite_or_zero(dotprod_ionb_norm)); + feats.insert("dotprod_iony_norm".into(), finite_or_zero(dotprod_iony_norm)); + feats.insert("cos_norm".into(), finite_or_zero(cos_norm)); + feats.insert("cos_ionb_norm".into(), finite_or_zero(cos_ionb_norm)); + feats.insert("cos_iony_norm".into(), finite_or_zero(cos_iony_norm)); + + // normal space + feats.insert("spec_pearson".into(), finite_or_zero(spec_pearson)); + feats.insert("ionb_pearson".into(), finite_or_zero(ionb_pearson)); + feats.insert("iony_pearson".into(), finite_or_zero(iony_pearson)); + feats.insert("spec_spearman".into(), finite_or_zero(spec_spearman)); + feats.insert("ionb_spearman".into(), finite_or_zero(ionb_spearman)); + feats.insert("iony_spearman".into(), finite_or_zero(iony_spearman)); + feats.insert("spec_mse".into(), finite_or_zero(spec_mse)); + feats.insert("ionb_mse".into(), finite_or_zero(ionb_mse)); + feats.insert("iony_mse".into(), finite_or_zero(iony_mse)); + + feats.insert("min_abs_diff_iontype".into(), min_abs_diff_iontype); + feats.insert("max_abs_diff_iontype".into(), max_abs_diff_iontype); + + feats.insert("min_abs_diff".into(), finite_or_zero(min_abs_all_u)); + feats.insert("max_abs_diff".into(), finite_or_zero(max_abs_all_u)); + feats.insert("abs_diff_Q1".into(), finite_or_zero(q1_all_u)); + feats.insert("abs_diff_Q2".into(), finite_or_zero(q2_all_u)); + feats.insert("abs_diff_Q3".into(), finite_or_zero(q3_all_u)); + feats.insert("mean_abs_diff".into(), finite_or_zero(mean_abs_all_u)); + feats.insert("std_abs_diff".into(), finite_or_zero(std_abs_all_u)); + + feats.insert("ionb_min_abs_diff".into(), finite_or_zero(min_abs_b_u)); + feats.insert("ionb_max_abs_diff".into(), finite_or_zero(max_abs_b_u)); + feats.insert("ionb_abs_diff_Q1".into(), finite_or_zero(q1_b_u)); + feats.insert("ionb_abs_diff_Q2".into(), finite_or_zero(q2_b_u)); + feats.insert("ionb_abs_diff_Q3".into(), finite_or_zero(q3_b_u)); + feats.insert("ionb_mean_abs_diff".into(), finite_or_zero(mean_abs_b_u)); + feats.insert("ionb_std_abs_diff".into(), finite_or_zero(std_abs_b_u)); + + feats.insert("iony_min_abs_diff".into(), finite_or_zero(min_abs_y_u)); + feats.insert("iony_max_abs_diff".into(), finite_or_zero(max_abs_y_u)); + feats.insert("iony_abs_diff_Q1".into(), finite_or_zero(q1_y_u)); + feats.insert("iony_abs_diff_Q2".into(), finite_or_zero(q2_y_u)); + feats.insert("iony_abs_diff_Q3".into(), finite_or_zero(q3_y_u)); + feats.insert("iony_mean_abs_diff".into(), finite_or_zero(mean_abs_y_u)); + feats.insert("iony_std_abs_diff".into(), finite_or_zero(std_abs_y_u)); + + feats.insert("dotprod".into(), finite_or_zero(dotprod)); + feats.insert("dotprod_ionb".into(), finite_or_zero(dotprod_ionb)); + feats.insert("dotprod_iony".into(), finite_or_zero(dotprod_iony)); + feats.insert("cos".into(), finite_or_zero(cos)); + feats.insert("cos_ionb".into(), finite_or_zero(cos_ionb)); + feats.insert("cos_iony".into(), finite_or_zero(cos_iony)); + + (it.idx, feats) + }) + .collect::>() + }); + + out.append(&mut block_out); + } + + Ok(out) +} diff --git a/src/parse_mzdata.rs b/src/parse_mzdata.rs index 8425ded..d461ab9 100644 --- a/src/parse_mzdata.rs +++ b/src/parse_mzdata.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use mzdata::{params::ParamValue, prelude::*, MZReader}; +use rayon::prelude::*; use crate::ms2_spectrum::MS2Spectrum; use crate::precursor::Precursor; @@ -49,8 +50,11 @@ pub fn parse_precursor_info( spectrum_path: &str, ) -> Result, std::io::Error> { let reader = MZReader::open_path(spectrum_path)?; - Ok(reader + let spectra: Vec<_> = reader .filter(|spectrum| spectrum.description.ms_level == 2) + .collect(); + Ok(spectra + .into_par_iter() .filter_map(|spectrum| { spectrum.precursor().as_ref()?; Some((spectrum.description.id.clone(), Precursor::from(&spectrum))) @@ -65,8 +69,11 @@ pub fn read_ms2_spectra(spectrum_path: &str) -> Result, std::io inner.set_centroiding(true); } - Ok(reader + let spectra: Vec<_> = reader .filter(|spectrum| spectrum.description.ms_level == 2) + .collect(); + Ok(spectra + .into_par_iter() .map(MS2Spectrum::from) .collect::>()) } diff --git a/src/parse_timsrust.rs b/src/parse_timsrust.rs index 59997e6..b637710 100644 --- a/src/parse_timsrust.rs +++ b/src/parse_timsrust.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use rayon::prelude::*; use timsrust::readers::SpectrumReaderError; use crate::ms2_spectrum::MS2Spectrum; @@ -61,6 +62,7 @@ pub fn parse_precursor_info( .map_err(|e| std::io::Error::other(e.to_string()))?; let spectra = (0..reader.len()) + .into_par_iter() .map(|index| match reader.get(index) { Ok(spectrum) => Ok(Some(spectrum)), Err(err) => handle_spectrum_reader_error(err).map(|_| None), @@ -74,7 +76,7 @@ pub fn parse_precursor_info( .collect::>(); let precursor_info = spectra - .into_iter() + .into_par_iter() .filter_map(|spectrum| match spectrum.precursor { Some(precursor) => Some((spectrum.index.to_string(), Precursor::from(precursor))), None => None, @@ -90,6 +92,7 @@ pub fn read_ms2_spectra(spectrum_path: &str) -> Result, std::io .map_err(|e| std::io::Error::other(e.to_string()))?; let spectra = (0..reader.len()) + .into_par_iter() .map(|index| match reader.get(index) { Ok(spectrum) => Ok(Some(spectrum)), Err(err) => handle_spectrum_reader_error(err).map(|_| None), @@ -102,5 +105,5 @@ pub fn read_ms2_spectra(spectrum_path: &str) -> Result, std::io .flatten() .collect::>(); - Ok(spectra.into_iter().map(MS2Spectrum::from).collect()) + Ok(spectra.into_par_iter().map(MS2Spectrum::from).collect()) } diff --git a/src/precursor.rs b/src/precursor.rs index e0fbd32..eee7a21 100644 --- a/src/precursor.rs +++ b/src/precursor.rs @@ -1,7 +1,7 @@ use pyo3::prelude::*; /// Precursor information. -#[pyclass(get_all, set_all)] +#[pyclass(module = "ms2rescore_rs", get_all, set_all)] #[derive(Debug, Clone)] pub struct Precursor { pub mz: f64, @@ -13,12 +13,29 @@ pub struct Precursor { #[pymethods] impl Precursor { + #[new] + #[pyo3(signature = (mz=0.0, rt=0.0, im=0.0, charge=0, intensity=0.0))] + pub fn new(mz: f64, rt: f64, im: f64, charge: usize, intensity: f64) -> Self { + Precursor { + mz, + rt, + im, + charge, + intensity, + } + } + pub fn __repr__(&self) -> String { format!( "Precursor(mz={}, rt={}, im={}, charge={}, intensity={})", self.mz, self.rt, self.im, self.charge, self.intensity ) } + + pub fn __reduce__(&self, py: Python<'_>) -> PyResult<(PyObject, (f64, f64, f64, usize, f64))> { + let cls = py.import("ms2rescore_rs")?.getattr("Precursor")?; + Ok((cls.into(), (self.mz, self.rt, self.im, self.charge, self.intensity))) + } } impl Default for Precursor {