diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ab65c46..9089e8c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added --- +- `m_chi2` attribute and `get_m_chi2` method for `TimeSeries` +- `take_mut` dependency ### Changed diff --git a/Cargo.toml b/Cargo.toml index 5d600284..9c735456 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,7 @@ num-traits = "^0.2" paste = "1" schemars = "^0.8" serde = { version = "1", features = ["derive"] } +take_mut = "0.2.2" thiserror = "1" thread_local = "1.1" unzip3 = "1" @@ -66,6 +67,7 @@ clap = { version = "3.2.6", features = ["std", "color", "suggestions", "derive", criterion = "0.4" hyperdual = "1.1" light-curve-common = "0.1.0" +ndarray = { version = "^0.15", features = ["approx-0_5"] } plotters = { version = "0.3.5", default-features = false, features = ["errorbar", "line_series", "ttf"] } plotters-bitmap = "0.3.3" rand = "0.7" diff --git a/src/data/data_sample.rs b/src/data/data_sample.rs index 44e5073f..f7753a54 100644 --- a/src/data/data_sample.rs +++ b/src/data/data_sample.rs @@ -21,6 +21,15 @@ where std2: Option, } +impl<'a, T> PartialEq for DataSample<'a, T> +where + T: Float, +{ + fn eq(&self, other: &Self) -> bool { + self.sample == other.sample + } +} + macro_rules! data_sample_getter { ($attr: ident, $getter: ident, $func: expr, $method_sorted: ident) => { // This lint is false-positive in macros diff --git a/src/data/mod.rs b/src/data/mod.rs index aafaeb99..469dee0b 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -1,12 +1,11 @@ mod data_sample; pub use data_sample::DataSample; -mod multi_color_time_series; +pub(crate) mod multi_color_time_series; pub use multi_color_time_series::MultiColorTimeSeries; mod sorted_array; pub use sorted_array::SortedArray; mod time_series; - pub use time_series::TimeSeries; diff --git a/src/data/multi_color_time_series.rs b/src/data/multi_color_time_series.rs index 535543aa..787d3bc7 100644 --- a/src/data/multi_color_time_series.rs +++ b/src/data/multi_color_time_series.rs @@ -3,12 +3,14 @@ use crate::float_trait::Float; use crate::multicolor::PassbandTrait; use crate::{DataSample, PassbandSet}; +use conv::prelude::*; use itertools::Either; use itertools::EitherOrBoth; use itertools::Itertools; use std::collections::{BTreeMap, BTreeSet}; use std::ops::{Deref, DerefMut}; +#[derive(Clone, Debug)] pub enum MultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { Mapping(MappedMultiColorTimeSeries<'a, P, T>), Flat(FlatMultiColorTimeSeries<'a, P, T>), @@ -23,6 +25,31 @@ where P: PassbandTrait + 'p, T: Float, { + pub fn total_lenu(&self) -> usize { + match self { + Self::Mapping(mapping) => mapping.total_lenu(), + Self::Flat(flat) => flat.total_lenu(), + Self::MappingFlat { flat, .. } => flat.total_lenu(), + } + } + + pub fn total_lenf(&self) -> T { + match self { + Self::Mapping(mapping) => mapping.total_lenf(), + Self::Flat(flat) => flat.total_lenf(), + Self::MappingFlat { flat, .. } => flat.total_lenf(), + } + } + + pub fn passband_count(&self) -> usize { + match self { + Self::Mapping(mapping) => mapping.passband_count(), + Self::Flat(flat) => flat.passband_count(), + // Both flat and mapping have the same number of passbands and should be equally fast + Self::MappingFlat { flat, .. } => flat.passband_count(), + } + } + pub fn from_map(map: impl Into>>) -> Self { Self::Mapping(MappedMultiColorTimeSeries::new(map)) } @@ -36,21 +63,43 @@ where Self::Flat(FlatMultiColorTimeSeries::new(t, m, w, passbands)) } - pub fn mapping_mut(&mut self) -> &mut MappedMultiColorTimeSeries<'a, P, T> { + fn ensure_mapping(&mut self) -> &mut Self { if matches!(self, MultiColorTimeSeries::Flat(_)) { - let dummy_self = Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new())); - *self = match std::mem::replace(self, dummy_self) { + take_mut::take(self, |slf| match slf { Self::Flat(mut flat) => { let mapping = MappedMultiColorTimeSeries::from_flat(&mut flat); Self::MappingFlat { mapping, flat } } - _ => unreachable!(), - } + _ => unreachable!("We just checked that we are in ::Flat variant"), + }); } + self + } + + fn enforce_mapping(&mut self) -> &mut Self { match self { + Self::Mapping(_) => {} + Self::Flat(_flat) => take_mut::take(self, |slf| match slf { + Self::Flat(flat) => Self::Mapping(flat.into()), + _ => unreachable!("We just checked that we are in ::Flat variant"), + }), + Self::MappingFlat { .. } => { + take_mut::take(self, |slf| match slf { + Self::MappingFlat { mapping, .. } => Self::Mapping(mapping), + _ => unreachable!("We just checked that we are in ::MappingFlat variant"), + }); + } + } + self + } + + pub fn mapping_mut(&mut self) -> &mut MappedMultiColorTimeSeries<'a, P, T> { + match self.ensure_mapping() { Self::Mapping(mapping) => mapping, Self::Flat(_flat) => { - unreachable!("::Flat variant is already transofrmed to ::MappingFlat") + unreachable!( + "::Flat variant is already transformed to ::MappingFlat in ensure_mapping" + ) } Self::MappingFlat { mapping, .. } => mapping, } @@ -64,20 +113,25 @@ where } } - pub fn flat_mut(&mut self) -> &mut FlatMultiColorTimeSeries<'a, P, T> { + fn ensure_flat(&mut self) -> &mut Self { if matches!(self, MultiColorTimeSeries::Mapping(_)) { - let dummy_self = Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new())); - *self = match std::mem::replace(self, dummy_self) { + take_mut::take(self, |slf| match slf { Self::Mapping(mut mapping) => { let flat = FlatMultiColorTimeSeries::from_mapping(&mut mapping); Self::MappingFlat { mapping, flat } } - _ => unreachable!(), - } + _ => unreachable!("We just checked that we are in ::Mapping variant"), + }); } - match self { + self + } + + pub fn flat_mut(&mut self) -> &mut FlatMultiColorTimeSeries<'a, P, T> { + match self.ensure_flat() { Self::Mapping(_mapping) => { - unreachable!("::Mapping veriant is already transformed to ::MappingFlat") + unreachable!( + "::Mapping variant is already transformed to ::MappingFlat in ensure_flat" + ) } Self::Flat(flat) => flat, Self::MappingFlat { flat, .. } => flat, @@ -107,12 +161,45 @@ where Self::MappingFlat { mapping, .. } => Either::Left(mapping.passbands()), } } + + /// Inserts new pair of passband and time series into the multicolor time series. + /// + /// It always converts [MultiColorTimeSeries] to [MultiColorTimeSeries::Mapping] variant. + /// Also it replaces existing time series if passband is already present, and returns old time + /// series. + pub fn insert(&mut self, passband: P, ts: TimeSeries<'a, T>) -> Option> { + match self.enforce_mapping() { + Self::Mapping(mapping) => mapping.0.insert(passband, ts), + _ => unreachable!("We just converted self to ::Mapping variant"), + } + } +} + +impl<'a, P, T> Default for MultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn default() -> Self { + Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new())) + } } +#[derive(Debug, Clone)] pub struct MappedMultiColorTimeSeries<'a, P: PassbandTrait, T: Float>( BTreeMap>, ); +impl<'a, P, T> PartialEq for MappedMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} + impl<'a, 'p, P, T> MappedMultiColorTimeSeries<'a, P, T> where P: PassbandTrait + 'p, @@ -148,13 +235,34 @@ where ) } + pub fn total_lenu(&self) -> usize { + self.0.values().map(|ts| ts.lenu()).sum() + } + + pub fn total_lenf(&self) -> T { + self.total_lenu().value_as::().unwrap() + } + + pub fn passband_count(&self) -> usize { + self.0.len() + } + pub fn passbands<'slf>( &'slf self, ) -> std::collections::btree_map::Keys<'slf, P, TimeSeries<'a, T>> where 'a: 'slf, { - self.keys() + self.0.keys() + } + + pub fn iter_ts<'slf>( + &'slf self, + ) -> std::collections::btree_map::Values<'slf, P, TimeSeries<'a, T>> + where + 'a: 'slf, + { + self.0.values() } pub fn iter_passband_set<'slf, 'ps>( @@ -233,6 +341,7 @@ impl<'a, P: PassbandTrait, T: Float> DerefMut for MappedMultiColorTimeSeries<'a, } } +#[derive(Debug, Clone)] pub struct FlatMultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { pub t: DataSample<'a, T>, pub m: DataSample<'a, T>, @@ -241,6 +350,19 @@ pub struct FlatMultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { passband_set: BTreeSet

