diff --git a/src/aind_dynamic_foraging_basic_analysis/metrics/snr_kurtosis.py b/src/aind_dynamic_foraging_basic_analysis/metrics/snr_kurtosis.py index 920dc2a..7db42f6 100644 --- a/src/aind_dynamic_foraging_basic_analysis/metrics/snr_kurtosis.py +++ b/src/aind_dynamic_foraging_basic_analysis/metrics/snr_kurtosis.py @@ -2,9 +2,10 @@ Utilities for signal quality metrics on 1D fluorescence traces. This module provides: -- :func:`estimate_snr` — an SNR estimator using a derivative-based noise +- :func:`estimate_trace_snr` — an SNR estimator using a derivative-based noise estimate and peak-based signal estimate. -- :func:`estimate_kurtosis` — excess kurtosis of the trace distribution. +- :func:`estimate_trace_kurtosis` — excess kurtosis of the trace distribution. +- :func:`estimate_snr_and_kurtosis` — estimate the SNR and kurtosis using an NWB Notes ----- @@ -32,19 +33,22 @@ from __future__ import annotations import warnings -from typing import Tuple import numpy as np +import pandas as pd from numpy.typing import NDArray from scipy.signal import find_peaks from scipy.stats import kurtosis -__all__ = ["estimate_snr", "estimate_kurtosis"] +from aind_dynamic_foraging_data_utils import nwb_utils as nu -def estimate_snr( +__all__ = ["estimate_trace_snr", "estimate_trace_kurtosis"] + + +def estimate_trace_snr( trace: NDArray[np.floating], fps: float = 20.0 -) -> Tuple[float, float, NDArray[np.intp]]: +) -> tuple[float, float, NDArray[np.intp]]: """ Estimate the signal-to-noise ratio (SNR) of a 1D trace. @@ -110,7 +114,7 @@ def estimate_snr( return snr, noise, peaks -def estimate_kurtosis(trace: NDArray[np.floating]) -> float: +def estimate_trace_kurtosis(trace: NDArray[np.floating]) -> float: """ Compute the **excess kurtosis** of a 1D trace distribution. @@ -132,3 +136,57 @@ def estimate_kurtosis(trace: NDArray[np.floating]) -> float: # Excess kurtosis (normal distribution = 0) return float(kurtosis(trace, fisher=True, bias=False)) + + +def estimate_snr_and_kurtosis( + nwb, data_column='data', + process_suffix='dff-poly_mc-iso-IRLS', + fps: float = 20.0, +) -> pd.DataFrame: + """ + Estimate the signal-to-noise ratio (SNR) and kurtosis given an NWB or list of NWBs. + Iso channels will be excluded, and by default, only the preprocessed, motion-corrected + channels will be used. + + Parameters + ---------- + nwb: + Single or list of nwb or nwb-like object + data_column: string, optional + The data_column for the df_fip, by default 'data' + process_suffix: string, optional + The suffix of the channel indicating processing method, by default 'dff-poly_mc-iso-IRLS' + fps : float, optional + Sampling frequency (frames per second), by default ``20.0``. + + Returns + ------- + df.Dataframe + Dataframe with each row a session and channel with columns of estimated + signal to noise ratio (SNR), noise, peaks in the trace, and excess kurtosis. + """ + nwb_list = nwb if isinstance(nwb, list) else [nwb] + + sess_metrics = [] + for nwb_i in nwb_list: + if not hasattr(nwb_i, "df_fip"): + print("You need to compute the df_fip first") + print("running `nwb.df_fip = create_df_fip(nwb,tidy=True)`") + df_fip = nu.create_df_fip(nwb_i, tidy=True) + else: + df_fip = nwb_i.df_fip + ses_idx = df_fip['ses_idx'].unique()[0] + all_channels = [channel for channel in df_fip.event.unique() if + not channel.startswith("FIP") and not channel.startswith("Iso")] + processed_signal_channels = [channel for channel in all_channels if + channel.endswith(process_suffix)] + df_fip = df_fip[df_fip["event"].isin(processed_signal_channels)] + for channel in processed_signal_channels: + df_fip_channel_trace = df_fip.query(f"event == '{channel}'")[data_column].values + (snr, noise, peaks) = estimate_trace_snr(df_fip_channel_trace, fps) + kurtosis = estimate_trace_kurtosis(df_fip_channel_trace) + sess_metrics.append([ses_idx, channel + '_' + data_column, snr, noise, peaks, kurtosis]) + # put together df_sess_metrics + df_sess_metrics = pd.DataFrame(sess_metrics, columns=['ses_idx', 'channel', 'SNR', + 'noise', 'peaks', 'kurtosis']) + return df_sess_metrics