diff --git a/ms2pip/__init__.py b/ms2pip/__init__.py index b9f872a..5370f91 100644 --- a/ms2pip/__init__.py +++ b/ms2pip/__init__.py @@ -15,6 +15,7 @@ predict_batch, predict_library, correlate, + process_observed_spectra, get_training_data, annotate_spectra, download_models, diff --git a/ms2pip/core.py b/ms2pip/core.py index 3ffd82b..c5116ea 100644 --- a/ms2pip/core.py +++ b/ms2pip/core.py @@ -15,6 +15,7 @@ import pandas as pd from psm_utils import PSM, Peptidoform, PSMList from rich.progress import track +from ms2rescore_rs import MS2Spectrum import ms2pip.exceptions as exceptions from ms2pip._cython_modules import ms2pip_pyx @@ -291,6 +292,159 @@ def correlate( return results +def process_MS2_spectra( + psms: Union[PSMList, List[PSM]], + compute_correlations: bool = False, + model: Optional[str] = "HCD", + model_dir: Optional[Union[str, Path]] = None, + ms2_tolerance: float = 0.02, + processes: Optional[int] = None, +) -> List[ProcessingResult]: + """ + Process PSMs with MS2Spectrum objects already attached.\f + + This function processes PSMs that already have :py:class:`ms2rescore_rs.MS2Spectrum` + objects in their ``spectrum`` attribute. It extracts the spectra, performs predictions, + and optionally computes correlations. + + Parameters + ---------- + psms + PSMList or list of PSM objects. Each PSM must have an + :py:class:`ms2rescore_rs.MS2Spectrum` object in its ``spectrum`` attribute. + compute_correlations + Compute correlations between predictions and targets. Default: False. + model + Model to use for prediction. Default: "HCD". + model_dir + Directory where XGBoost model files are stored. Default: `~/.ms2pip`. + ms2_tolerance + MS2 tolerance in Da for observed spectrum peak annotation. By default, 0.02 Da. + processes + Number of parallel processes for multiprocessing steps. By default, all available. + + Returns + ------- + results: List[ProcessingResult] + ProcessingResult objects with theoretical m/z, predicted intensity, and observed + intensity values, and optionally, correlations. + + Raises + ------ + ValueError + If PSMs do not contain :py:class:`ms2rescore_rs.MS2Spectrum` objects in the + ``spectrum`` attribute. + + """ + if isinstance(psms, list): + psm_list = PSMList(psm_list=psms) + else: + psm_list = psms + + if not all(psm_list["spectrum"]) or not (isinstance(psm_list["spectrum"][0], MS2Spectrum)): + raise ValueError("PSMs must contain MS2Spectrum objects in the 'spectrum' attribute.") + + observed_spectra = [ + ObservedSpectrum( + mz=np.array(spectrum.mz, dtype=np.float32), + intensity=np.array(spectrum.intensity, dtype=np.float32), + identifier=str(spectrum.identifier), + precursor_mz=float(spectrum.precursor.mz), + precursor_charge=float(spectrum.precursor.charge), + retention_time=float(spectrum.precursor.rt), + ) + for spectrum in psm_list["spectrum"] + ] + + # Process with encoder + ion_types = [it.lower() for it in MODELS[model]["ion_types"]] + model_dir = model_dir if model_dir else Path.home() / ".ms2pip" + + with Encoder.from_psm_list(psm_list) as encoder: + ms2pip_pyx.ms2pip_init(*encoder.encoder_files) + results = [] + + for psm_index, psm in enumerate(psm_list): + spectrum = observed_spectra[psm_index] + + # Spectrum preprocessing (same as in _process_spectra) + for label_type in ["iTRAQ", "TMT"]: + if label_type in model: + spectrum.remove_reporter_ions(label_type) + spectrum.tic_norm() + spectrum.log2_transform() + + # Encode peptidoform + try: + enc_peptidoform = encoder.encode_peptidoform(psm.peptidoform) + except exceptions.InvalidAminoAcidError: + result = ProcessingResult(psm_index=psm_index, psm=psm) + results.append(result) + continue + + # Get observed intensities (targets) + targets = ms2pip_pyx.get_targets( + enc_peptidoform, + spectrum.mz.astype(np.float32), + spectrum.intensity.astype(np.float32), + float(ms2_tolerance), + MODELS[model]["peaks_version"], + ) + observed_intensity = {i: np.array(t, dtype=np.float32) for i, t in zip(ion_types, targets)} + + # Update precursor charge if needed + if not psm.peptidoform.precursor_charge: + psm.peptidoform.precursor_charge = spectrum.precursor_charge + + # Get predictions + try: + result = _process_peptidoform(psm_index, psm, model, encoder, ion_types) + except ( + exceptions.InvalidPeptidoformError, + exceptions.InvalidAminoAcidError, + ): + result = ProcessingResult(psm_index=psm_index, psm=psm) + else: + result.observed_intensity = observed_intensity + + results.append(result) + + # Add XGBoost predictions if needed + if "xgboost_model_files" in MODELS[model].keys(): + results_to_predict = [r for r in results if r.feature_vectors is not None] + if results_to_predict: + import xgboost as xgb + + validate_requested_xgb_model( + MODELS[model]["xgboost_model_files"], + MODELS[model]["model_hash"], + model_dir, + ) + + num_ions = [len(r.psm.peptidoform.parsed_sequence) - 1 for r in results_to_predict] + xgb_vector = xgb.DMatrix(np.vstack(list(r.feature_vectors for r in results_to_predict))) + + predictions = get_predictions_xgb( + xgb_vector, + num_ions, + MODELS[model], + model_dir, + processes=processes, + ) + + for result, preds in zip(results_to_predict, predictions): + result.predicted_intensity = preds + result.feature_vectors = None + + # Compute correlations if requested + if compute_correlations: + logger.info("Computing correlations") + calculate_correlations(results) + logger.info(f"Median correlation: {np.median([r.correlation for r in results if r.correlation is not None])}") + + return results + + def correlate_single( observed_spectrum: ObservedSpectrum, ms2_tolerance: float = 0.02,