, } +impl<'a, P, T> PartialEq for FlatMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn eq(&self, other: &Self) -> bool { + self.t == other.t + && self.m == other.m + && self.w == other.w + && self.passbands == other.passbands + } +} + impl<'a, P, T> FlatMultiColorTimeSeries<'a, P, T> where P: PassbandTrait, @@ -305,4 +427,135 @@ where passband_set: mapping.keys().cloned().collect(), } } + + pub fn total_lenu(&self) -> usize { + self.t.sample.len() + } + + pub fn total_lenf(&self) -> T { + self.t.sample.len().value_as::().unwrap() + } + + pub fn passband_count(&self) -> usize { + self.passband_set.len() + } +} + +impl<'a, P, T> From> for MappedMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn from(mut flat: FlatMultiColorTimeSeries<'a, P, T>) -> Self { + Self::from_flat(&mut flat) + } +} + +impl<'a, P, T> From> for FlatMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn from(mut mapped: MappedMultiColorTimeSeries<'a, P, T>) -> Self { + Self::from_mapping(&mut mapped.0) + } +} + +impl<'a, P, T> From> for MultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn from(flat: FlatMultiColorTimeSeries<'a, P, T>) -> Self { + Self::Flat(flat) + } +} + +impl<'a, P, T> From> for MultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn from(mapped: MappedMultiColorTimeSeries<'a, P, T>) -> Self { + Self::Mapping(mapped) + } +} + +impl<'a, P, T> From> for FlatMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn from(mcts: MultiColorTimeSeries<'a, P, T>) -> Self { + match mcts { + MultiColorTimeSeries::Flat(flat) => flat, + MultiColorTimeSeries::Mapping(mapped) => mapped.into(), + MultiColorTimeSeries::MappingFlat { flat, .. } => flat, + } + } +} + +impl<'a, P, T> From> for MappedMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn from(mcts: MultiColorTimeSeries<'a, P, T>) -> Self { + match mcts { + MultiColorTimeSeries::Flat(flat) => flat.into(), + MultiColorTimeSeries::Mapping(mapping) => mapping, + MultiColorTimeSeries::MappingFlat { mapping, .. } => mapping, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::MonochromePassband; + + use ndarray::Array1; + + #[test] + fn multi_color_ts_insert() { + let mut mcts = MultiColorTimeSeries::default(); + mcts.insert( + MonochromePassband::new(4700.0, "g"), + TimeSeries::new_without_weight(Array1::linspace(0.0, 1.0, 11), Array1::zeros(11)), + ); + assert_eq!(mcts.passband_count(), 1); + assert_eq!(mcts.total_lenu(), 11); + mcts.insert( + MonochromePassband::new(6200.0, "r"), + TimeSeries::new_without_weight(Array1::linspace(0.0, 1.0, 6), Array1::zeros(6)), + ); + assert_eq!(mcts.passband_count(), 2); + assert_eq!(mcts.total_lenu(), 17); + } + + fn compare_variants(mcts: MultiColorTimeSeries) { + let flat: FlatMultiColorTimeSeries<_, _> = mcts.clone().into(); + let mapped: MappedMultiColorTimeSeries<_, _> = mcts.clone().into(); + let mapped_from_flat: MappedMultiColorTimeSeries<_, _> = flat.clone().into(); + let flat_from_mapped: FlatMultiColorTimeSeries<_, _> = mapped.clone().into(); + assert_eq!(mapped, mapped_from_flat); + assert_eq!(flat, flat_from_mapped); + } + + #[test] + fn convert_between_variants() { + let mut mcts = MultiColorTimeSeries::default(); + compare_variants(mcts.clone()); + mcts.insert( + MonochromePassband::new(4700.0, "g"), + TimeSeries::new_without_weight(Array1::linspace(0.0, 1.0, 11), Array1::zeros(11)), + ); + compare_variants(mcts.clone()); + mcts.insert( + MonochromePassband::new(6200.0, "r"), + TimeSeries::new_without_weight(Array1::linspace(0.0, 1.0, 6), Array1::zeros(6)), + ); + compare_variants(mcts.clone()); + } } diff --git a/src/data/time_series.rs b/src/data/time_series.rs index 855e7819..aaf8230d 100644 --- a/src/data/time_series.rs +++ b/src/data/time_series.rs @@ -21,12 +21,22 @@ where pub m: DataSample<'a, T>, pub w: DataSample<'a, T>, m_weighted_mean: Option, + m_chi2: Option, m_reduced_chi2: Option, t_max_m: Option, t_min_m: Option, plateau: Option, } +impl<'a, T> PartialEq for TimeSeries<'a, T> +where + T: Float, +{ + fn eq(&self, other: &Self) -> bool { + self.t == other.t && self.m == other.m && self.w == other.w + } +} + macro_rules! time_series_getter { ($t: ty, $attr: ident, $getter: ident, $func: expr) => { // This lint is false-positive in macros @@ -84,6 +94,7 @@ where m, w, m_weighted_mean: None, + m_chi2: None, m_reduced_chi2: None, t_max_m: None, t_min_m: None, @@ -116,6 +127,7 @@ where m, w, m_weighted_mean: None, + m_chi2: None, m_reduced_chi2: None, t_max_m: None, t_min_m: None, @@ -140,20 +152,23 @@ where |ts: &mut TimeSeries| { ts.m.sample.weighted_mean(&ts.w.sample).unwrap() } ); - time_series_getter!(m_reduced_chi2, get_m_reduced_chi2, |ts: &mut TimeSeries< - T, - >| { + time_series_getter!(m_chi2, get_m_chi2, |ts: &mut TimeSeries| { let m_weighed_mean = ts.get_m_weighted_mean(); - let m_reduced_chi2 = Zip::from(&ts.m.sample) + let m_chi2 = Zip::from(&ts.m.sample) .and(&ts.w.sample) .fold(T::zero(), |chi2, &m, &w| { chi2 + (m - m_weighed_mean).powi(2) * w - }) - / (ts.lenf() - T::one()); - if m_reduced_chi2.is_zero() { + }); + if m_chi2.is_zero() { ts.plateau = Some(true); } - m_reduced_chi2 + m_chi2 + }); + + time_series_getter!(m_reduced_chi2, get_m_reduced_chi2, |ts: &mut TimeSeries< + T, + >| { + ts.get_m_chi2() / (ts.lenf() - T::one()) }); time_series_getter!(bool, plateau, is_plateau, |ts: &mut TimeSeries| { diff --git a/src/error.rs b/src/error.rs index 15c6f0a1..9f118304 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,5 @@ +use crate::data::multi_color_time_series::MappedMultiColorTimeSeries; +use crate::float_trait::Float; use crate::PassbandTrait; use std::collections::BTreeSet; @@ -28,6 +30,18 @@ pub enum MultiColorEvaluatorError { actual: BTreeSet, desired: BTreeSet, }, + + #[error("No time-series long enough: maximum length found is {maximum_actual}, while minimum required is {minimum_required}")] + AllTimeSeriesAreShort { + maximum_actual: usize, + minimum_required: usize, + }, + + #[error(r#"Underlying feature caused an error: "{0:?}""#)] + UnderlyingEvaluatorError(#[from] EvaluatorError), + + #[error("All time-series are flat")] + AllTimeSeriesAreFlat, } impl MultiColorEvaluatorError { @@ -43,6 +57,20 @@ impl MultiColorEvaluatorError { desired: desired.map(|p| p.name().into()).collect(), } } + + pub fn all_time_series_short( + mapped: &MappedMultiColorTimeSeries, + minimum_required: usize, + ) -> Self + where + P: PassbandTrait, + T: Float, + { + Self::AllTimeSeriesAreShort { + maximum_actual: mapped.iter_ts().map(|ts| ts.lenu()).max().unwrap_or(0), + minimum_required, + } + } } #[derive(Debug, thiserror::Error, PartialEq, Eq)] diff --git a/src/feature.rs b/src/feature.rs index 73772698..ac284b07 100644 --- a/src/feature.rs +++ b/src/feature.rs @@ -50,7 +50,7 @@ where PercentAmplitude, PercentDifferenceMagnitudePercentile, Periodogram(Periodogram), - _PeriodogramPeaks, + _PeriodogramPeaks(PeriodogramPeaks), ReducedChi2, Skew, StandardDeviation, diff --git a/src/features/_periodogram_peaks.rs b/src/features/_periodogram_peaks.rs new file mode 100644 index 00000000..aab3f6f9 --- /dev/null +++ b/src/features/_periodogram_peaks.rs @@ -0,0 +1,163 @@ +use crate::evaluator::*; +use crate::evaluator::{Deserialize, EvaluatorInfo, EvaluatorProperties, Serialize}; +use crate::peak_indices::peak_indices_reverse_sorted; +use crate::{ + number_ending, EvaluatorError, EvaluatorInfoTrait, FeatureEvaluator, + FeatureNamesDescriptionsTrait, Float, TimeSeries, +}; + +use schemars::JsonSchema; +use std::iter; + +macro_const! { + const PERIODOGRAM_PEAKS_DOC: &'static str = r#" +Peak evaluator for [Periodogram] + +- Depends on: **time**, **magnitude** (which have meaning of frequency and spectral density) +- Minimum number of observations: **1** +- Number of features: **2 * npeaks** +"#; +} + +#[doc(hidden)] +#[doc = PERIODOGRAM_PEAKS_DOC!()] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde( + from = "PeriodogramPeaksParameters", + into = "PeriodogramPeaksParameters" +)] +pub struct PeriodogramPeaks { + peaks: usize, + properties: Box, +} + +impl PeriodogramPeaks { + pub fn new(peaks: usize) -> Self { + assert!(peaks > 0, "Number of peaks should be at least one"); + let info = EvaluatorInfo { + size: 2 * peaks, + min_ts_length: 1, + t_required: true, + m_required: true, + w_required: false, + sorting_required: true, + variability_required: false, + }; + let names = (0..peaks) + .flat_map(|i| vec![format!("period_{}", i), format!("period_s_to_n_{}", i)]) + .collect(); + let descriptions = (0..peaks) + .flat_map(|i| { + vec![ + format!( + "period of the {}{} highest peak", + i + 1, + number_ending(i + 1), + ), + format!( + "Spectral density to spectral density standard deviation ratio of \ + the {}{} highest peak", + i + 1, + number_ending(i + 1) + ), + ] + }) + .collect(); + Self { + properties: EvaluatorProperties { + info, + names, + descriptions, + } + .into(), + peaks, + } + } + + pub fn get_peaks(&self) -> usize { + self.peaks + } + + #[inline] + pub fn default_peaks() -> usize { + 1 + } + + pub const fn doc() -> &'static str { + PERIODOGRAM_PEAKS_DOC + } +} + +impl Default for PeriodogramPeaks { + fn default() -> Self { + Self::new(Self::default_peaks()) + } +} + +impl EvaluatorInfoTrait for PeriodogramPeaks { + fn get_info(&self) -> &EvaluatorInfo { + &self.properties.info + } +} + +impl FeatureNamesDescriptionsTrait for PeriodogramPeaks { + fn get_names(&self) -> Vec<&str> { + self.properties.names.iter().map(String::as_str).collect() + } + + fn get_descriptions(&self) -> Vec<&str> { + self.properties + .descriptions + .iter() + .map(String::as_str) + .collect() + } +} + +impl FeatureEvaluator for PeriodogramPeaks +where + T: Float, +{ + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + let peak_indices = peak_indices_reverse_sorted(&ts.m.sample); + Ok(peak_indices + .iter() + .flat_map(|&i| { + iter::once(T::two() * T::PI() / ts.t.sample[i]) + .chain(iter::once(ts.m.signal_to_noise(ts.m.sample[i]))) + }) + .chain(iter::repeat(T::zero())) + .take(2 * self.peaks) + .collect()) + } +} + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename = "PeriodogramPeaks")] +struct PeriodogramPeaksParameters { + peaks: usize, +} + +impl From for PeriodogramPeaksParameters { + fn from(f: PeriodogramPeaks) -> Self { + Self { peaks: f.peaks } + } +} + +impl From for PeriodogramPeaks { + fn from(p: PeriodogramPeaksParameters) -> Self { + Self::new(p.peaks) + } +} + +impl JsonSchema for PeriodogramPeaks { + json_schema!(PeriodogramPeaksParameters, false); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::*; + + check_feature!(PeriodogramPeaks); +} diff --git a/src/features/mod.rs b/src/features/mod.rs index d5dea9e7..37a71e4b 100644 --- a/src/features/mod.rs +++ b/src/features/mod.rs @@ -1,5 +1,8 @@ //! Feature sctructs implements [crate::FeatureEvaluator] trait +mod _periodogram_peaks; +pub(crate) use _periodogram_peaks::PeriodogramPeaks; + mod amplitude; pub use amplitude::Amplitude; @@ -82,8 +85,8 @@ mod percent_difference_magnitude_percentile; pub use percent_difference_magnitude_percentile::PercentDifferenceMagnitudePercentile; mod periodogram; +pub use _periodogram_peaks::PeriodogramPeaks as _PeriodogramPeaks; pub use periodogram::Periodogram; -pub use periodogram::PeriodogramPeaks as _PeriodogramPeaks; mod reduced_chi2; pub use reduced_chi2::ReducedChi2; @@ -110,4 +113,5 @@ mod villar_fit; pub use villar_fit::{VillarFit, VillarInitsBounds, VillarLnPrior}; mod weighted_mean; + pub use weighted_mean::WeightedMean; diff --git a/src/features/periodogram.rs b/src/features/periodogram.rs index 407dd93f..bc71b781 100644 --- a/src/features/periodogram.rs +++ b/src/features/periodogram.rs @@ -1,162 +1,12 @@ use crate::evaluator::*; use crate::extractor::FeatureExtractor; -use crate::peak_indices::peak_indices_reverse_sorted; +use crate::features::_periodogram_peaks::PeriodogramPeaks; use crate::periodogram; use crate::periodogram::{AverageNyquistFreq, NyquistFreq, PeriodogramPower, PeriodogramPowerFft}; +use ndarray::Array1; use std::convert::TryInto; use std::fmt::Debug; -use std::iter; - -fn number_ending(i: usize) -> &'static str { - #[allow(clippy::match_same_arms)] - match (i % 10, i % 100) { - (1, 11) => "th", - (1, _) => "st", - (2, 12) => "th", - (2, _) => "nd", - (3, 13) => "th", - (3, _) => "rd", - (_, _) => "th", - } -} - -macro_const! { - const PERIODOGRAM_PEAK_DOC: &'static str = r#" -Peak evaluator for [Periodogram] -"#; -} - -#[doc(hidden)] -#[doc = PERIODOGRAM_PEAK_DOC!()] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde( - from = "PeriodogramPeaksParameters", - into = "PeriodogramPeaksParameters" -)] -pub struct PeriodogramPeaks { - peaks: usize, - properties: Box, -} - -impl PeriodogramPeaks { - pub fn new(peaks: usize) -> Self { - assert!(peaks > 0, "Number of peaks should be at least one"); - let info = EvaluatorInfo { - size: 2 * peaks, - min_ts_length: 1, - t_required: true, - m_required: true, - w_required: false, - sorting_required: true, - variability_required: false, - }; - let names = (0..peaks) - .flat_map(|i| vec![format!("period_{}", i), format!("period_s_to_n_{}", i)]) - .collect(); - let descriptions = (0..peaks) - .flat_map(|i| { - vec![ - format!( - "period of the {}{} highest peak of periodogram", - i + 1, - number_ending(i + 1), - ), - format!( - "Spectral density to spectral density standard deviation ratio of \ - the {}{} highest peak of periodogram", - i + 1, - number_ending(i + 1) - ), - ] - }) - .collect(); - Self { - properties: EvaluatorProperties { - info, - names, - descriptions, - } - .into(), - peaks, - } - } - - #[inline] - pub fn default_peaks() -> usize { - 1 - } - - pub const fn doc() -> &'static str { - PERIODOGRAM_PEAK_DOC - } -} - -impl Default for PeriodogramPeaks { - fn default() -> Self { - Self::new(Self::default_peaks()) - } -} - -impl EvaluatorInfoTrait for PeriodogramPeaks { - fn get_info(&self) -> &EvaluatorInfo { - &self.properties.info - } -} - -impl FeatureNamesDescriptionsTrait for PeriodogramPeaks { - fn get_names(&self) -> Vec<&str> { - self.properties.names.iter().map(String::as_str).collect() - } - - fn get_descriptions(&self) -> Vec<&str> { - self.properties - .descriptions - .iter() - .map(String::as_str) - .collect() - } -} - -impl FeatureEvaluator for PeriodogramPeaks -where - T: Float, -{ - fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - let peak_indices = peak_indices_reverse_sorted(&ts.m.sample); - Ok(peak_indices - .iter() - .flat_map(|&i| { - iter::once(T::two() * T::PI() / ts.t.sample[i]) - .chain(iter::once(ts.m.signal_to_noise(ts.m.sample[i]))) - }) - .chain(iter::repeat(T::zero())) - .take(2 * self.peaks) - .collect()) - } -} - -#[derive(Serialize, Deserialize, JsonSchema)] -#[serde(rename = "PeriodogramPeaks")] -struct PeriodogramPeaksParameters { - peaks: usize, -} - -impl From for PeriodogramPeaksParameters { - fn from(f: PeriodogramPeaks) -> Self { - Self { peaks: f.peaks } - } -} - -impl From for PeriodogramPeaks { - fn from(p: PeriodogramPeaksParameters) -> Self { - Self::new(p.peaks) - } -} - -impl JsonSchema for PeriodogramPeaks { - json_schema!(PeriodogramPeaksParameters, false); -} macro_const! { const DOC: &str = r#" @@ -183,7 +33,7 @@ series without observation errors (unity weights are used if required). You can #[doc = DOC!()] #[derive(Clone, Debug, Deserialize, Serialize)] #[serde( - bound = "T: Float, F: FeatureEvaluator + From + TryInto, >::Error: Debug,", + bound = "T: Float, F: FeatureEvaluator + From + TryInto, >::Error: Debug,", from = "PeriodogramParameters", into = "PeriodogramParameters" )] @@ -194,7 +44,11 @@ where resolution: f32, max_freq_factor: f32, nyquist: NyquistFreq, - feature_extractor: FeatureExtractor, + pub(crate) feature_extractor: FeatureExtractor, + // In can be re-defined in MultiColorPeriodogram + pub(crate) name_prefix: String, + // In can be re-defined in MultiColorPeriodogram + pub(crate) description_suffix: String, periodogram_algorithm: PeriodogramPower, properties: Box, } @@ -250,13 +104,13 @@ where feature .get_names() .iter() - .map(|name| "periodogram_".to_owned() + name), + .map(|name| format!("{}_{}", self.name_prefix, name)), ); self.properties.descriptions.extend( feature .get_descriptions() .into_iter() - .map(|desc| format!("{} of periodogram", desc)), + .map(|desc| format!("{} {}", desc, self.description_suffix)), ); self.feature_extractor.add_feature(feature); self @@ -270,24 +124,24 @@ where self } - fn periodogram(&self, ts: &mut TimeSeries) -> periodogram::Periodogram { + pub(crate) fn periodogram(&self, t: &[T]) -> periodogram::Periodogram { periodogram::Periodogram::from_t( self.periodogram_algorithm.clone(), - ts.t.as_slice(), + t, self.resolution, self.max_freq_factor, self.nyquist.clone(), ) } - pub fn power(&self, ts: &mut TimeSeries) -> Vec { - self.periodogram(ts).power(ts) + pub fn power(&self, ts: &mut TimeSeries) -> Array1 { + self.periodogram(ts.t.as_slice()).power(ts) } - pub fn freq_power(&self, ts: &mut TimeSeries) -> (Vec, Vec) { - let p = self.periodogram(ts); + pub fn freq_power(&self, ts: &mut TimeSeries) -> (Array1, Array1) { + let p = self.periodogram(ts.t.as_slice()); let power = p.power(ts); - let freq = (0..power.len()).map(|i| p.freq(i)).collect::>(); + let freq = (0..power.len()).map(|i| p.freq_by_index(i)).collect(); (freq, power) } } @@ -299,33 +153,44 @@ where { /// New [Periodogram] that finds given number of peaks pub fn new(peaks: usize) -> Self { - let peaks = PeriodogramPeaks::new(peaks); - let peak_names = peaks.properties.names.clone(); - let peak_descriptions = peaks.properties.descriptions.clone(); - let peaks_size_hint = peaks.size_hint(); - let peaks_min_ts_length = peaks.min_ts_length(); + Self::with_name_description( + peaks, + "periodogram", + "of periodogram (interpreting frequency as time, power as magnitude)", + ) + } + + pub(crate) fn with_name_description( + peaks: usize, + name_prefix: impl ToString, + description_suffix: impl ToString, + ) -> Self { let info = EvaluatorInfo { - size: peaks_size_hint, - min_ts_length: usize::max(peaks_min_ts_length, 2), + size: 0, + min_ts_length: 2, t_required: true, m_required: true, w_required: false, sorting_required: true, variability_required: false, }; - Self { + let mut slf = Self { properties: EvaluatorProperties { info, - names: peak_names, - descriptions: peak_descriptions, + names: vec![], + descriptions: vec![], } .into(), resolution: Self::default_resolution(), + name_prefix: name_prefix.to_string(), + description_suffix: description_suffix.to_string(), max_freq_factor: Self::default_max_freq_factor(), nyquist: AverageNyquistFreq.into(), - feature_extractor: FeatureExtractor::new(vec![peaks.into()]), + feature_extractor: FeatureExtractor::new(vec![]), periodogram_algorithm: PeriodogramPowerFft::new().into(), - } + }; + slf.add_feature(PeriodogramPeaks::new(peaks).into()); + slf } } @@ -333,15 +198,12 @@ impl Periodogram where T: Float, F: FeatureEvaluator + From + TryInto, - >::Error: Debug, + >::Error: Debug, { fn transform_ts(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + self.check_ts(ts)?; let (freq, power) = self.freq_power(ts); - Ok(TmArrays { - t: freq.into(), - m: power.into(), - }) + Ok(TmArrays { t: freq, m: power }) } } @@ -368,7 +230,7 @@ impl EvaluatorInfoTrait for Periodogram where T: Float, F: FeatureEvaluator + From + TryInto, - >::Error: Debug, + >::Error: Debug, { fn get_info(&self) -> &EvaluatorInfo { &self.properties.info @@ -379,7 +241,7 @@ impl FeatureNamesDescriptionsTrait for Periodogram where T: Float, F: FeatureEvaluator + From + TryInto, - >::Error: Debug, + >::Error: Debug, { fn get_names(&self) -> Vec<&str> { self.properties.names.iter().map(String::as_str).collect() @@ -398,7 +260,7 @@ impl FeatureEvaluator for Periodogram where T: Float, F: FeatureEvaluator + From + TryInto, - >::Error: Debug, + >::Error: Debug, { transformer_eval!(); } @@ -422,7 +284,7 @@ impl From> for PeriodogramParameters where T: Float, F: FeatureEvaluator + From + TryInto, - >::Error: Debug, + >::Error: Debug, { fn from(f: Periodogram) -> Self { let Periodogram { @@ -431,13 +293,13 @@ where nyquist, feature_extractor, periodogram_algorithm, - properties: _, + .. } = f; let mut features = feature_extractor.into_vec(); let rest_of_features = features.split_off(1); let periodogram_peaks: PeriodogramPeaks = features.pop().unwrap().try_into().unwrap(); - let peaks = periodogram_peaks.peaks; + let peaks = periodogram_peaks.get_peaks(); Self { resolution, diff --git a/src/features/stetson_k.rs b/src/features/stetson_k.rs index a28d3c06..7110e239 100644 --- a/src/features/stetson_k.rs +++ b/src/features/stetson_k.rs @@ -62,12 +62,11 @@ where T: Float, { fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - let chi2 = ts.get_m_reduced_chi2() * (ts.lenf() - T::one()); let mean = ts.get_m_weighted_mean(); let value = Zip::from(&ts.m.sample) .and(&ts.w.sample) .fold(T::zero(), |acc, &y, &w| acc + T::abs(y - mean) * T::sqrt(w)) - / T::sqrt(ts.lenf() * chi2); + / T::sqrt(ts.lenf() * ts.get_m_chi2()); Ok(vec![value]) } } diff --git a/src/lib.rs b/src/lib.rs index 36c6c79e..8dcbc914 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,7 +32,7 @@ pub use float_trait::Float; mod lnerfc; -mod multicolor; +pub mod multicolor; pub use multicolor::*; mod nl_fit; @@ -44,6 +44,9 @@ pub use nl_fit::LmsderCurveFit; pub use nl_fit::{prior, LnPrior, LnPrior1D}; pub use nl_fit::{CurveFitAlgorithm, McmcCurveFit}; +mod number_ending; +pub(crate) use number_ending::number_ending; + #[doc(hidden)] pub mod periodogram; pub use periodogram::recurrent_sin_cos::RecurrentSinCos; diff --git a/src/multicolor/features/color_of_maximum.rs b/src/multicolor/features/color_of_maximum.rs index 9639f875..cd20ee0f 100644 --- a/src/multicolor/features/color_of_maximum.rs +++ b/src/multicolor/features/color_of_maximum.rs @@ -11,6 +11,10 @@ pub use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; use std::fmt::Debug; +/// Difference maximum value magnitudes of two passbands +/// +/// Note that maximum is calculated for each passband separately, and maximum has mathematical +/// meaning, not "magnitudial" (astronomical) one. #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] pub struct ColorOfMaximum

@@ -27,6 +31,10 @@ impl

ColorOfMaximum

where P: PassbandTrait, { + /// Create new [ColorOfMaximum] evaluator + /// + /// # Arguments + /// - `passbands` - two passbands pub fn new(passbands: [P; 2]) -> Self { let set: BTreeSet<_> = passbands.clone().into(); Self { diff --git a/src/multicolor/features/color_of_median.rs b/src/multicolor/features/color_of_median.rs index a6779e4b..83a0e78c 100644 --- a/src/multicolor/features/color_of_median.rs +++ b/src/multicolor/features/color_of_median.rs @@ -14,6 +14,9 @@ pub use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; use std::fmt::Debug; +/// Difference of median magnitudes in two passbands +/// +/// Note that median is calculated for each passband separately #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] pub struct ColorOfMedian

diff --git a/src/multicolor/features/color_of_minimum.rs b/src/multicolor/features/color_of_minimum.rs index 72de5c7d..3129544c 100644 --- a/src/multicolor/features/color_of_minimum.rs +++ b/src/multicolor/features/color_of_minimum.rs @@ -11,6 +11,10 @@ pub use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; use std::fmt::Debug; +/// Difference of minimum magnitudes of two passbands +/// +/// Note that minimum is calculated for each passband separately, and maximum has mathematical +/// meaning, not "magnitudial" (astronomical) one. #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] pub struct ColorOfMinimum

@@ -27,6 +31,10 @@ impl

ColorOfMinimum

where P: PassbandTrait, { + /// Create new [ColorOfMinimum] evaluator + /// + /// # Arguments + /// - `passbands` - two passbands pub fn new(passbands: [P; 2]) -> Self { let set: BTreeSet<_> = passbands.clone().into(); Self { diff --git a/src/multicolor/features/mod.rs b/src/multicolor/features/mod.rs index 54b5a4c7..10c77d7a 100644 --- a/src/multicolor/features/mod.rs +++ b/src/multicolor/features/mod.rs @@ -6,3 +6,6 @@ pub use color_of_median::ColorOfMedian; mod color_of_minimum; pub use color_of_minimum::ColorOfMinimum; + +mod multi_color_periodogram; +pub use multi_color_periodogram::{MultiColorPeriodogram, MultiColorPeriodogramNormalisation}; diff --git a/src/multicolor/features/multi_color_periodogram.rs b/src/multicolor/features/multi_color_periodogram.rs new file mode 100644 index 00000000..4829335c --- /dev/null +++ b/src/multicolor/features/multi_color_periodogram.rs @@ -0,0 +1,310 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::TmArrays; +use crate::evaluator::{ + EvaluatorInfo, EvaluatorInfoTrait, FeatureEvaluator, FeatureNamesDescriptionsTrait, OwnedArrays, +}; +use crate::features::{Periodogram, PeriodogramPeaks}; +use crate::float_trait::Float; +use crate::multicolor::multicolor_evaluator::*; +use crate::multicolor::{PassbandSet, PassbandTrait}; +use crate::periodogram::{self, NyquistFreq, PeriodogramPower}; + +use ndarray::Array1; +use std::fmt::Debug; + +/// Normalisation of periodogram across passbands +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub enum MultiColorPeriodogramNormalisation { + /// Weight individual periodograms by the number of observations in each passband. + /// Useful if no weight is given to observations + Count, + /// Weight individual periodograms by $\chi^2 = \sum \left(\frac{m_i - \bar{m}}{\delta_i}\right)^2$ + /// + /// Be aware that if no weight are given to observations + /// (i.e. via [TimeSeries::new_without_weight]) unity weights are assumed and this is NOT + /// equivalent to [::Count], but weighting by magnitude variance. + Chi2, +} + +/// Multi-passband periodogram +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde( + bound = "T: Float, F: FeatureEvaluator + From + TryInto, >::Error: Debug," +)] +pub struct MultiColorPeriodogram +where + T: Float, + F: FeatureEvaluator, +{ + // We use it to not reimplement some internals + monochrome: Periodogram, + normalization: MultiColorPeriodogramNormalisation, +} + +impl MultiColorPeriodogram +where + T: Float, + F: FeatureEvaluator + From, +{ + pub fn new(peaks: usize, normalization: MultiColorPeriodogramNormalisation) -> Self { + let monochrome = Periodogram::with_name_description( + peaks, + "multicolor_periodogram", + "of multi-color periodogram (interpreting frequency as time, power as magnitude)", + ); + Self { + monochrome, + normalization, + } + } + + #[inline] + pub fn default_peaks() -> usize { + PeriodogramPeaks::default_peaks() + } + + #[inline] + pub fn default_resolution() -> f32 { + Periodogram::::default_resolution() + } + + #[inline] + pub fn default_max_freq_factor() -> f32 { + Periodogram::::default_max_freq_factor() + } + + /// Set frequency resolution + /// + /// The larger frequency resolution allows to find peak period with better precision + pub fn set_freq_resolution(&mut self, resolution: f32) -> &mut Self { + self.monochrome.set_freq_resolution(resolution); + self + } + + /// Multiply maximum (Nyquist) frequency + /// + /// Maximum frequency is Nyquist frequncy multiplied by this factor. The larger factor allows + /// to find larger frequency and makes [PeriodogramPowerFft] more precise. However large + /// frequencies can show false peaks + pub fn set_max_freq_factor(&mut self, max_freq_factor: f32) -> &mut Self { + self.monochrome.set_max_freq_factor(max_freq_factor); + self + } + + /// Define Nyquist frequency + pub fn set_nyquist(&mut self, nyquist: NyquistFreq) -> &mut Self { + self.monochrome.set_nyquist(nyquist); + self + } + + /// Extend a feature to extract from periodogram + pub fn add_feature(&mut self, feature: F) -> &mut Self { + self.monochrome.add_feature(feature); + self + } + + pub fn set_periodogram_algorithm( + &mut self, + periodogram_power: PeriodogramPower, + ) -> &mut Self { + self.monochrome.set_periodogram_algorithm(periodogram_power); + self + } +} + +impl MultiColorPeriodogram +where + T: Float, + F: FeatureEvaluator + From + TryInto, + >::Error: Debug, +{ + fn power_from_periodogram<'slf, 'a, 'mcts, P>( + &self, + p: &periodogram::Periodogram, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + P: PassbandTrait, + { + let ts_weights = { + let mut a: Array1<_> = match self.normalization { + MultiColorPeriodogramNormalisation::Count => { + mcts.mapping_mut().values().map(|ts| ts.lenf()).collect() + } + MultiColorPeriodogramNormalisation::Chi2 => mcts + .mapping_mut() + .values_mut() + .map(|ts| ts.get_m_chi2()) + .collect(), + }; + let norm = a.sum(); + if norm.is_zero() { + match self.normalization { + MultiColorPeriodogramNormalisation::Count => { + return Err(MultiColorEvaluatorError::all_time_series_short( + mcts.mapping_mut(), + self.min_ts_length(), + )); + } + MultiColorPeriodogramNormalisation::Chi2 => { + return Err(MultiColorEvaluatorError::AllTimeSeriesAreFlat); + } + } + } + a /= norm; + a + }; + mcts.mapping_mut() + .values_mut() + .zip(ts_weights.iter()) + .filter(|(ts, _ts_weight)| self.monochrome.check_ts_length(ts).is_ok()) + .map(|(ts, &ts_weight)| { + let mut power = p.power(ts); + power *= ts_weight; + power + }) + .reduce(|mut acc, power| { + acc += &power; + acc + }) + .ok_or_else(|| { + MultiColorEvaluatorError::all_time_series_short( + mcts.mapping_mut(), + self.min_ts_length(), + ) + }) + } + + pub fn power<'slf, 'a, 'mcts, P>( + &self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + P: PassbandTrait, + { + self.power_from_periodogram( + &self.monochrome.periodogram(mcts.flat_mut().t.as_slice()), + mcts, + ) + } + + pub fn freq_power<'slf, 'a, 'mcts, P>( + &self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result<(Array1, Array1), MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + P: PassbandTrait, + { + let p = self.monochrome.periodogram(mcts.flat_mut().t.as_slice()); + let power = self.power_from_periodogram(&p, mcts)?; + let freq = (0..power.len()).map(|i| p.freq_by_index(i)).collect(); + Ok((freq, power)) + } +} + +impl MultiColorPeriodogram +where + T: Float, + F: FeatureEvaluator + From + TryInto, + >::Error: Debug, +{ + fn transform_mcts_to_ts

( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> + where + P: PassbandTrait, + { + let (freq, power) = self.freq_power(mcts)?; + Ok(TmArrays { t: freq, m: power }) + } +} + +impl EvaluatorInfoTrait for MultiColorPeriodogram +where + T: Float, + F: FeatureEvaluator + From + TryInto, + >::Error: Debug, +{ + fn get_info(&self) -> &EvaluatorInfo { + self.monochrome.get_info() + } +} + +impl FeatureNamesDescriptionsTrait for MultiColorPeriodogram +where + T: Float, + F: FeatureEvaluator + From + TryInto, + >::Error: Debug, +{ + fn get_names(&self) -> Vec<&str> { + self.monochrome.get_names() + } + + fn get_descriptions(&self) -> Vec<&str> { + self.monochrome.get_descriptions() + } +} + +impl MultiColorPassbandSetTrait

for MultiColorPeriodogram +where + T: Float, + P: PassbandTrait, + F: FeatureEvaluator, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &PassbandSet::AllAvailable + } +} + +impl MultiColorEvaluator for MultiColorPeriodogram +where + P: PassbandTrait, + T: Float, + F: FeatureEvaluator + From + TryInto, + >::Error: Debug, +{ + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { + let arrays = self.transform_mcts_to_ts(mcts)?; + let mut ts = arrays.ts(); + self.monochrome + .feature_extractor + .eval(&mut ts) + .map_err(From::from) + } + + /// Returns vector of feature values and fill invalid components with given value + fn eval_or_fill_multicolor<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + fill_value: T, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { + let arrays = match self.transform_mcts_to_ts(mcts) { + Ok(arrays) => arrays, + Err(_) => return Ok(vec![fill_value; self.size_hint()]), + }; + let mut ts = arrays.ts(); + Ok(self + .monochrome + .feature_extractor + .eval_or_fill(&mut ts, fill_value)) + } +} diff --git a/src/multicolor/mod.rs b/src/multicolor/mod.rs index fb09f714..bcb785b3 100644 --- a/src/multicolor/mod.rs +++ b/src/multicolor/mod.rs @@ -1,4 +1,4 @@ -mod features; +pub mod features; mod monochrome_feature; pub use monochrome_feature::MonochromeFeature; diff --git a/src/multicolor/monochrome_feature.rs b/src/multicolor/monochrome_feature.rs index 2b483d29..7e9e4111 100644 --- a/src/multicolor/monochrome_feature.rs +++ b/src/multicolor/monochrome_feature.rs @@ -15,6 +15,7 @@ use std::collections::BTreeSet; use std::fmt::Debug; use std::marker::PhantomData; +/// Multi-color feature which evaluates non-color dependent feature for each passband. #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(bound( deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float, F: FeatureEvaluator" @@ -35,6 +36,11 @@ where T: Float, F: FeatureEvaluator, { + /// Creates a new instance of `MonochromeFeature`. + /// + /// # Arguments + /// - `feature` - non-multi-color feature to evaluate for each passband. + /// - `passband_set` - set of passbands to evaluate the feature for. pub fn new(feature: F, passband_set: BTreeSet

) -> Self { let names = passband_set .iter() @@ -130,3 +136,32 @@ where } } } + +#[cfg(test)] +mod tests { + use super::*; + + use crate::features::Mean; + use crate::multicolor::passband::MonochromePassband; + use crate::Feature; + + #[test] + fn test_monochrome_feature() { + let feature: MonochromeFeature, f64, Feature<_>> = + MonochromeFeature::new( + Mean::default().into(), + [ + MonochromePassband::new(4700e-8, "g"), + MonochromePassband::new(6200e-8, "r"), + ] + .into_iter() + .collect(), + ); + assert_eq!(feature.get_names(), vec!["mean_g", "mean_r"]); + assert_eq!( + feature.get_descriptions(), + vec!["mean magnitude, passband g", "mean magnitude, passband r"] + ); + assert_eq!(feature.get_info().size, 2); + } +} diff --git a/src/multicolor/multicolor_evaluator.rs b/src/multicolor/multicolor_evaluator.rs index c60dddf7..6d67b4b4 100644 --- a/src/multicolor/multicolor_evaluator.rs +++ b/src/multicolor/multicolor_evaluator.rs @@ -16,14 +16,20 @@ pub use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; use std::fmt::Debug; +/// Trait for getting alphabetically sorted passbands #[enum_dispatch] pub trait MultiColorPassbandSetTrait

where P: PassbandTrait, { + /// Get passband set for this evaluator fn get_passband_set(&self) -> &PassbandSet

; } +/// Enum for passband set, which can be either fixed set or all available passbands. +/// This is used for [MultiColorEvaluator]s, which can be evaluated on all available passbands +/// (for example [MultiColorPeriodogram](super::features::MultiColorPeriodogram)) or on fixed set of +/// passbands (for example [ColorOfMaximum](super::ColorOfMaximum)). #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] #[non_exhaustive] @@ -31,7 +37,9 @@ pub enum PassbandSet

where P: Ord, { + /// Fixed set of passbands FixedSet(BTreeSet

), + /// All available passbands AllAvailable, } @@ -44,6 +52,7 @@ where } } +/// Helper error for [MultiColorEvaluator] enum InternalMctsError { MultiColorEvaluatorError(MultiColorEvaluatorError), InternalWrongPassbandSet, @@ -78,6 +87,7 @@ impl InternalMctsError { } } +/// Trait for multi-color feature evaluators #[enum_dispatch] pub trait MultiColorEvaluator: FeatureNamesDescriptionsTrait diff --git a/src/multicolor/multicolor_extractor.rs b/src/multicolor/multicolor_extractor.rs index 5fe4d70d..3ef1e49e 100644 --- a/src/multicolor/multicolor_extractor.rs +++ b/src/multicolor/multicolor_extractor.rs @@ -12,6 +12,7 @@ use std::collections::BTreeSet; use std::fmt::Debug; use std::marker::PhantomData; +/// Bulk feature evaluator. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde( into = "MultiColorExtractorParameters", @@ -37,6 +38,10 @@ where T: Float, MCF: MultiColorEvaluator, { + /// Create a new [MultiColorExtractor] + /// + /// # Arguments + /// `features` - A vector of multi-color features to be evaluated pub fn new(features: Vec) -> Self { let passband_set = { let set: BTreeSet<_> = features diff --git a/src/multicolor/passband/dump_passband.rs b/src/multicolor/passband/dump_passband.rs index ffd73cde..675d7e6b 100644 --- a/src/multicolor/passband/dump_passband.rs +++ b/src/multicolor/passband/dump_passband.rs @@ -4,6 +4,7 @@ pub use schemars::JsonSchema; pub use serde::{Deserialize, Serialize}; use std::fmt::Debug; +/// A passband for the cases where we don't care about the actual passband. #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)] pub struct DumpPassband {} @@ -12,3 +13,14 @@ impl PassbandTrait for DumpPassband { "" } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dump_passband() { + let passband = DumpPassband {}; + assert_eq!(passband.name(), ""); + } +} diff --git a/src/multicolor/passband/monochrome_passband.rs b/src/multicolor/passband/monochrome_passband.rs index 987bc79a..09f6f29f 100644 --- a/src/multicolor/passband/monochrome_passband.rs +++ b/src/multicolor/passband/monochrome_passband.rs @@ -7,6 +7,7 @@ pub use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::fmt::Debug; +/// A passband specified by a single wavelength. #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct MonochromePassband<'a, T> { pub name: &'a str, @@ -17,6 +18,11 @@ impl<'a, T> MonochromePassband<'a, T> where T: Float, { + /// Create a new `MonochromePassband`. + /// + /// # Arguments + /// - `wavelength`: The wavelength of the passband, panic if it is not a positive normal number. + /// - `name`: The name of the passband. pub fn new(wavelength: T, name: &'a str) -> Self { assert!( wavelength.is_normal(), @@ -67,3 +73,15 @@ where self.name } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_monochrome_passband() { + let passband = MonochromePassband::new(1.0, "test"); + assert_eq!(passband.name(), "test"); + assert_eq!(passband.wavelength, 1.0); + } +} diff --git a/src/number_ending.rs b/src/number_ending.rs new file mode 100644 index 00000000..f994e03f --- /dev/null +++ b/src/number_ending.rs @@ -0,0 +1,51 @@ +/// Return a suffix for a number, like "st", "nd", or "th". +pub(crate) fn number_ending(i: usize) -> &'static str { + #[allow(clippy::match_same_arms)] + match (i % 10, i % 100) { + (1, 11) => "th", + (1, _) => "st", + (2, 12) => "th", + (2, _) => "nd", + (3, 13) => "th", + (3, _) => "rd", + (_, _) => "th", + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test() { + assert_eq!(number_ending(0), "th"); + assert_eq!(number_ending(1), "st"); + assert_eq!(number_ending(2), "nd"); + assert_eq!(number_ending(3), "rd"); + assert_eq!(number_ending(4), "th"); + assert_eq!(number_ending(5), "th"); + assert_eq!(number_ending(6), "th"); + assert_eq!(number_ending(7), "th"); + assert_eq!(number_ending(8), "th"); + assert_eq!(number_ending(9), "th"); + assert_eq!(number_ending(10), "th"); + assert_eq!(number_ending(11), "th"); + assert_eq!(number_ending(12), "th"); + assert_eq!(number_ending(13), "th"); + assert_eq!(number_ending(14), "th"); + assert_eq!(number_ending(15), "th"); + assert_eq!(number_ending(16), "th"); + assert_eq!(number_ending(17), "th"); + assert_eq!(number_ending(18), "th"); + assert_eq!(number_ending(19), "th"); + assert_eq!(number_ending(20), "th"); + assert_eq!(number_ending(21), "st"); + assert_eq!(number_ending(22), "nd"); + assert_eq!(number_ending(23), "rd"); + assert_eq!(number_ending(24), "th"); + assert_eq!(number_ending(25), "th"); + assert_eq!(number_ending(100), "th"); + assert_eq!(number_ending(101), "st"); + assert_eq!(number_ending(102), "nd"); + } +} diff --git a/src/periodogram/mod.rs b/src/periodogram/mod.rs index 6bc477ea..b83407e0 100644 --- a/src/periodogram/mod.rs +++ b/src/periodogram/mod.rs @@ -5,6 +5,7 @@ use crate::float_trait::Float; use conv::ConvAsUtil; use enum_dispatch::enum_dispatch; +use ndarray::Array1; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -92,11 +93,11 @@ where ) } - pub fn freq(&self, i: usize) -> T { + pub fn freq_by_index(&self, i: usize) -> T { self.freq_grid.step * (i + 1).approx().unwrap() } - pub fn power(&self, ts: &mut TimeSeries) -> Vec { + pub fn power(&self, ts: &mut TimeSeries) -> Array1 { self.periodogram_power.power(&self.freq_grid, ts) } } @@ -110,16 +111,17 @@ mod tests { use crate::data::SortedArray; use crate::peak_indices::peak_indices_reverse_sorted; - use light_curve_common::{all_close, linspace}; + use approx::assert_relative_eq; + use ndarray::{arr1, s}; use rand::prelude::*; #[test] fn compr_direct_with_scipy() { const OMEGA_SIN: f64 = 0.07; const N: usize = 100; - let t = linspace(0.0, 99.0, N); - let m: Vec<_> = t.iter().map(|&x| f64::sin(OMEGA_SIN * x)).collect(); - let mut ts = TimeSeries::new_without_weight(&t, &m); + let t = Array1::linspace(0.0, 99.0, N); + let m = t.mapv(|x| f64::sin(OMEGA_SIN * x)); + let mut ts = TimeSeries::new_without_weight(t, m); let periodogram = Periodogram::new( PeriodogramPowerDirect.into(), FreqGrid { @@ -127,10 +129,10 @@ mod tests { size: 1, }, ); - all_close( - &[periodogram.power(&mut ts)[0] * 2.0 / (N as f64 - 1.0)], - &[1.0], - 1.0 / (N as f64), + assert_relative_eq!( + periodogram.power(&mut ts)[0] * 2.0 / (N as f64 - 1.0), + 1.0, + epsilon = 1.0 / (N as f64), ); // import numpy as np @@ -147,26 +149,28 @@ mod tests { size: 5, }; let periodogram = Periodogram::new(PeriodogramPowerDirect.into(), freq_grid.clone()); - all_close( - &linspace( + assert_relative_eq!( + Array1::linspace( freq_grid.step, freq_grid.step * freq_grid.size as f64, freq_grid.size, - ), - &(0..freq_grid.size) - .map(|i| periodogram.freq(i)) - .collect::>(), - 1e-12, + ) + .view(), + (0..freq_grid.size) + .map(|i| periodogram.freq_by_index(i)) + .collect::>() + .view(), + epsilon = 1e-12, ); - let desired = [ + let desired = arr1(&[ 16.99018018, 18.57722516, 21.96049738, 28.15056806, 36.66519435, - ]; + ]); let actual = periodogram.power(&mut ts); - all_close(&actual[..], &desired[..], 1e-6); + assert_relative_eq!(actual, desired, epsilon = 1e-6); } #[test] @@ -176,14 +180,14 @@ mod tests { const RESOLUTION: f32 = 1.0; const MAX_FREQ_FACTOR: f32 = 1.0; - let t = linspace(0.0, (N - 1) as f64, N); - let m: Vec<_> = t.iter().map(|&x| f64::sin(OMEGA * x)).collect(); - let mut ts = TimeSeries::new_without_weight(&t, &m); + let t = Array1::linspace(0.0, (N - 1) as f64, N); + let m = t.mapv(|x| f64::sin(OMEGA * x)); + let mut ts = TimeSeries::new_without_weight(t, m); let nyquist: NyquistFreq = AverageNyquistFreq.into(); let direct = Periodogram::from_t( PeriodogramPowerDirect.into(), - &t, + ts.t.as_slice(), RESOLUTION, MAX_FREQ_FACTOR, nyquist.clone(), @@ -191,13 +195,17 @@ mod tests { .power(&mut ts); let fft = Periodogram::from_t( PeriodogramPowerFft::new().into(), - &t, + ts.t.as_slice(), RESOLUTION, MAX_FREQ_FACTOR, nyquist, ) .power(&mut ts); - all_close(&fft[..direct.len() - 1], &direct[..direct.len() - 1], 1e-8); + assert_relative_eq!( + fft.slice(s![..direct.len() - 1]), + direct.slice(s![..direct.len() - 1]), + epsilon = 1e-8 + ); } #[test] @@ -209,17 +217,14 @@ mod tests { const RESOLUTION: f32 = 4.0; const MAX_FREQ_FACTOR: f32 = 1.0; - let t = linspace(0.0, (N - 1) as f64, N); - let m: Vec<_> = t - .iter() - .map(|&x| f64::sin(OMEGA1 * x) + AMPLITUDE2 * f64::cos(OMEGA2 * x)) - .collect(); - let mut ts = TimeSeries::new_without_weight(&t, &m); + let t = Array1::linspace(0.0, (N - 1) as f64, N); + let m = t.mapv(|x| f64::sin(OMEGA1 * x) + AMPLITUDE2 * f64::cos(OMEGA2 * x)); + let mut ts = TimeSeries::new_without_weight(t, m); let nyquist: NyquistFreq = AverageNyquistFreq.into(); let direct = Periodogram::from_t( PeriodogramPowerDirect.into(), - &t, + ts.t.as_slice(), RESOLUTION, MAX_FREQ_FACTOR, nyquist.clone(), @@ -227,7 +232,7 @@ mod tests { .power(&mut ts); let fft = Periodogram::from_t( PeriodogramPowerFft::new().into(), - &t, + ts.t.as_slice(), RESOLUTION, MAX_FREQ_FACTOR, nyquist, diff --git a/src/periodogram/power_direct.rs b/src/periodogram/power_direct.rs index a8020318..7009314e 100644 --- a/src/periodogram/power_direct.rs +++ b/src/periodogram/power_direct.rs @@ -4,6 +4,7 @@ use crate::periodogram::freq::FreqGrid; use crate::periodogram::power_trait::*; use crate::periodogram::recurrent_sin_cos::*; +use ndarray::Array1; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -22,7 +23,7 @@ impl PeriodogramPowerTrait for PeriodogramPowerDirect where T: Float, { - fn power(&self, freq: &FreqGrid, ts: &mut TimeSeries) -> Vec { + fn power(&self, freq: &FreqGrid, ts: &mut TimeSeries) -> Array1 { let m_mean = ts.m.get_mean(); let sin_cos_omega_tau = SinCosOmegaTau::new(freq.step, ts.t.as_slice().iter()); diff --git a/src/periodogram/power_fft.rs b/src/periodogram/power_fft.rs index 40526df9..dadd6b05 100644 --- a/src/periodogram/power_fft.rs +++ b/src/periodogram/power_fft.rs @@ -5,6 +5,7 @@ use crate::periodogram::freq::FreqGrid; use crate::periodogram::power_trait::*; use conv::{ConvAsUtil, RoundToNearest}; +use ndarray::Array1; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::cell::RefCell; @@ -73,11 +74,11 @@ impl PeriodogramPowerTrait for PeriodogramPowerFft where T: Float, { - fn power(&self, freq: &FreqGrid, ts: &mut TimeSeries) -> Vec { + fn power(&self, freq: &FreqGrid, ts: &mut TimeSeries) -> Array1 { let m_std2 = ts.m.get_std2(); if m_std2.is_zero() { - return vec![T::zero(); freq.size.next_power_of_two()]; + return Array1::zeros(freq.size.next_power_of_two()); } let grid = TimeGrid::from_freq_grid(freq); diff --git a/src/periodogram/power_trait.rs b/src/periodogram/power_trait.rs index 46c271be..92e03618 100644 --- a/src/periodogram/power_trait.rs +++ b/src/periodogram/power_trait.rs @@ -3,6 +3,7 @@ use crate::float_trait::Float; use crate::periodogram::freq::FreqGrid; use enum_dispatch::enum_dispatch; +use ndarray::Array1; use std::fmt::Debug; /// Periodogram execution algorithm @@ -11,5 +12,5 @@ pub trait PeriodogramPowerTrait: Debug + Clone + Send where T: Float, { - fn power(&self, freq: &FreqGrid, ts: &mut TimeSeries) -> Vec; + fn power(&self, freq: &FreqGrid, ts: &mut TimeSeries) -> Array1; } diff --git a/src/time_series.rs b/src/time_series.rs new file mode 100644 index 00000000..e69de29b