diff --git a/CHANGELOG.md b/CHANGELOG.md index a93e02d2..6ab65c46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed --- +- **Breaking:** {Bazin,Villar}Fit requires variability now and do not accept flat time series anymore https://github.com/light-curve/light-curve-feature/issues/112 https://github.com/light-curve/light-curve-feature/pull/113 ### Security diff --git a/src/data/data_sample.rs b/src/data/data_sample.rs new file mode 100644 index 00000000..44e5073f --- /dev/null +++ b/src/data/data_sample.rs @@ -0,0 +1,314 @@ +use crate::data::sorted_array::SortedArray; +use crate::float_trait::Float; +use crate::types::CowArray1; + +use conv::prelude::*; +use ndarray::{s, Array1, ArrayView1, Zip}; + +/// A [`TimeSeries`] component +#[derive(Clone, Debug)] +pub struct DataSample<'a, T> +where + T: Float, +{ + pub sample: CowArray1<'a, T>, + sorted: Option>, + min: Option, + max: Option, + mean: Option, + median: Option, + std: Option, + std2: Option, +} + +macro_rules! data_sample_getter { + ($attr: ident, $getter: ident, $func: expr, $method_sorted: ident) => { + // This lint is false-positive in macros + // https://github.com/rust-lang/rust-clippy/issues/1553 + #[allow(clippy::redundant_closure_call)] + pub fn $getter(&mut self) -> T { + match self.$attr { + Some(x) => x, + None => { + self.$attr = Some(match self.sorted.as_ref() { + Some(sorted) => sorted.$method_sorted(), + None => $func(self), + }); + self.$attr.unwrap() + } + } + } + }; + ($attr: ident, $getter: ident, $func: expr) => { + // This lint is false-positive in macros + // https://github.com/rust-lang/rust-clippy/issues/1553 + #[allow(clippy::redundant_closure_call)] + pub fn $getter(&mut self) -> T { + match self.$attr { + Some(x) => x, + None => { + self.$attr = Some($func(self)); + self.$attr.unwrap() + } + } + } + }; +} + +impl<'a, T> DataSample<'a, T> +where + T: Float, +{ + pub fn new(sample: CowArray1<'a, T>) -> Self { + Self { + sample, + sorted: None, + min: None, + max: None, + mean: None, + median: None, + std: None, + std2: None, + } + } + + pub fn as_slice(&mut self) -> &[T] { + if !self.sample.is_standard_layout() { + let owned: Array1<_> = self.sample.iter().copied().collect::>().into(); + self.sample = owned.into(); + } + self.sample.as_slice().unwrap() + } + + pub fn get_sorted(&mut self) -> &SortedArray { + if self.sorted.is_none() { + self.sorted = Some(self.sample.to_vec().into()); + } + self.sorted.as_ref().unwrap() + } + + fn set_min_max(&mut self) { + let (min, max) = + self.sample + .slice(s![1..]) + .fold((self.sample[0], self.sample[0]), |(min, max), &x| { + if x > max { + (min, x) + } else if x < min { + (x, max) + } else { + (min, max) + } + }); + self.min = Some(min); + self.max = Some(max); + } + + data_sample_getter!( + min, + get_min, + |ds: &mut DataSample<'a, T>| { + ds.set_min_max(); + ds.min.unwrap() + }, + minimum + ); + data_sample_getter!( + max, + get_max, + |ds: &mut DataSample<'a, T>| { + ds.set_min_max(); + ds.max.unwrap() + }, + maximum + ); + data_sample_getter!(mean, get_mean, |ds: &mut DataSample<'a, T>| { + ds.sample.mean().expect("time series must be non-empty") + }); + data_sample_getter!(median, get_median, |ds: &mut DataSample<'a, T>| { + ds.get_sorted().median() + }); + data_sample_getter!(std, get_std, |ds: &mut DataSample<'a, T>| { + ds.get_std2().sqrt() + }); + data_sample_getter!(std2, get_std2, |ds: &mut DataSample<'a, T>| { + // Benchmarks show that it is faster than `ndarray::ArrayBase::var(T::one)` + let mean = ds.get_mean(); + ds.sample + .fold(T::zero(), |sum, &x| sum + (x - mean).powi(2)) + / (ds.sample.len() - 1).approx().unwrap() + }); + + pub fn signal_to_noise(&mut self, value: T) -> T { + if self.get_std().is_zero() { + T::zero() + } else { + (value - self.get_mean()) / self.get_std() + } + } + + /// Returns true if all values are equal. Always true for zero- or one- length + pub fn is_all_same(&self) -> bool { + if self.sample.is_empty() { + return true; + } + if self.max.is_some() && self.max == self.min { + return true; + } + if self.std2 == Some(T::zero()) { + return true; + } + if let Some(sorted) = &self.sorted { + return sorted[0] == sorted[sorted.len() - 1]; + } + let x0 = self.sample[0]; + // all() returns true for the empty slice, i.e. single-point time series + Zip::from(self.sample.slice(s![1..])).all(|&x| x == x0) + } +} + +impl<'a, T> From> for DataSample<'a, T> +where + T: Float, +{ + fn from(sorted: SortedArray) -> Self { + let sample = sorted.0.clone().into(); + Self { + sample, + sorted: Some(sorted), + min: None, + max: None, + median: None, + mean: None, + std: None, + std2: None, + } + } +} + +impl<'a, T, Slice: ?Sized> From<&'a Slice> for DataSample<'a, T> +where + T: Float, + Slice: AsRef<[T]>, +{ + fn from(s: &'a Slice) -> Self { + ArrayView1::from(s).into() + } +} + +impl<'a, T> From> for DataSample<'a, T> +where + T: Float, +{ + fn from(v: Vec) -> Self { + Array1::from(v).into() + } +} + +impl<'a, T> From> for DataSample<'a, T> +where + T: Float, +{ + fn from(a: ArrayView1<'a, T>) -> Self { + Self::new(a.into()) + } +} + +impl<'a, T> From> for DataSample<'a, T> +where + T: Float, +{ + fn from(a: Array1) -> Self { + Self::new(a.into()) + } +} + +impl<'a, T> From> for DataSample<'a, T> +where + T: Float, +{ + fn from(a: CowArray1<'a, T>) -> Self { + Self::new(a) + } +} + +#[cfg(test)] +#[allow(clippy::unreadable_literal)] +#[allow(clippy::excessive_precision)] +mod tests { + use super::*; + + use approx::assert_relative_eq; + + macro_rules! data_sample_test { + ($name: ident, $method: ident, $desired: literal, $x: tt $(,)?) => { + #[test] + fn $name() { + let x = $x; + let desired = $desired; + + let mut ds: DataSample<_> = DataSample::from(&x); + assert_relative_eq!(ds.$method(), desired, epsilon = 1e-6); + assert_relative_eq!(ds.$method(), desired, epsilon = 1e-6); + + let mut ds: DataSample<_> = DataSample::from(&x); + ds.get_sorted(); + assert_relative_eq!(ds.$method(), desired, epsilon = 1e-6); + assert_relative_eq!(ds.$method(), desired, epsilon = 1e-6); + } + }; + } + + data_sample_test!( + data_sample_min, + get_min, + -7.79420906, + [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], + ); + + data_sample_test!( + data_sample_max, + get_max, + 6.73375373, + [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], + ); + + data_sample_test!( + data_sample_mean, + get_mean, + -0.21613426, + [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], + ); + + data_sample_test!( + data_sample_median_odd, + get_median, + 3.28436964, + [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], + ); + + data_sample_test!( + data_sample_median_even, + get_median, + 5.655794743124782, + [9.47981408, 3.86815751, 9.90299294, -2.986894, 7.44343197, 1.52751816], + ); + + data_sample_test!( + data_sample_std, + get_std, + 6.7900544035968435, + [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], + ); + + /// https://github.com/light-curve/light-curve-feature/issues/95 + #[test] + fn std2_overflow() { + const N: usize = (1 << 24) + 2; + // Such a large integer cannot be represented as a float32 + let x = Array1::linspace(0.0_f32, 1.0, N); + let mut ds = DataSample::new(x.into()); + // This should not panic + let _std2 = ds.get_std2(); + } +} diff --git a/src/data/mod.rs b/src/data/mod.rs new file mode 100644 index 00000000..aafaeb99 --- /dev/null +++ b/src/data/mod.rs @@ -0,0 +1,12 @@ +mod data_sample; +pub use data_sample::DataSample; + +mod multi_color_time_series; +pub use multi_color_time_series::MultiColorTimeSeries; + +mod sorted_array; +pub use sorted_array::SortedArray; + +mod time_series; + +pub use time_series::TimeSeries; diff --git a/src/data/multi_color_time_series.rs b/src/data/multi_color_time_series.rs new file mode 100644 index 00000000..535543aa --- /dev/null +++ b/src/data/multi_color_time_series.rs @@ -0,0 +1,308 @@ +use crate::data::TimeSeries; +use crate::float_trait::Float; +use crate::multicolor::PassbandTrait; +use crate::{DataSample, PassbandSet}; + +use itertools::Either; +use itertools::EitherOrBoth; +use itertools::Itertools; +use std::collections::{BTreeMap, BTreeSet}; +use std::ops::{Deref, DerefMut}; + +pub enum MultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { + Mapping(MappedMultiColorTimeSeries<'a, P, T>), + Flat(FlatMultiColorTimeSeries<'a, P, T>), + MappingFlat { + mapping: MappedMultiColorTimeSeries<'a, P, T>, + flat: FlatMultiColorTimeSeries<'a, P, T>, + }, +} + +impl<'a, 'p, P, T> MultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait + 'p, + T: Float, +{ + pub fn from_map(map: impl Into>>) -> Self { + Self::Mapping(MappedMultiColorTimeSeries::new(map)) + } + + pub fn from_flat( + t: impl Into>, + m: impl Into>, + w: impl Into>, + passbands: impl Into>, + ) -> Self { + Self::Flat(FlatMultiColorTimeSeries::new(t, m, w, passbands)) + } + + pub fn mapping_mut(&mut self) -> &mut MappedMultiColorTimeSeries<'a, P, T> { + if matches!(self, MultiColorTimeSeries::Flat(_)) { + let dummy_self = Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new())); + *self = match std::mem::replace(self, dummy_self) { + Self::Flat(mut flat) => { + let mapping = MappedMultiColorTimeSeries::from_flat(&mut flat); + Self::MappingFlat { mapping, flat } + } + _ => unreachable!(), + } + } + match self { + Self::Mapping(mapping) => mapping, + Self::Flat(_flat) => { + unreachable!("::Flat variant is already transofrmed to ::MappingFlat") + } + Self::MappingFlat { mapping, .. } => mapping, + } + } + + pub fn mapping(&self) -> Option<&MappedMultiColorTimeSeries<'a, P, T>> { + match self { + Self::Mapping(mapping) => Some(mapping), + Self::Flat(_flat) => None, + Self::MappingFlat { mapping, .. } => Some(mapping), + } + } + + pub fn flat_mut(&mut self) -> &mut FlatMultiColorTimeSeries<'a, P, T> { + if matches!(self, MultiColorTimeSeries::Mapping(_)) { + let dummy_self = Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new())); + *self = match std::mem::replace(self, dummy_self) { + Self::Mapping(mut mapping) => { + let flat = FlatMultiColorTimeSeries::from_mapping(&mut mapping); + Self::MappingFlat { mapping, flat } + } + _ => unreachable!(), + } + } + match self { + Self::Mapping(_mapping) => { + unreachable!("::Mapping veriant is already transformed to ::MappingFlat") + } + Self::Flat(flat) => flat, + Self::MappingFlat { flat, .. } => flat, + } + } + + pub fn flat(&self) -> Option<&FlatMultiColorTimeSeries<'a, P, T>> { + match self { + Self::Mapping(_mapping) => None, + Self::Flat(flat) => Some(flat), + Self::MappingFlat { flat, .. } => Some(flat), + } + } + + pub fn passbands<'slf>( + &'slf self, + ) -> Either< + std::collections::btree_map::Keys<'slf, P, TimeSeries<'a, T>>, + std::collections::btree_set::Iter

, + > + where + 'a: 'slf, + { + match self { + Self::Mapping(mapping) => Either::Left(mapping.passbands()), + Self::Flat(flat) => Either::Right(flat.passband_set.iter()), + Self::MappingFlat { mapping, .. } => Either::Left(mapping.passbands()), + } + } +} + +pub struct MappedMultiColorTimeSeries<'a, P: PassbandTrait, T: Float>( + BTreeMap>, +); + +impl<'a, 'p, P, T> MappedMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait + 'p, + T: Float, +{ + pub fn new(map: impl Into>>) -> Self { + Self(map.into()) + } + + pub fn from_flat(flat: &mut FlatMultiColorTimeSeries) -> Self { + let mut map = BTreeMap::new(); + let groups = itertools::multizip(( + flat.t.as_slice().iter(), + flat.m.as_slice().iter(), + flat.w.as_slice().iter(), + flat.passbands.iter(), + )) + .group_by(|(_t, _m, _w, p)| (*p).clone()); + for (p, group) in &groups { + let (t_vec, m_vec, w_vec) = map + .entry(p.clone()) + .or_insert_with(|| (vec![], vec![], vec![])); + for (&t, &m, &w, _p) in group { + t_vec.push(t); + m_vec.push(m); + w_vec.push(w); + } + } + Self( + map.into_iter() + .map(|(p, (t, m, w))| (p, TimeSeries::new(t, m, w))) + .collect(), + ) + } + + pub fn passbands<'slf>( + &'slf self, + ) -> std::collections::btree_map::Keys<'slf, P, TimeSeries<'a, T>> + where + 'a: 'slf, + { + self.keys() + } + + pub fn iter_passband_set<'slf, 'ps>( + &'slf self, + passband_set: &'ps PassbandSet

, + ) -> impl Iterator>)> + 'slf + where + 'a: 'slf, + 'ps: 'a, + { + match passband_set { + PassbandSet::AllAvailable => Either::Left(self.iter().map(|(p, ts)| (p, Some(ts)))), + PassbandSet::FixedSet(set) => Either::Right(self.iter_matched_passbands(set.iter())), + } + } + + pub fn iter_passband_set_mut<'slf, 'ps>( + &'slf mut self, + passband_set: &'ps PassbandSet

, + ) -> impl Iterator>)> + 'slf + where + 'a: 'slf, + 'ps: 'a, + { + match passband_set { + PassbandSet::AllAvailable => Either::Left(self.iter_mut().map(|(p, ts)| (p, Some(ts)))), + PassbandSet::FixedSet(set) => { + Either::Right(self.iter_matched_passbands_mut(set.iter())) + } + } + } + + pub fn iter_matched_passbands( + &self, + passband_it: impl Iterator, + ) -> impl Iterator>)> { + passband_it.map(|p| (p, self.get(p))) + } + + pub fn iter_matched_passbands_mut( + &mut self, + passband_it: impl Iterator, + ) -> impl Iterator>)> { + passband_it + .merge_join_by(self.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2)) + .filter_map(|either_or_both| match either_or_both { + // mcts misses required passband + EitherOrBoth::Left(p) => Some((p, None)), + // mcts has some passban passband_set doesn't require + EitherOrBoth::Right(_) => None, + // passbands match + EitherOrBoth::Both(p, (_, ts)) => Some((p, Some(ts))), + }) + } +} + +impl<'a, P: PassbandTrait, T: Float> FromIterator<(P, TimeSeries<'a, T>)> + for MappedMultiColorTimeSeries<'a, P, T> +{ + fn from_iter)>>(iter: I) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl<'a, P: PassbandTrait, T: Float> Deref for MappedMultiColorTimeSeries<'a, P, T> { + type Target = BTreeMap>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, P: PassbandTrait, T: Float> DerefMut for MappedMultiColorTimeSeries<'a, P, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +pub struct FlatMultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { + pub t: DataSample<'a, T>, + pub m: DataSample<'a, T>, + pub w: DataSample<'a, T>, + pub passbands: Vec

, + passband_set: BTreeSet

, +} + +impl<'a, P, T> FlatMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + pub fn new( + t: impl Into>, + m: impl Into>, + w: impl Into>, + passbands: impl Into>, + ) -> Self { + let t = t.into(); + let m = m.into(); + let w = w.into(); + let passbands = passbands.into(); + let passband_set = passbands.iter().cloned().collect(); + + assert_eq!( + t.sample.len(), + m.sample.len(), + "t and m should have the same size" + ); + assert_eq!( + m.sample.len(), + w.sample.len(), + "m and err should have the same size" + ); + assert_eq!( + t.sample.len(), + passbands.len(), + "t and passbands should have the same size" + ); + + Self { + t, + m, + w, + passbands, + passband_set, + } + } + + pub fn from_mapping(mapping: &mut BTreeMap>) -> Self { + let (t, m, w, passbands): (Vec<_>, Vec<_>, Vec<_>, _) = mapping + .iter_mut() + .map(|(p, ts)| { + itertools::multizip(( + ts.t.as_slice().iter().copied(), + ts.m.as_slice().iter().copied(), + ts.w.as_slice().iter().copied(), + std::iter::repeat(p.clone()), + )) + }) + .kmerge_by(|(t1, _m1, _w1, _p1), (t2, _m2, _w2, _p2)| t1 <= t2) + .multiunzip(); + + Self { + t: t.into(), + m: m.into(), + w: w.into(), + passbands, + passband_set: mapping.keys().cloned().collect(), + } + } +} diff --git a/src/sorted_array.rs b/src/data/sorted_array.rs similarity index 99% rename from src/sorted_array.rs rename to src/data/sorted_array.rs index 917fcf84..3686ff82 100644 --- a/src/sorted_array.rs +++ b/src/data/sorted_array.rs @@ -1,5 +1,6 @@ use crate::error::SortedArrayError; use crate::float_trait::Float; + use conv::prelude::*; use itertools::Itertools; use ndarray::Array1; diff --git a/src/data/time_series.rs b/src/data/time_series.rs new file mode 100644 index 00000000..855e7819 --- /dev/null +++ b/src/data/time_series.rs @@ -0,0 +1,252 @@ +use crate::data::data_sample::DataSample; +use crate::float_trait::Float; + +use conv::prelude::*; +use itertools::Itertools; +#[cfg(test)] +use ndarray::Array1; +use ndarray::Zip; +use ndarray_stats::SummaryStatisticsExt; + +/// Time series object to be put into [Feature](crate::Feature) +/// +/// This struct caches it's properties, like mean magnitude value, etc., that's why mutable +/// reference is required fot feature evaluation +#[derive(Clone, Debug)] +pub struct TimeSeries<'a, T> +where + T: Float, +{ + pub t: DataSample<'a, T>, + pub m: DataSample<'a, T>, + pub w: DataSample<'a, T>, + m_weighted_mean: Option, + m_reduced_chi2: Option, + t_max_m: Option, + t_min_m: Option, + plateau: Option, +} + +macro_rules! time_series_getter { + ($t: ty, $attr: ident, $getter: ident, $func: expr) => { + // This lint is false-positive in macros + // https://github.com/rust-lang/rust-clippy/issues/1553 + #[allow(clippy::redundant_closure_call)] + pub fn $getter(&mut self) -> $t { + match self.$attr { + Some(x) => x, + None => { + self.$attr = Some($func(self)); + self.$attr.unwrap() + } + } + } + }; + + ($attr: ident, $getter: ident, $func: expr) => { + time_series_getter!(T, $attr, $getter, $func); + }; +} + +impl<'a, T> TimeSeries<'a, T> +where + T: Float, +{ + /// Construct `TimeSeries` from array-like objects + /// + /// `t` is time, `m` is magnitude (or flux), `w` is weights. + /// + /// All arrays must have the same length, `t` must increase monotonically. Input arrays could be + /// [`ndarray::Array1`], [`ndarray::ArrayView1`], 1-D [`ndarray::CowArray`], or `&[T]`. Several + /// features assumes that `w` array corresponds to inverse square errors of `m`. + pub fn new( + t: impl Into>, + m: impl Into>, + w: impl Into>, + ) -> Self { + let t = t.into(); + let m = m.into(); + let w = w.into(); + + assert_eq!( + t.sample.len(), + m.sample.len(), + "t and m should have the same size" + ); + assert_eq!( + m.sample.len(), + w.sample.len(), + "m and err should have the same size" + ); + + Self { + t, + m, + w, + m_weighted_mean: None, + m_reduced_chi2: None, + t_max_m: None, + t_min_m: None, + plateau: None, + } + } + + /// Construct [`TimeSeries`] from time and magnitude (flux) + /// + /// It is the same as [`TimeSeries::new`], but sets unity weights. It doesn't recommended to use + /// it for features dependent on weights / observation errors like [`crate::StetsonK`] or + /// [`crate::LinearFit`]. + pub fn new_without_weight( + t: impl Into>, + m: impl Into>, + ) -> Self { + let t = t.into(); + let m = m.into(); + + assert_eq!( + t.sample.len(), + m.sample.len(), + "t and m should have the same size" + ); + + let w = T::array0_unity().broadcast(t.sample.len()).unwrap().into(); + + Self { + t, + m, + w, + m_weighted_mean: None, + m_reduced_chi2: None, + t_max_m: None, + t_min_m: None, + plateau: None, + } + } + + /// Time series length + #[inline] + pub fn lenu(&self) -> usize { + self.t.sample.len() + } + + /// Float approximating time series length + pub fn lenf(&self) -> T { + self.lenu().approx().unwrap() + } + + time_series_getter!( + m_weighted_mean, + get_m_weighted_mean, + |ts: &mut TimeSeries| { ts.m.sample.weighted_mean(&ts.w.sample).unwrap() } + ); + + time_series_getter!(m_reduced_chi2, get_m_reduced_chi2, |ts: &mut TimeSeries< + T, + >| { + let m_weighed_mean = ts.get_m_weighted_mean(); + let m_reduced_chi2 = Zip::from(&ts.m.sample) + .and(&ts.w.sample) + .fold(T::zero(), |chi2, &m, &w| { + chi2 + (m - m_weighed_mean).powi(2) * w + }) + / (ts.lenf() - T::one()); + if m_reduced_chi2.is_zero() { + ts.plateau = Some(true); + } + m_reduced_chi2 + }); + + time_series_getter!(bool, plateau, is_plateau, |ts: &mut TimeSeries| { + ts.m.is_all_same() + }); + + fn set_t_min_max_m(&mut self) { + let (i_min, i_max) = self + .m + .as_slice() + .iter() + .position_minmax() + .into_option() + .expect("time series must be non-empty"); + self.t_min_m = Some(self.t.sample[i_min]); + self.t_max_m = Some(self.t.sample[i_max]); + } + + pub fn get_t_min_m(&mut self) -> T { + if self.t_min_m.is_none() { + self.set_t_min_max_m(); + } + self.t_min_m.unwrap() + } + + pub fn get_t_max_m(&mut self) -> T { + if self.t_max_m.is_none() { + self.set_t_min_max_m(); + } + self.t_max_m.unwrap() + } +} + +// We really don't want it to be public, it is a private helper for test-data functions +#[cfg(test)] +impl<'a, T, D> From<(D, D, D)> for TimeSeries<'a, T> +where + T: Float, + D: Into>, +{ + fn from(v: (D, D, D)) -> Self { + Self::new(v.0, v.1, v.2) + } +} + +#[cfg(test)] +impl<'a, T> From<&'a (Array1, Array1, Array1)> for TimeSeries<'a, T> +where + T: Float, +{ + fn from(v: &'a (Array1, Array1, Array1)) -> Self { + Self::new(v.0.view(), v.1.view(), v.2.view()) + } +} + +#[cfg(test)] +#[allow(clippy::unreadable_literal)] +#[allow(clippy::excessive_precision)] +mod tests { + use super::*; + + use approx::assert_relative_eq; + + #[test] + fn time_series_m_weighted_mean() { + let t: Vec<_> = (0..5).map(|i| i as f64).collect(); + let m = [ + 12.77883145, + 18.89988406, + 17.55633632, + 18.36073996, + 11.83854198, + ]; + let w = [0.1282489, 0.10576467, 0.32102692, 0.12962352, 0.10746144]; + let mut ts = TimeSeries::new(&t, &m, &w); + // np.average(m, weights=w) + let desired = 16.31817047752941; + assert_relative_eq!(ts.get_m_weighted_mean(), desired, epsilon = 1e-6); + } + + #[test] + fn time_series_m_reduced_chi2() { + let t: Vec<_> = (0..5).map(|i| i as f64).collect(); + let m = [ + 12.77883145, + 18.89988406, + 17.55633632, + 18.36073996, + 11.83854198, + ]; + let w = [0.1282489, 0.10576467, 0.32102692, 0.12962352, 0.10746144]; + let mut ts = TimeSeries::new(&t, &m, &w); + let desired = 1.3752251301435465; + assert_relative_eq!(ts.get_m_reduced_chi2(), desired, epsilon = 1e-6); + } +} diff --git a/src/error.rs b/src/error.rs index 38bb58ce..15c6f0a1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,7 @@ +use crate::PassbandTrait; + +use std::collections::BTreeSet; + /// Error returned from [crate::FeatureEvaluator] #[derive(Debug, thiserror::Error, PartialEq, Eq)] pub enum EvaluatorError { @@ -11,10 +15,41 @@ pub enum EvaluatorError { ZeroDivision(&'static str), } +#[derive(Debug, thiserror::Error, PartialEq, Eq)] +pub enum MultiColorEvaluatorError { + #[error("Passband {passband} time-series caused error: {error:?}")] + MonochromeEvaluatorError { + passband: String, + error: EvaluatorError, + }, + + #[error("Wrong passbands {actual:?}, {desired:?} are desired")] + WrongPassbandsError { + actual: BTreeSet, + desired: BTreeSet, + }, +} + +impl MultiColorEvaluatorError { + pub fn wrong_passbands_error<'a, P>( + actual: impl Iterator, + desired: impl Iterator, + ) -> Self + where + P: PassbandTrait + 'a, + { + Self::WrongPassbandsError { + actual: actual.map(|p| p.name().into()).collect(), + desired: desired.map(|p| p.name().into()).collect(), + } + } +} + #[derive(Debug, thiserror::Error, PartialEq, Eq)] pub enum SortedArrayError { #[error("SortedVec constructors accept sorted arrays only")] Unsorted, + #[error("SortedVec constructors accept contiguous arrays only")] NonContiguous, } diff --git a/src/evaluator.rs b/src/evaluator.rs index aa83d6db..e21f891f 100644 --- a/src/evaluator.rs +++ b/src/evaluator.rs @@ -1,6 +1,6 @@ +pub use crate::data::TimeSeries; pub use crate::error::EvaluatorError; pub use crate::float_trait::Float; -pub use crate::time_series::TimeSeries; pub use conv::errors::GeneralError; use enum_dispatch::enum_dispatch; @@ -12,7 +12,7 @@ use serde::de::DeserializeOwned; pub use serde::{Deserialize, Serialize}; pub use std::fmt::Debug; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] pub struct EvaluatorInfo { pub size: usize, pub min_ts_length: usize, @@ -20,9 +20,10 @@ pub struct EvaluatorInfo { pub m_required: bool, pub w_required: bool, pub sorting_required: bool, + pub variability_required: bool, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct EvaluatorProperties { pub info: EvaluatorInfo, pub names: Vec, @@ -69,6 +70,47 @@ pub trait EvaluatorInfoTrait { fn is_sorting_required(&self) -> bool { self.get_info().sorting_required } + + /// If feature requires magnitude array elements to be different + fn is_variability_required(&self) -> bool { + self.get_info().variability_required + } + + fn check_ts(&self, ts: &mut TimeSeries) -> Result<(), EvaluatorError> + where + F: Float, + { + self.check_ts_length(ts)?; + self.check_ts_variability(ts) + } + + /// Checks if [TimeSeries] has enough points to evaluate the feature + fn check_ts_length(&self, ts: &TimeSeries) -> Result<(), EvaluatorError> + where + F: Float, + { + let length = ts.lenu(); + if length < self.min_ts_length() { + Err(EvaluatorError::ShortTimeSeries { + actual: length, + minimum: self.min_ts_length(), + }) + } else { + Ok(()) + } + } + + /// Checks if [TimeSeries] meets variability requirement + fn check_ts_variability(&self, ts: &mut TimeSeries) -> Result<(), EvaluatorError> + where + F: Float, + { + if self.is_variability_required() && ts.is_plateau() { + Err(EvaluatorError::FlatTimeSeries) + } else { + Ok(()) + } + } } // impl

EvaluatorInfoTrait for P @@ -124,8 +166,14 @@ pub trait FeatureEvaluator: + DeserializeOwned + JsonSchema { + /// Version of [FeatureEvaluator::eval] which can panic for incorrect input + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError>; + /// Vector of feature values or `EvaluatorError` - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError>; + fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + self.check_ts(ts)?; + self.eval_no_ts_check(ts) + } /// Returns vector of feature values and fill invalid components with given value fn eval_or_fill(&self, ts: &mut TimeSeries, fill_value: T) -> Vec { @@ -134,46 +182,6 @@ pub trait FeatureEvaluator: Err(_) => vec![fill_value; self.size_hint()], } } - - /// Checks if [TimeSeries] has enough points to evaluate the feature - fn check_ts_length(&self, ts: &TimeSeries) -> Result { - let length = ts.lenu(); - if length < self.min_ts_length() { - Err(EvaluatorError::ShortTimeSeries { - actual: length, - minimum: self.min_ts_length(), - }) - } else { - Ok(length) - } - } -} - -pub fn get_nonzero_m_std(ts: &mut TimeSeries) -> Result { - let std = ts.m.get_std(); - if std.is_zero() || ts.is_plateau() { - Err(EvaluatorError::FlatTimeSeries) - } else { - Ok(std) - } -} - -pub fn get_nonzero_m_std2(ts: &mut TimeSeries) -> Result { - let std2 = ts.m.get_std2(); - if std2.is_zero() || ts.is_plateau() { - Err(EvaluatorError::FlatTimeSeries) - } else { - Ok(std2) - } -} - -pub fn get_nonzero_reduced_chi2(ts: &mut TimeSeries) -> Result { - let reduced_chi2 = ts.get_m_reduced_chi2(); - if reduced_chi2.is_zero() || ts.is_plateau() { - Err(EvaluatorError::FlatTimeSeries) - } else { - Ok(reduced_chi2) - } } pub trait OwnedArrays diff --git a/src/extractor.rs b/src/extractor.rs index e7e5a1a1..0eedee08 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -1,8 +1,8 @@ +use crate::data::TimeSeries; use crate::error::EvaluatorError; use crate::evaluator::*; use crate::feature::Feature; use crate::float_trait::Float; -use crate::time_series::TimeSeries; use std::marker::PhantomData; @@ -46,6 +46,7 @@ where m_required: features.iter().any(|x| x.is_m_required()), w_required: features.iter().any(|x| x.is_w_required()), sorting_required: features.iter().any(|x| x.is_sorting_required()), + variability_required: features.iter().any(|x| x.is_variability_required()), } .into(); Self { @@ -118,10 +119,10 @@ where T: Float, F: FeatureEvaluator, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let mut vec = Vec::with_capacity(self.size_hint()); for x in &self.features { - vec.extend(x.eval(ts)?); + vec.extend(x.eval_no_ts_check(ts)?); } Ok(vec) } diff --git a/src/feature.rs b/src/feature.rs index 018b7b91..73772698 100644 --- a/src/feature.rs +++ b/src/feature.rs @@ -1,8 +1,8 @@ +use crate::data::TimeSeries; use crate::evaluator::*; use crate::extractor::FeatureExtractor; use crate::features::*; use crate::float_trait::Float; -use crate::time_series::TimeSeries; use crate::transformers::Transformer; use enum_dispatch::enum_dispatch; diff --git a/src/features/amplitude.rs b/src/features/amplitude.rs index cad0f5ad..ced5bc9f 100644 --- a/src/features/amplitude.rs +++ b/src/features/amplitude.rs @@ -47,6 +47,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for Amplitude { @@ -63,8 +64,7 @@ impl FeatureEvaluator for Amplitude where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![T::half() * (ts.m.get_max() - ts.m.get_min())]) } } diff --git a/src/features/anderson_darling_normal.rs b/src/features/anderson_darling_normal.rs index 189ea786..8bf947a9 100644 --- a/src/features/anderson_darling_normal.rs +++ b/src/features/anderson_darling_normal.rs @@ -46,6 +46,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: true, ); impl FeatureNamesDescriptionsTrait for AndersonDarlingNormal { @@ -62,10 +63,10 @@ impl FeatureEvaluator for AndersonDarlingNormal where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - let size = self.check_ts_length(ts)?; - let m_std = get_nonzero_m_std(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + let size = ts.lenu(); let m_mean = ts.m.get_mean(); + let m_std = ts.m.get_std(); let sum: f64 = ts.m.get_sorted() .as_ref() diff --git a/src/features/bazin_fit.rs b/src/features/bazin_fit.rs index e5ce4b9d..9dc6bae8 100644 --- a/src/features/bazin_fit.rs +++ b/src/features/bazin_fit.rs @@ -108,7 +108,8 @@ lazy_info!( t_required: true, m_required: true, w_required: true, - sorting_required: true, // improve reproducibility + sorting_required: true, // improves reproducibility + variability_required: true, ); struct Params<'a, T> { @@ -418,13 +419,22 @@ mod tests { check_feature!(BazinFit); feature_test!( - bazin_fit_plateau, + bazin_fit_almost_plateau, [BazinFit::default()], [0.0, 0.0, 10.0, 5.0, 5.0, 0.0], // initial model parameters and zero chi2 linspace(0.0, 10.0, 11), - [0.0; 11], + linspace(0.0, 1e-100, 11), // make it a bit non-flat ); + #[test] + fn bazin_fit_plateau() { + let fe = BazinFit::default(); + let t = linspace(0.0, 10.0, 11); + let f = [0.0; 11]; + let mut ts = TimeSeries::new_without_weight(&t, &f); + assert!(fe.eval(&mut ts).is_err()); + } + fn bazin_fit_noisy(eval: BazinFit) { const N: usize = 50; diff --git a/src/features/beyond_n_std.rs b/src/features/beyond_n_std.rs index 6f55e9ab..5c05c971 100644 --- a/src/features/beyond_n_std.rs +++ b/src/features/beyond_n_std.rs @@ -94,6 +94,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl Default for BeyondNStd @@ -122,8 +123,7 @@ impl FeatureEvaluator for BeyondNStd where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_mean = ts.m.get_mean(); let threshold = ts.m.get_std() * self.nstd; let count_beyond = ts.m.sample.fold(0, |count, &m| { diff --git a/src/features/bins.rs b/src/features/bins.rs index 42c9e1da..b4a1b442 100644 --- a/src/features/bins.rs +++ b/src/features/bins.rs @@ -62,6 +62,7 @@ where m_required: true, w_required: true, sorting_required: true, + variability_required: false, }; Self { properties: EvaluatorProperties { @@ -94,6 +95,7 @@ where self.properties.info.size += feature.size_hint(); self.properties.info.min_ts_length = usize::max(self.properties.info.min_ts_length, feature.min_ts_length()); + self.properties.info.variability_required |= feature.is_variability_required(); self.properties.names.extend( feature .get_names() @@ -135,7 +137,12 @@ where } fn transform_ts(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + if ts.lenu() < self.min_ts_length() { + return Err(EvaluatorError::ShortTimeSeries { + actual: ts.lenu(), + minimum: self.min_ts_length(), + }); + } let (t, m, w): (Vec<_>, Vec<_>, Vec<_>) = ts.t.as_slice() .iter() diff --git a/src/features/cusum.rs b/src/features/cusum.rs index 6c5adcf2..0ece8c23 100644 --- a/src/features/cusum.rs +++ b/src/features/cusum.rs @@ -46,6 +46,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: true, + variability_required: true, ); impl FeatureNamesDescriptionsTrait for Cusum { @@ -62,10 +63,9 @@ impl FeatureEvaluator for Cusum where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - let m_std = get_nonzero_m_std(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_mean = ts.m.get_mean(); + let m_std = ts.m.get_std(); let (_last_cusum, min_cusum, max_cusum) = ts.m.as_slice().iter().fold( (T::zero(), T::infinity(), -T::infinity()), |(mut cusum, min_cusum, max_cusum), &m| { diff --git a/src/features/duration.rs b/src/features/duration.rs index d26a8187..1c4e473c 100644 --- a/src/features/duration.rs +++ b/src/features/duration.rs @@ -39,6 +39,7 @@ lazy_info!( m_required: false, w_required: false, sorting_required: true, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for Duration { @@ -55,8 +56,7 @@ impl FeatureEvaluator for Duration where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.t.sample[ts.lenu() - 1] - ts.t.sample[0]]) } } diff --git a/src/features/eta.rs b/src/features/eta.rs index 6c424ca9..92f8a03b 100644 --- a/src/features/eta.rs +++ b/src/features/eta.rs @@ -42,6 +42,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: true, + variability_required: true, ); impl FeatureNamesDescriptionsTrait for Eta { @@ -58,9 +59,8 @@ impl FeatureEvaluator for Eta where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - let m_std2 = get_nonzero_m_std2(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + let m_std2 = ts.m.get_std2(); let value = ts.m.as_slice() .iter() diff --git a/src/features/eta_e.rs b/src/features/eta_e.rs index b4dac558..d19f6168 100644 --- a/src/features/eta_e.rs +++ b/src/features/eta_e.rs @@ -47,6 +47,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: true, + variability_required: true, ); impl FeatureNamesDescriptionsTrait for EtaE { @@ -63,9 +64,8 @@ impl FeatureEvaluator for EtaE where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - let m_std2 = get_nonzero_m_std2(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + let m_std2 = ts.m.get_std2(); let sq_slope_sum = ts.t.as_slice() .iter() diff --git a/src/features/excess_variance.rs b/src/features/excess_variance.rs index ef9a861a..c2ecf68b 100644 --- a/src/features/excess_variance.rs +++ b/src/features/excess_variance.rs @@ -32,6 +32,7 @@ lazy_info!( m_required: true, w_required: true, sorting_required: false, + variability_required: false, ); impl ExcessVariance { @@ -58,8 +59,7 @@ impl FeatureEvaluator for ExcessVariance where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let mean_error2 = ts.w.sample.fold(T::zero(), |sum, w| sum + w.recip()) / ts.lenf(); Ok(vec![ (ts.m.get_std2() - mean_error2) / ts.m.get_mean().powi(2), diff --git a/src/features/inter_percentile_range.rs b/src/features/inter_percentile_range.rs index b43969d5..4f96e4b1 100644 --- a/src/features/inter_percentile_range.rs +++ b/src/features/inter_percentile_range.rs @@ -42,6 +42,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl InterPercentileRange { @@ -91,8 +92,7 @@ impl FeatureEvaluator for InterPercentileRange where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let ppf_low = ts.m.get_sorted().ppf(self.quantile); let ppf_high = ts.m.get_sorted().ppf(1.0 - self.quantile); let value = ppf_high - ppf_low; diff --git a/src/features/kurtosis.rs b/src/features/kurtosis.rs index c52286ec..707daf70 100644 --- a/src/features/kurtosis.rs +++ b/src/features/kurtosis.rs @@ -43,6 +43,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: true, ); impl FeatureNamesDescriptionsTrait for Kurtosis { @@ -59,10 +60,9 @@ impl FeatureEvaluator for Kurtosis where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - let m_std2 = get_nonzero_m_std2(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_mean = ts.m.get_mean(); + let m_std2 = ts.m.get_std2(); let n = ts.lenf(); let n1 = n + T::one(); let n_1 = n - T::one(); diff --git a/src/features/linear_fit.rs b/src/features/linear_fit.rs index 11831a4f..3e87b7db 100644 --- a/src/features/linear_fit.rs +++ b/src/features/linear_fit.rs @@ -45,6 +45,7 @@ lazy_info!( m_required: true, w_required: true, sorting_required: true, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for LinearFit { @@ -69,8 +70,7 @@ impl FeatureEvaluator for LinearFit where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let result = fit_straight_line(ts, true); Ok(vec![ result.slope, diff --git a/src/features/linear_trend.rs b/src/features/linear_trend.rs index d0916347..42dc57e6 100644 --- a/src/features/linear_trend.rs +++ b/src/features/linear_trend.rs @@ -43,6 +43,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: true, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for LinearTrend { @@ -63,8 +64,7 @@ impl FeatureEvaluator for LinearTrend where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let result = fit_straight_line(ts, false); Ok(vec![ result.slope, diff --git a/src/features/magnitude_percentage_ratio.rs b/src/features/magnitude_percentage_ratio.rs index cde34cbd..11546c2a 100644 --- a/src/features/magnitude_percentage_ratio.rs +++ b/src/features/magnitude_percentage_ratio.rs @@ -41,6 +41,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: true, ); impl MagnitudePercentageRatio { @@ -112,18 +113,13 @@ impl FeatureEvaluator for MagnitudePercentageRatio where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_sorted = ts.m.get_sorted(); let numerator = m_sorted.ppf(1.0 - self.quantile_numerator) - m_sorted.ppf(self.quantile_numerator); let denumerator = m_sorted.ppf(1.0 - self.quantile_denominator) - m_sorted.ppf(self.quantile_denominator); - if numerator.is_zero() & denumerator.is_zero() { - Err(EvaluatorError::FlatTimeSeries) - } else { - Ok(vec![numerator / denumerator]) - } + Ok(vec![numerator / denumerator]) } } diff --git a/src/features/maximum_slope.rs b/src/features/maximum_slope.rs index b39ac491..84a5cd84 100644 --- a/src/features/maximum_slope.rs +++ b/src/features/maximum_slope.rs @@ -33,6 +33,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: true, + variability_required: false, ); impl MaximumSlope { @@ -57,8 +58,7 @@ impl FeatureEvaluator for MaximumSlope where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let result = ts.t.as_slice() .iter() diff --git a/src/features/maximum_time_interval.rs b/src/features/maximum_time_interval.rs index 8bece763..1a30c00f 100644 --- a/src/features/maximum_time_interval.rs +++ b/src/features/maximum_time_interval.rs @@ -40,6 +40,7 @@ lazy_info!( m_required: false, w_required: false, sorting_required: true, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for MaximumTimeInterval { @@ -56,8 +57,7 @@ impl FeatureEvaluator for MaximumTimeInterval where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let dt = ts.t.as_slice() .iter() diff --git a/src/features/mean.rs b/src/features/mean.rs index d95bf695..ba1b792b 100644 --- a/src/features/mean.rs +++ b/src/features/mean.rs @@ -28,6 +28,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl Mean { @@ -54,7 +55,7 @@ impl FeatureEvaluator for Mean where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { self.check_ts_length(ts)?; Ok(vec![ts.m.get_mean()]) } diff --git a/src/features/mean_variance.rs b/src/features/mean_variance.rs index 2712e21b..625118ab 100644 --- a/src/features/mean_variance.rs +++ b/src/features/mean_variance.rs @@ -27,6 +27,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl MeanVariance { @@ -53,8 +54,7 @@ impl FeatureEvaluator for MeanVariance where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.m.get_std() / ts.m.get_mean()]) } } diff --git a/src/features/median.rs b/src/features/median.rs index ed71471a..451c2405 100644 --- a/src/features/median.rs +++ b/src/features/median.rs @@ -27,6 +27,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl Median { @@ -53,8 +54,7 @@ impl FeatureEvaluator for Median where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.m.get_median()]) } } diff --git a/src/features/median_absolute_deviation.rs b/src/features/median_absolute_deviation.rs index d4b6e442..2dae9af9 100644 --- a/src/features/median_absolute_deviation.rs +++ b/src/features/median_absolute_deviation.rs @@ -1,5 +1,5 @@ +use crate::data::SortedArray; use crate::evaluator::*; -use crate::sorted_array::SortedArray; macro_const! { const DOC: &'static str = r" @@ -30,6 +30,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl MedianAbsoluteDeviation { @@ -56,8 +57,7 @@ impl FeatureEvaluator for MedianAbsoluteDeviation where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_median = ts.m.get_median(); let sorted_deviation: SortedArray<_> = ts.m.sample diff --git a/src/features/median_buffer_range_percentage.rs b/src/features/median_buffer_range_percentage.rs index 6e8c9d6f..16b39efd 100644 --- a/src/features/median_buffer_range_percentage.rs +++ b/src/features/median_buffer_range_percentage.rs @@ -38,6 +38,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl MedianBufferRangePercentage @@ -102,8 +103,7 @@ impl FeatureEvaluator for MedianBufferRangePercentage where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_median = ts.m.get_median(); let amplitude = T::half() * (ts.m.get_max() - ts.m.get_min()); let threshold = self.quantile * amplitude; diff --git a/src/features/minimum_time_interval.rs b/src/features/minimum_time_interval.rs index 24298c4e..a7a6da6c 100644 --- a/src/features/minimum_time_interval.rs +++ b/src/features/minimum_time_interval.rs @@ -40,6 +40,7 @@ lazy_info!( m_required: false, w_required: false, sorting_required: true, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for MinimumTimeInterval { @@ -56,8 +57,7 @@ impl FeatureEvaluator for MinimumTimeInterval where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let dt = ts.t.as_slice() .iter() diff --git a/src/features/observation_count.rs b/src/features/observation_count.rs index c04c8c30..ab2ba097 100644 --- a/src/features/observation_count.rs +++ b/src/features/observation_count.rs @@ -39,6 +39,7 @@ lazy_info!( m_required: false, w_required: false, sorting_required: false, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for ObservationCount { @@ -55,8 +56,7 @@ impl FeatureEvaluator for ObservationCount where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.lenf()]) } } diff --git a/src/features/otsu_split.rs b/src/features/otsu_split.rs index b0edbcf5..165d72a8 100644 --- a/src/features/otsu_split.rs +++ b/src/features/otsu_split.rs @@ -1,5 +1,5 @@ +use crate::data::DataSample; use crate::evaluator::*; -use crate::time_series::DataSample; use conv::prelude::*; use ndarray::{s, Array1, ArrayView1, Axis, Zip}; use ndarray_stats::QuantileExt; @@ -36,6 +36,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: true, ); impl OtsuSplit { @@ -47,28 +48,17 @@ impl OtsuSplit { DOC } - pub fn threshold<'a, 'b, T>( + fn threshold_no_ds_check<'a, 'b, T>( ds: &'b mut DataSample<'a, T>, - ) -> Result<(T, ArrayView1<'b, T>, ArrayView1<'b, T>), EvaluatorError> + ) -> (T, ArrayView1<'b, T>, ArrayView1<'b, T>) where 'a: 'b, T: Float, { - if ds.sample.len() < 2 { - return Err(EvaluatorError::ShortTimeSeries { - actual: ds.sample.len(), - minimum: 2, - }); - } - let count = ds.sample.len(); let countf = count.approx().unwrap(); let sorted = ds.get_sorted(); - if sorted.minimum() == sorted.maximum() { - return Err(EvaluatorError::FlatTimeSeries); - } - // size is (count - 1) let cumsum1: Array1<_> = sorted .iter() @@ -110,7 +100,30 @@ impl OtsuSplit { let index = inter_class_variance.argmax().unwrap(); let (lower, upper) = sorted.0.view().split_at(Axis(0), index + 1); - Ok((sorted.0[index + 1], lower, upper)) + (sorted.0[index + 1], lower, upper) + } + + pub fn threshold<'a, 'b, T>( + ds: &'b mut DataSample<'a, T>, + ) -> Result<(T, ArrayView1<'b, T>, ArrayView1<'b, T>), EvaluatorError> + where + 'a: 'b, + T: Float, + { + if ds.sample.len() < 2 { + return Err(EvaluatorError::ShortTimeSeries { + actual: ds.sample.len(), + minimum: 2, + }); + } + + // Sorted array will be cached inside ds, we will reuse it in threshold_no_ds_check + let sorted = ds.get_sorted(); + if sorted.minimum() == sorted.maximum() { + return Err(EvaluatorError::FlatTimeSeries); + } + + return Ok(Self::threshold_no_ds_check(ds)); } } @@ -136,9 +149,7 @@ impl FeatureEvaluator for OtsuSplit where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let (_, lower, upper) = Self::threshold(&mut ts.m)?; let mut lower: DataSample<_> = lower.into(); let mut upper: DataSample<_> = upper.into(); diff --git a/src/features/percent_amplitude.rs b/src/features/percent_amplitude.rs index cb4f1156..899fb468 100644 --- a/src/features/percent_amplitude.rs +++ b/src/features/percent_amplitude.rs @@ -30,6 +30,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl PercentAmplitude { @@ -56,8 +57,7 @@ impl FeatureEvaluator for PercentAmplitude where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_min = ts.m.get_min(); let m_max = ts.m.get_max(); let m_median = ts.m.get_median(); diff --git a/src/features/percent_difference_magnitude_percentile.rs b/src/features/percent_difference_magnitude_percentile.rs index 6b4ec505..61b839f5 100644 --- a/src/features/percent_difference_magnitude_percentile.rs +++ b/src/features/percent_difference_magnitude_percentile.rs @@ -38,6 +38,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl PercentDifferenceMagnitudePercentile { @@ -94,8 +95,7 @@ impl FeatureEvaluator for PercentDifferenceMagnitudePercentile where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let nominator = ts.m.get_sorted().ppf(1.0 - self.quantile) - ts.m.get_sorted().ppf(self.quantile); let denominator = ts.m.get_median(); diff --git a/src/features/periodogram.rs b/src/features/periodogram.rs index 4a2a0210..407dd93f 100644 --- a/src/features/periodogram.rs +++ b/src/features/periodogram.rs @@ -49,6 +49,7 @@ impl PeriodogramPeaks { m_required: true, w_required: false, sorting_required: true, + variability_required: false, }; let names = (0..peaks) .flat_map(|i| vec![format!("period_{}", i), format!("period_s_to_n_{}", i)]) @@ -102,6 +103,7 @@ impl EvaluatorInfoTrait for PeriodogramPeaks { &self.properties.info } } + impl FeatureNamesDescriptionsTrait for PeriodogramPeaks { fn get_names(&self) -> Vec<&str> { self.properties.names.iter().map(String::as_str).collect() @@ -120,8 +122,7 @@ impl FeatureEvaluator for PeriodogramPeaks where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let peak_indices = peak_indices_reverse_sorted(&ts.m.sample); Ok(peak_indices .iter() @@ -310,6 +311,7 @@ where m_required: true, w_required: false, sorting_required: true, + variability_required: false, }; Self { properties: EvaluatorProperties { diff --git a/src/features/reduced_chi2.rs b/src/features/reduced_chi2.rs index 42076043..66eeef98 100644 --- a/src/features/reduced_chi2.rs +++ b/src/features/reduced_chi2.rs @@ -33,6 +33,7 @@ lazy_info!( m_required: true, w_required: true, sorting_required: false, + variability_required: false, ); impl ReducedChi2 { @@ -59,8 +60,7 @@ impl FeatureEvaluator for ReducedChi2 where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.get_m_reduced_chi2()]) } } diff --git a/src/features/skew.rs b/src/features/skew.rs index c5f41a7e..ab7f2a16 100644 --- a/src/features/skew.rs +++ b/src/features/skew.rs @@ -32,6 +32,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: true, ); impl Skew { @@ -58,10 +59,9 @@ impl FeatureEvaluator for Skew where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - let m_std = get_nonzero_m_std(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let m_mean = ts.m.get_mean(); + let m_std = ts.m.get_std(); let n = ts.lenf(); let n_1 = n - T::one(); let n_2 = n_1 - T::one(); diff --git a/src/features/standard_deviation.rs b/src/features/standard_deviation.rs index 55196f6c..88f5d3fb 100644 --- a/src/features/standard_deviation.rs +++ b/src/features/standard_deviation.rs @@ -32,6 +32,7 @@ lazy_info!( m_required: true, w_required: false, sorting_required: false, + variability_required: false, ); impl StandardDeviation { @@ -58,8 +59,7 @@ impl FeatureEvaluator for StandardDeviation where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.m.get_std()]) } } diff --git a/src/features/stetson_k.rs b/src/features/stetson_k.rs index e70b7138..a28d3c06 100644 --- a/src/features/stetson_k.rs +++ b/src/features/stetson_k.rs @@ -34,6 +34,7 @@ lazy_info!( m_required: true, w_required: true, sorting_required: false, + variability_required: true, ); impl StetsonK { @@ -60,9 +61,8 @@ impl FeatureEvaluator for StetsonK where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - let chi2 = get_nonzero_reduced_chi2(ts)? * (ts.lenf() - T::one()); + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + let chi2 = ts.get_m_reduced_chi2() * (ts.lenf() - T::one()); let mean = ts.get_m_weighted_mean(); let value = Zip::from(&ts.m.sample) .and(&ts.w.sample) diff --git a/src/features/time_mean.rs b/src/features/time_mean.rs index 36e0bbf2..d8fc3c2d 100644 --- a/src/features/time_mean.rs +++ b/src/features/time_mean.rs @@ -39,6 +39,7 @@ lazy_info!( m_required: false, w_required: false, sorting_required: false, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for TimeMean { fn get_names(&self) -> Vec<&str> { @@ -53,8 +54,7 @@ impl FeatureEvaluator for TimeMean where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.t.get_mean()]) } } diff --git a/src/features/time_standard_deviation.rs b/src/features/time_standard_deviation.rs index ece52bc8..3b733419 100644 --- a/src/features/time_standard_deviation.rs +++ b/src/features/time_standard_deviation.rs @@ -39,6 +39,7 @@ lazy_info!( m_required: false, w_required: false, sorting_required: false, + variability_required: false, ); impl FeatureNamesDescriptionsTrait for TimeStandardDeviation { @@ -55,8 +56,7 @@ impl FeatureEvaluator for TimeStandardDeviation where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.t.get_std()]) } } diff --git a/src/features/transformed.rs b/src/features/transformed.rs index b33ed241..3054c76f 100644 --- a/src/features/transformed.rs +++ b/src/features/transformed.rs @@ -52,6 +52,7 @@ where m_required: feature.is_m_required(), w_required: feature.is_w_required(), sorting_required: feature.is_sorting_required(), + variability_required: feature.is_variability_required(), }; let names = transformer.names(&feature.get_names()); let descriptions = transformer.descriptions(&feature.get_descriptions()); @@ -110,15 +111,17 @@ where F: FeatureEvaluator, Tr: TransformerTrait, { + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + Ok(self + .transformer + .transform(self.feature.eval_no_ts_check(ts)?)) + } + fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(self.transformer.transform(self.feature.eval(ts)?)) } // We keep default implementation of eval_or_fill - - fn check_ts_length(&self, ts: &TimeSeries) -> Result { - self.feature.check_ts_length(ts) - } } #[derive(Serialize, Deserialize, JsonSchema)] diff --git a/src/features/villar_fit.rs b/src/features/villar_fit.rs index c9b794a8..ca1f8041 100644 --- a/src/features/villar_fit.rs +++ b/src/features/villar_fit.rs @@ -126,7 +126,8 @@ lazy_info!( t_required: true, m_required: true, w_required: true, - sorting_required: true, // improve reproducibility + sorting_required: true, // improves reproducibility + variability_required: true, ); impl FitModelTrait for VillarFit @@ -609,13 +610,22 @@ mod tests { check_feature!(VillarFit); feature_test!( - villar_fit_plateau, + villar_fit_almost_plateau, [VillarFit::default()], [0.0, 0.0, 10.0, 5.0, 5.0, 0.0, 1.0, 0.0], // initial model parameters and zero chi2 linspace(0.0, 10.0, 11), - [0.0; 11], + linspace(0.0, 1e-100, 11), // make it a bit non-flat ); + #[test] + fn villar_fit_plateau() { + let fe = VillarFit::default(); + let t = linspace(0.0, 10.0, 11); + let f = [0.0; 11]; + let mut ts = TimeSeries::new_without_weight(&t, &f); + assert!(fe.eval(&mut ts).is_err()); + } + #[cfg(any( feature = "gsl", any(feature = "ceres-source", feature = "ceres-system") diff --git a/src/features/weighted_mean.rs b/src/features/weighted_mean.rs index f728ffef..dd2bcb86 100644 --- a/src/features/weighted_mean.rs +++ b/src/features/weighted_mean.rs @@ -28,6 +28,7 @@ lazy_info!( m_required: true, w_required: true, sorting_required: false, + variability_required: false, ); impl WeightedMean { @@ -54,8 +55,7 @@ impl FeatureEvaluator for WeightedMean where T: Float, { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { Ok(vec![ts.get_m_weighted_mean()]) } } diff --git a/src/lib.rs b/src/lib.rs index 1c6cf709..36c6c79e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ #![doc = include_str!("../README.md")] +extern crate core; + #[cfg(test)] #[macro_use] mod tests; @@ -7,6 +9,9 @@ mod tests; #[macro_use] mod macros; +mod data; +pub use data::{DataSample, TimeSeries}; + mod evaluator; pub use evaluator::{EvaluatorInfoTrait, FeatureEvaluator, FeatureNamesDescriptionsTrait}; @@ -27,6 +32,9 @@ pub use float_trait::Float; mod lnerfc; +mod multicolor; +pub use multicolor::*; + mod nl_fit; pub use nl_fit::evaluator::FitFeatureEvaluatorGettersTrait; #[cfg(any(feature = "ceres-source", feature = "ceres-system"))] @@ -46,8 +54,6 @@ pub use periodogram::{ pub mod prelude; -mod sorted_array; - mod straight_line_fit; #[doc(hidden)] pub use straight_line_fit::fit_straight_line; @@ -59,9 +65,6 @@ mod peak_indices; #[doc(hidden)] pub use peak_indices::peak_indices; -mod time_series; -pub use time_series::{DataSample, TimeSeries}; - mod types; pub use ndarray; diff --git a/src/macros.rs b/src/macros.rs index 443279aa..603ba302 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -8,6 +8,7 @@ macro_rules! lazy_info { m_required: $m: expr, w_required: $w: expr, sorting_required: $sort: expr, + variability_required: $var: expr, ) => { lazy_static! { static ref $name: EvaluatorInfo = EvaluatorInfo { @@ -17,6 +18,7 @@ macro_rules! lazy_info { m_required: $m, w_required: $w, sorting_required: $sort, + variability_required: $var, }; } }; @@ -29,6 +31,7 @@ macro_rules! lazy_info { m_required: $m: expr, w_required: $w: expr, sorting_required: $sort: expr, + variability_required: $var: expr, ) => { lazy_info!( $name, @@ -38,6 +41,7 @@ macro_rules! lazy_info { m_required: $m, w_required: $w, sorting_required: $sort, + variability_required: $var, ); impl EvaluatorInfoTrait for $feature { @@ -56,6 +60,7 @@ macro_rules! lazy_info { m_required: $m: expr, w_required: $w: expr, sorting_required: $sort: expr, + variability_required: $var: expr, ) => { lazy_info!( $name, @@ -65,6 +70,7 @@ macro_rules! lazy_info { m_required: $m, w_required: $w, sorting_required: $sort, + variability_required: $var, ); impl EvaluatorInfoTrait for $feature { @@ -80,7 +86,7 @@ macro_rules! lazy_info { /// - `transform_ts(&self, ts: &mut TimeSeries) -> Result, EvaluatorError>` macro_rules! transformer_eval { () => { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let arrays = self.transform_ts(ts)?; let mut new_ts = arrays.ts(); self.feature_extractor.eval(&mut new_ts) @@ -121,9 +127,7 @@ macro_rules! json_schema { /// - declare `const NPARAMS: usize` in your code macro_rules! fit_eval { () => { - fn eval(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; - + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { let norm_data = NormalizedData::::from_ts(ts); let (x0, lower, upper) = { diff --git a/src/multicolor/features/color_of_maximum.rs b/src/multicolor/features/color_of_maximum.rs new file mode 100644 index 00000000..9639f875 --- /dev/null +++ b/src/multicolor/features/color_of_maximum.rs @@ -0,0 +1,111 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{EvaluatorInfo, EvaluatorInfoTrait, FeatureNamesDescriptionsTrait}; +use crate::float_trait::Float; +use crate::multicolor::multicolor_evaluator::*; +use crate::multicolor::{PassbandSet, PassbandTrait}; + +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; +use std::fmt::Debug; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] +pub struct ColorOfMaximum

+where + P: Ord, +{ + passband_set: PassbandSet

, + passbands: [P; 2], + name: String, + description: String, +} + +impl

ColorOfMaximum

+where + P: PassbandTrait, +{ + pub fn new(passbands: [P; 2]) -> Self { + let set: BTreeSet<_> = passbands.clone().into(); + Self { + passband_set: set.into(), + name: format!("color_max_{}_{}", passbands[0].name(), passbands[1].name()), + description: format!( + "difference of maximum value magnitudes {}-{}", + passbands[0].name(), + passbands[1].name() + ), + passbands, + } + } +} + +lazy_info!( + COLOR_OF_MAXIMUM_INFO, + size: 1, + min_ts_length: 1, + t_required: false, + m_required: true, + w_required: false, + sorting_required: false, + variability_required: false, +); + +impl

EvaluatorInfoTrait for ColorOfMaximum

+where + P: Ord, +{ + fn get_info(&self) -> &EvaluatorInfo { + &COLOR_OF_MAXIMUM_INFO + } +} + +impl

FeatureNamesDescriptionsTrait for ColorOfMaximum

+where + P: Ord, +{ + fn get_names(&self) -> Vec<&str> { + vec![self.name.as_str()] + } + + fn get_descriptions(&self) -> Vec<&str> { + vec![self.description.as_str()] + } +} + +impl

MultiColorPassbandSetTrait

for ColorOfMaximum

+where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } +} + +impl MultiColorEvaluator for ColorOfMaximum

+where + P: PassbandTrait, + T: Float, +{ + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { + let mut maxima = [T::zero(); 2]; + for ((_passband, mcts), maximum) in mcts + .mapping_mut() + .iter_matched_passbands_mut(self.passbands.iter()) + .zip(maxima.iter_mut()) + { + let mcts = mcts.expect("MultiColorTimeSeries must have all required passbands"); + *maximum = mcts.m.get_max() + } + Ok(vec![maxima[0] - maxima[1]]) + } +} diff --git a/src/multicolor/features/color_of_median.rs b/src/multicolor/features/color_of_median.rs new file mode 100644 index 00000000..a6779e4b --- /dev/null +++ b/src/multicolor/features/color_of_median.rs @@ -0,0 +1,125 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{ + EvaluatorInfo, EvaluatorInfoTrait, FeatureEvaluator, FeatureNamesDescriptionsTrait, +}; +use crate::features::Median; +use crate::float_trait::Float; +use crate::multicolor::multicolor_evaluator::*; +use crate::multicolor::{PassbandSet, PassbandTrait}; + +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; +use std::fmt::Debug; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] +pub struct ColorOfMedian

+where + P: Ord, +{ + passband_set: PassbandSet

, + passbands: [P; 2], + median: Median, + name: String, + description: String, +} + +impl

ColorOfMedian

+where + P: PassbandTrait, +{ + pub fn new(passbands: [P; 2]) -> Self { + let set: BTreeSet<_> = passbands.clone().into(); + Self { + passband_set: set.into(), + name: format!( + "color_median_{}_{}", + passbands[0].name(), + passbands[1].name() + ), + description: format!( + "difference of median magnitudes {}-{}", + passbands[0].name(), + passbands[1].name() + ), + passbands, + median: Median {}, + } + } +} + +lazy_info!( + COLOR_OF_MEDIAN_INFO, + size: 1, + min_ts_length: 1, + t_required: false, + m_required: true, + w_required: false, + sorting_required: false, + variability_required: false, +); + +impl

EvaluatorInfoTrait for ColorOfMedian

+where + P: Ord, +{ + fn get_info(&self) -> &EvaluatorInfo { + &COLOR_OF_MEDIAN_INFO + } +} + +impl

FeatureNamesDescriptionsTrait for ColorOfMedian

+where + P: Ord, +{ + fn get_names(&self) -> Vec<&str> { + vec![self.name.as_str()] + } + + fn get_descriptions(&self) -> Vec<&str> { + vec![self.description.as_str()] + } +} + +impl

MultiColorPassbandSetTrait

for ColorOfMedian

+where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } +} + +impl MultiColorEvaluator for ColorOfMedian

+where + P: PassbandTrait, + T: Float, +{ + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { + let mut medians = [T::zero(); 2]; + for ((passband, mcts), median) in mcts + .mapping_mut() + .iter_matched_passbands_mut(self.passbands.iter()) + .zip(medians.iter_mut()) + { + let mcts = mcts.expect("MultiColorTimeSeries must have all required passbands"); + *median = self.median.eval(mcts).map_err(|error| { + MultiColorEvaluatorError::MonochromeEvaluatorError { + passband: passband.name().into(), + error, + } + })?[0] + } + Ok(vec![medians[0] - medians[1]]) + } +} diff --git a/src/multicolor/features/color_of_minimum.rs b/src/multicolor/features/color_of_minimum.rs new file mode 100644 index 00000000..72de5c7d --- /dev/null +++ b/src/multicolor/features/color_of_minimum.rs @@ -0,0 +1,111 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{EvaluatorInfo, EvaluatorInfoTrait, FeatureNamesDescriptionsTrait}; +use crate::float_trait::Float; +use crate::multicolor::multicolor_evaluator::*; +use crate::multicolor::{PassbandSet, PassbandTrait}; + +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; +use std::fmt::Debug; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] +pub struct ColorOfMinimum

+where + P: Ord, +{ + passband_set: PassbandSet

, + passbands: [P; 2], + name: String, + description: String, +} + +impl

ColorOfMinimum

+where + P: PassbandTrait, +{ + pub fn new(passbands: [P; 2]) -> Self { + let set: BTreeSet<_> = passbands.clone().into(); + Self { + passband_set: set.into(), + name: format!("color_min_{}_{}", passbands[0].name(), passbands[1].name()), + description: format!( + "difference of minimum value magnitudes {}-{}", + passbands[0].name(), + passbands[1].name() + ), + passbands, + } + } +} + +lazy_info!( + COLOR_OF_MINIMUM_INFO, + size: 1, + min_ts_length: 1, + t_required: false, + m_required: true, + w_required: false, + sorting_required: false, + variability_required: false, +); + +impl

EvaluatorInfoTrait for ColorOfMinimum

+where + P: Ord, +{ + fn get_info(&self) -> &EvaluatorInfo { + &COLOR_OF_MINIMUM_INFO + } +} + +impl

FeatureNamesDescriptionsTrait for ColorOfMinimum

+where + P: Ord, +{ + fn get_names(&self) -> Vec<&str> { + vec![self.name.as_str()] + } + + fn get_descriptions(&self) -> Vec<&str> { + vec![self.description.as_str()] + } +} + +impl

MultiColorPassbandSetTrait

for ColorOfMinimum

+where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } +} + +impl MultiColorEvaluator for ColorOfMinimum

+where + P: PassbandTrait, + T: Float, +{ + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { + let mut minima = [T::zero(); 2]; + for ((_passband, mcts), minimum) in mcts + .mapping_mut() + .iter_matched_passbands_mut(self.passbands.iter()) + .zip(minima.iter_mut()) + { + let mcts = mcts.expect("MultiColorTimeSeries must have all required passbands"); + *minimum = mcts.m.get_min() + } + Ok(vec![minima[0] - minima[1]]) + } +} diff --git a/src/multicolor/features/mod.rs b/src/multicolor/features/mod.rs new file mode 100644 index 00000000..54b5a4c7 --- /dev/null +++ b/src/multicolor/features/mod.rs @@ -0,0 +1,8 @@ +mod color_of_maximum; +pub use color_of_maximum::ColorOfMaximum; + +mod color_of_median; +pub use color_of_median::ColorOfMedian; + +mod color_of_minimum; +pub use color_of_minimum::ColorOfMinimum; diff --git a/src/multicolor/mod.rs b/src/multicolor/mod.rs new file mode 100644 index 00000000..fb09f714 --- /dev/null +++ b/src/multicolor/mod.rs @@ -0,0 +1,16 @@ +mod features; + +mod monochrome_feature; +pub use monochrome_feature::MonochromeFeature; + +mod multicolor_evaluator; +pub use multicolor_evaluator::{MultiColorEvaluator, MultiColorPassbandSetTrait, PassbandSet}; + +mod multicolor_extractor; +pub use multicolor_extractor::MultiColorExtractor; + +mod multicolor_feature; +pub use multicolor_feature::MultiColorFeature; + +mod passband; +pub use passband::*; diff --git a/src/multicolor/monochrome_feature.rs b/src/multicolor/monochrome_feature.rs new file mode 100644 index 00000000..2b483d29 --- /dev/null +++ b/src/multicolor/monochrome_feature.rs @@ -0,0 +1,132 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{ + EvaluatorInfo, EvaluatorInfoTrait, EvaluatorProperties, FeatureEvaluator, + FeatureNamesDescriptionsTrait, +}; +use crate::float_trait::Float; +use crate::multicolor::multicolor_evaluator::*; + +use itertools::Itertools; +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; +use std::fmt::Debug; +use std::marker::PhantomData; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound( + deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float, F: FeatureEvaluator" +))] +pub struct MonochromeFeature +where + P: Ord, +{ + feature: F, + passband_set: PassbandSet

, + properties: Box, + phantom: PhantomData, +} + +impl MonochromeFeature +where + P: PassbandTrait, + T: Float, + F: FeatureEvaluator, +{ + pub fn new(feature: F, passband_set: BTreeSet

) -> Self { + let names = passband_set + .iter() + .cartesian_product(feature.get_names()) + .map(|(passband, name)| format!("{}_{}", name, passband.name())) + .collect(); + let descriptions = passband_set + .iter() + .cartesian_product(feature.get_descriptions()) + .map(|(passband, description)| format!("{}, passband {}", description, passband.name())) + .collect(); + let info = { + let mut info = feature.get_info().clone(); + info.size *= passband_set.len(); + info + }; + Self { + properties: EvaluatorProperties { + info, + names, + descriptions, + } + .into(), + feature, + passband_set: passband_set.into(), + phantom: PhantomData, + } + } +} + +impl FeatureNamesDescriptionsTrait for MonochromeFeature +where + P: Ord, +{ + fn get_names(&self) -> Vec<&str> { + self.properties.names.iter().map(String::as_str).collect() + } + + fn get_descriptions(&self) -> Vec<&str> { + self.properties + .descriptions + .iter() + .map(String::as_str) + .collect() + } +} + +impl EvaluatorInfoTrait for MonochromeFeature +where + P: Ord, +{ + fn get_info(&self) -> &EvaluatorInfo { + &self.properties.info + } +} + +impl MultiColorPassbandSetTrait

for MonochromeFeature +where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } +} + +impl MultiColorEvaluator for MonochromeFeature +where + P: PassbandTrait, + T: Float, + F: FeatureEvaluator, +{ + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { + match &self.passband_set { + PassbandSet::FixedSet(set) => { + mcts.mapping_mut().iter_matched_passbands_mut(set.iter()) + .map(|(passband, ts)| { + self.feature.eval_no_ts_check( + ts.expect("we checked all needed passbands are in mcts, but we still cannot find one") + ).map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError { + passband: passband.name().into(), + error, + }) + }).flatten_ok().collect() + } + PassbandSet::AllAvailable => panic!("passband_set must be FixedSet variant here"), + } + } +} diff --git a/src/multicolor/multicolor_evaluator.rs b/src/multicolor/multicolor_evaluator.rs new file mode 100644 index 00000000..c60dddf7 --- /dev/null +++ b/src/multicolor/multicolor_evaluator.rs @@ -0,0 +1,277 @@ +pub use crate::data::MultiColorTimeSeries; +pub use crate::error::MultiColorEvaluatorError; +pub use crate::evaluator::{ + EvaluatorError, EvaluatorInfo, EvaluatorInfoTrait, EvaluatorProperties, FeatureEvaluator, + FeatureNamesDescriptionsTrait, +}; +pub use crate::feature::Feature; +pub use crate::float_trait::Float; +pub use crate::multicolor::PassbandTrait; + +use enum_dispatch::enum_dispatch; +use itertools::Itertools; +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; +use std::fmt::Debug; + +#[enum_dispatch] +pub trait MultiColorPassbandSetTrait

+where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

; +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] +#[non_exhaustive] +pub enum PassbandSet

+where + P: Ord, +{ + FixedSet(BTreeSet

), + AllAvailable, +} + +impl

From> for PassbandSet

+where + P: Ord, +{ + fn from(value: BTreeSet

) -> Self { + Self::FixedSet(value) + } +} + +enum InternalMctsError { + MultiColorEvaluatorError(MultiColorEvaluatorError), + InternalWrongPassbandSet, +} + +impl InternalMctsError { + fn into_multi_color_evaluator_error<'mcts, 'a, 'ps, P, T>( + self, + mcts: &'mcts MultiColorTimeSeries<'a, P, T>, + ps: &'ps PassbandSet

, + ) -> MultiColorEvaluatorError + where + 'ps: 'a, + 'a: 'mcts, + P: PassbandTrait, + T: Float, + { + match self { + InternalMctsError::MultiColorEvaluatorError(e) => e, + InternalMctsError::InternalWrongPassbandSet => { + MultiColorEvaluatorError::wrong_passbands_error( + mcts.passbands(), + match ps { + PassbandSet::FixedSet(ps) => ps.iter(), + PassbandSet::AllAvailable => { + panic!("PassbandSet cannot be ::AllAvailable here") + } + }, + ) + } + } + } +} + +#[enum_dispatch] +pub trait MultiColorEvaluator: + FeatureNamesDescriptionsTrait + + EvaluatorInfoTrait + + MultiColorPassbandSetTrait

+ + Clone + + Serialize +where + P: PassbandTrait, + T: Float, +{ + /// Version of [MultiColorEvaluator::eval_multicolor] without basic [MultiColorTimeSeries] checks + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts; + + /// Vector of feature values or `EvaluatorError` + fn eval_multicolor<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + P: 'a, + { + self.check_mcts(mcts)?; + self.eval_multicolor_no_mcts_check(mcts) + } + + /// Returns vector of feature values and fill invalid components with given value + fn eval_or_fill_multicolor<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + fill_value: T, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + P: 'a, + { + Ok(match self.eval_multicolor(mcts) { + Ok(v) => v, + Err(_) => vec![fill_value; self.size_hint()], + }) + } + + /// Check [MultiColorTimeSeries] to have required passbands and individual [TimeSeries] are valid + fn check_mcts<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result<(), MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + P: 'a, + { + mcts.mapping_mut() + .iter_passband_set_mut(self.get_passband_set()) + .map(|(p, maybe_ts)| { + maybe_ts + .ok_or(InternalMctsError::InternalWrongPassbandSet) + .and_then(|ts| { + self.check_ts(ts).map_err(|error| { + InternalMctsError::MultiColorEvaluatorError( + MultiColorEvaluatorError::MonochromeEvaluatorError { + error, + passband: p.name().into(), + }, + ) + }) + }) + .map(|_| ()) + }) + .try_collect() + .map_err(|err| err.into_multi_color_evaluator_error(mcts, self.get_passband_set())) + } +} + +#[cfg(test)] +#[allow(clippy::unreadable_literal)] +#[allow(clippy::excessive_precision)] +mod tests { + use super::*; + use crate::data::TimeSeries; + use crate::multicolor::MonochromePassband; + + use std::collections::BTreeMap; + + #[derive(Clone, Debug, Serialize)] + struct TestTimeMultiColorFeature { + passband_set: PassbandSet>, + } + + lazy_info!( + TEST_TIME_FEATURE_INFO, + TestTimeMultiColorFeature, + size: 1, + min_ts_length: 1, + t_required: true, + m_required: false, + w_required: false, + sorting_required: true, + variability_required: false, + ); + + impl FeatureNamesDescriptionsTrait for TestTimeMultiColorFeature { + fn get_names(&self) -> Vec<&str> { + vec!["zero"] + } + + fn get_descriptions(&self) -> Vec<&str> { + vec!["zero"] + } + } + + impl MultiColorPassbandSetTrait> for TestTimeMultiColorFeature { + fn get_passband_set(&self) -> &PassbandSet> { + &self.passband_set + } + } + + impl MultiColorEvaluator, T> for TestTimeMultiColorFeature + where + T: Float, + { + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + _mcts: &'mcts mut MultiColorTimeSeries<'a, MonochromePassband<'static, f64>, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { + Ok(vec![T::zero()]) + } + } + + #[test] + fn test_check_mcts_passbands() { + let t = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]; + let m = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]; + let passband_b_capital = MonochromePassband::new(4400e-8, "B"); + let passband_v_capital = MonochromePassband::new(5500e-8, "V"); + let passband_r_capital = MonochromePassband::new(6400e-8, "R"); + let mut mcts = { + let mut mapping = BTreeMap::new(); + mapping.insert( + passband_b_capital.clone(), + TimeSeries::new_without_weight(&t, &m), + ); + mapping.insert( + passband_v_capital.clone(), + TimeSeries::new_without_weight(&t, &m), + ); + MultiColorTimeSeries::from_map(mapping) + }; + + let feature = TestTimeMultiColorFeature { + passband_set: PassbandSet::AllAvailable, + }; + assert!(feature.eval_multicolor(&mut mcts).is_ok()); + + let feature = TestTimeMultiColorFeature { + passband_set: PassbandSet::FixedSet( + [passband_b_capital.clone(), passband_v_capital.clone()].into(), + ), + }; + assert!(feature.eval_multicolor(&mut mcts).is_ok()); + + let feature = TestTimeMultiColorFeature { + passband_set: PassbandSet::FixedSet([passband_b_capital.clone()].into()), + }; + assert!(feature.eval_multicolor(&mut mcts).is_ok()); + + let feature = TestTimeMultiColorFeature { + passband_set: PassbandSet::FixedSet([passband_r_capital.clone()].into()), + }; + assert!(feature.eval_multicolor(&mut mcts).is_err()); + + let feature = TestTimeMultiColorFeature { + passband_set: PassbandSet::FixedSet( + [ + passband_b_capital.clone(), + passband_r_capital.clone(), + passband_r_capital.clone(), + ] + .into(), + ), + }; + assert!(feature.eval_multicolor(&mut mcts).is_err()); + } +} diff --git a/src/multicolor/multicolor_extractor.rs b/src/multicolor/multicolor_extractor.rs new file mode 100644 index 00000000..5fe4d70d --- /dev/null +++ b/src/multicolor/multicolor_extractor.rs @@ -0,0 +1,194 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{EvaluatorInfo, EvaluatorInfoTrait, FeatureNamesDescriptionsTrait}; +use crate::float_trait::Float; +use crate::multicolor::multicolor_evaluator::*; + +use itertools::Itertools; +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; +use std::fmt::Debug; +use std::marker::PhantomData; + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde( + into = "MultiColorExtractorParameters", + from = "MultiColorExtractorParameters", + bound( + serialize = "P: PassbandTrait, T: Float, MCF: MultiColorEvaluator", + deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float, MCF: MultiColorEvaluator + Deserialize<'de>" + ) +)] +pub struct MultiColorExtractor +where + P: Ord, +{ + features: Vec, + info: Box, + passband_set: PassbandSet

, + phantom: PhantomData, +} + +impl MultiColorExtractor +where + P: PassbandTrait, + T: Float, + MCF: MultiColorEvaluator, +{ + pub fn new(features: Vec) -> Self { + let passband_set = { + let set: BTreeSet<_> = features + .iter() + .filter_map(|f| match f.get_passband_set() { + PassbandSet::AllAvailable => None, + PassbandSet::FixedSet(set) => Some(set), + }) + .flatten() + .cloned() + .collect(); + if set.is_empty() { + PassbandSet::AllAvailable + } else { + PassbandSet::FixedSet(set) + } + }; + + let info = EvaluatorInfo { + size: features.iter().map(|x| x.size_hint()).sum(), + min_ts_length: features + .iter() + .map(|x| x.min_ts_length()) + .max() + .unwrap_or(0), + t_required: features.iter().any(|x| x.is_t_required()), + m_required: features.iter().any(|x| x.is_m_required()), + w_required: features.iter().any(|x| x.is_w_required()), + sorting_required: features.iter().any(|x| x.is_sorting_required()), + variability_required: features.iter().any(|x| x.is_variability_required()), + } + .into(); + + Self { + features, + passband_set, + info, + phantom: PhantomData, + } + } +} + +impl FeatureNamesDescriptionsTrait for MultiColorExtractor +where + P: Ord, + MCF: FeatureNamesDescriptionsTrait, +{ + /// Get feature names + fn get_names(&self) -> Vec<&str> { + self.features.iter().flat_map(|x| x.get_names()).collect() + } + + /// Get feature descriptions + fn get_descriptions(&self) -> Vec<&str> { + self.features + .iter() + .flat_map(|x| x.get_descriptions()) + .collect() + } +} + +impl EvaluatorInfoTrait for MultiColorExtractor +where + P: Ord, +{ + fn get_info(&self) -> &EvaluatorInfo { + &self.info + } +} + +impl MultiColorPassbandSetTrait

for MultiColorExtractor +where + P: PassbandTrait, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &self.passband_set + } +} + +impl MultiColorEvaluator for MultiColorExtractor +where + P: PassbandTrait, + T: Float, + MCF: MultiColorEvaluator, +{ + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { + let mut vec = Vec::with_capacity(self.size_hint()); + for x in &self.features { + vec.extend(x.eval_multicolor_no_mcts_check(mcts)?); + } + Ok(vec) + } + + fn eval_or_fill_multicolor<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + fill_value: T, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { + self.features + .iter() + .map(|x| x.eval_or_fill_multicolor(mcts, fill_value)) + .flatten_ok() + .collect() + } +} + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename = "MultiColorExtractor")] +struct MultiColorExtractorParameters { + features: Vec, +} + +impl From> for MultiColorExtractorParameters +where + P: PassbandTrait, + T: Float, + MCF: MultiColorEvaluator, +{ + fn from(f: MultiColorExtractor) -> Self { + Self { + features: f.features, + } + } +} + +impl From> for MultiColorExtractor +where + P: PassbandTrait, + T: Float, + MCF: MultiColorEvaluator, +{ + fn from(p: MultiColorExtractorParameters) -> Self { + Self::new(p.features) + } +} + +impl JsonSchema for MultiColorExtractor +where + P: PassbandTrait, + T: Float, + MCF: JsonSchema, +{ + json_schema!(MultiColorExtractorParameters, true); +} diff --git a/src/multicolor/multicolor_feature.rs b/src/multicolor/multicolor_feature.rs new file mode 100644 index 00000000..61a63e75 --- /dev/null +++ b/src/multicolor/multicolor_feature.rs @@ -0,0 +1,47 @@ +use crate::data::{MultiColorTimeSeries, TimeSeries}; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::{EvaluatorInfo, EvaluatorInfoTrait, FeatureNamesDescriptionsTrait}; +use crate::feature::Feature; +use crate::float_trait::Float; +use crate::multicolor::features::{ColorOfMaximum, ColorOfMedian, ColorOfMinimum}; +use crate::multicolor::multicolor_evaluator::*; +use crate::multicolor::{MonochromeFeature, MultiColorExtractor}; + +use enum_dispatch::enum_dispatch; +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; +use std::fmt::Debug; + +#[enum_dispatch(MultiColorEvaluator, FeatureNamesDescriptionsTrait, EvaluatorInfoTrait, MultiColorPassbandSetTrait

)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float"))] +#[non_exhaustive] +pub enum MultiColorFeature +where + P: PassbandTrait, + T: Float, +{ + // Extractor + MultiColorExtractor(MultiColorExtractor>), + // Monochrome Features + MonochromeFeature(MonochromeFeature>), + // Features + ColorOfMaximum(ColorOfMaximum

), + ColorOfMedian(ColorOfMedian

), + ColorOfMinimum(ColorOfMinimum

), +} + +impl MultiColorFeature +where + P: PassbandTrait, + T: Float, +{ + pub fn from_monochrome_feature(feature: F, passband_set: BTreeSet

) -> Self + where + F: Into>, + { + MonochromeFeature::new(feature.into(), passband_set).into() + } +} diff --git a/src/multicolor/passband/dump_passband.rs b/src/multicolor/passband/dump_passband.rs new file mode 100644 index 00000000..ffd73cde --- /dev/null +++ b/src/multicolor/passband/dump_passband.rs @@ -0,0 +1,14 @@ +use crate::PassbandTrait; + +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)] +pub struct DumpPassband {} + +impl PassbandTrait for DumpPassband { + fn name(&self) -> &str { + "" + } +} diff --git a/src/multicolor/passband/mod.rs b/src/multicolor/passband/mod.rs new file mode 100644 index 00000000..1fbef55e --- /dev/null +++ b/src/multicolor/passband/mod.rs @@ -0,0 +1,8 @@ +mod monochrome_passband; +pub use monochrome_passband::MonochromePassband; + +mod dump_passband; +pub use dump_passband::DumpPassband; + +mod passband_trait; +pub use passband_trait::PassbandTrait; diff --git a/src/multicolor/passband/monochrome_passband.rs b/src/multicolor/passband/monochrome_passband.rs new file mode 100644 index 00000000..987bc79a --- /dev/null +++ b/src/multicolor/passband/monochrome_passband.rs @@ -0,0 +1,69 @@ +use crate::float_trait::Float; +use crate::multicolor::PassbandTrait; + +pub use lazy_static::lazy_static; +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::fmt::Debug; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct MonochromePassband<'a, T> { + pub name: &'a str, + pub wavelength: T, +} + +impl<'a, T> MonochromePassband<'a, T> +where + T: Float, +{ + pub fn new(wavelength: T, name: &'a str) -> Self { + assert!( + wavelength.is_normal(), + "wavelength must be a positive normal number" + ); + assert!( + wavelength.is_sign_positive(), + "wavelength must be a positive normal number" + ); + Self { wavelength, name } + } +} + +impl<'a, T> PartialEq for MonochromePassband<'a, T> +where + T: Float, +{ + fn eq(&self, other: &Self) -> bool { + self.wavelength.eq(&other.wavelength) + } +} + +impl<'a, T> Eq for MonochromePassband<'a, T> where T: Float {} + +impl<'a, T> PartialOrd for MonochromePassband<'a, T> +where + T: Float, +{ + fn partial_cmp(&self, other: &Self) -> Option { + (self.wavelength).partial_cmp(&other.wavelength) + } +} + +impl<'a, T> Ord for MonochromePassband<'a, T> +where + T: Float, +{ + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap() + } +} + +impl<'a, T> PassbandTrait for MonochromePassband<'a, T> +where + T: Float, +{ + fn name(&self) -> &str { + self.name + } +} diff --git a/src/multicolor/passband/passband_trait.rs b/src/multicolor/passband/passband_trait.rs new file mode 100644 index 00000000..011c972a --- /dev/null +++ b/src/multicolor/passband/passband_trait.rs @@ -0,0 +1,7 @@ +pub use schemars::JsonSchema; +pub use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +pub trait PassbandTrait: Debug + Clone + Send + Sync + Ord + Serialize + JsonSchema { + fn name(&self) -> &str; +} diff --git a/src/nl_fit/data.rs b/src/nl_fit/data.rs index c5fba037..8cd6952b 100644 --- a/src/nl_fit/data.rs +++ b/src/nl_fit/data.rs @@ -1,5 +1,5 @@ +use crate::data::{DataSample, TimeSeries}; use crate::float_trait::Float; -use crate::time_series::{DataSample, TimeSeries}; use conv::ConvUtil; use ndarray::Array1; diff --git a/src/nl_fit/evaluator.rs b/src/nl_fit/evaluator.rs index 3000ebb0..1b83e741 100644 --- a/src/nl_fit/evaluator.rs +++ b/src/nl_fit/evaluator.rs @@ -1,6 +1,6 @@ +use crate::data::TimeSeries; use crate::float_trait::Float; use crate::nl_fit::{data::NormalizedData, CurveFitAlgorithm, LikeFloat, LnPrior}; -use crate::time_series::TimeSeries; use schemars::JsonSchema; use serde::de::DeserializeOwned; diff --git a/src/periodogram/freq.rs b/src/periodogram/freq.rs index d5b81662..177f3f60 100644 --- a/src/periodogram/freq.rs +++ b/src/periodogram/freq.rs @@ -1,5 +1,5 @@ +use crate::data::SortedArray; use crate::float_trait::Float; -use crate::sorted_array::SortedArray; use conv::{ConvAsUtil, ConvUtil, RoundToNearest}; use enum_dispatch::enum_dispatch; diff --git a/src/periodogram/mod.rs b/src/periodogram/mod.rs index af7a29df..6bc477ea 100644 --- a/src/periodogram/mod.rs +++ b/src/periodogram/mod.rs @@ -1,7 +1,7 @@ //! Periodogram-related stuff +use crate::data::TimeSeries; use crate::float_trait::Float; -use crate::time_series::TimeSeries; use conv::ConvAsUtil; use enum_dispatch::enum_dispatch; @@ -107,8 +107,8 @@ where mod tests { use super::*; + use crate::data::SortedArray; use crate::peak_indices::peak_indices_reverse_sorted; - use crate::sorted_array::SortedArray; use light_curve_common::{all_close, linspace}; use rand::prelude::*; diff --git a/src/periodogram/power_direct.rs b/src/periodogram/power_direct.rs index 1b659f27..a8020318 100644 --- a/src/periodogram/power_direct.rs +++ b/src/periodogram/power_direct.rs @@ -1,8 +1,8 @@ +use crate::data::TimeSeries; use crate::float_trait::Float; use crate::periodogram::freq::FreqGrid; use crate::periodogram::power_trait::*; use crate::periodogram::recurrent_sin_cos::*; -use crate::time_series::TimeSeries; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; diff --git a/src/periodogram/power_fft.rs b/src/periodogram/power_fft.rs index cb8d8647..40526df9 100644 --- a/src/periodogram/power_fft.rs +++ b/src/periodogram/power_fft.rs @@ -1,8 +1,8 @@ +use crate::data::TimeSeries; use crate::float_trait::Float; use crate::periodogram::fft::*; use crate::periodogram::freq::FreqGrid; use crate::periodogram::power_trait::*; -use crate::time_series::TimeSeries; use conv::{ConvAsUtil, RoundToNearest}; use schemars::JsonSchema; diff --git a/src/periodogram/power_trait.rs b/src/periodogram/power_trait.rs index 0ea6b723..46c271be 100644 --- a/src/periodogram/power_trait.rs +++ b/src/periodogram/power_trait.rs @@ -1,6 +1,6 @@ +use crate::data::TimeSeries; use crate::float_trait::Float; use crate::periodogram::freq::FreqGrid; -use crate::time_series::TimeSeries; use enum_dispatch::enum_dispatch; use std::fmt::Debug; diff --git a/src/prelude.rs b/src/prelude.rs index 704e1d47..0ecb0672 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,3 +1,4 @@ +pub use crate::data::TimeSeries; pub use crate::error::EvaluatorError; pub use crate::evaluator::{EvaluatorInfoTrait, FeatureEvaluator, FeatureNamesDescriptionsTrait}; pub use crate::extractor::FeatureExtractor; @@ -5,4 +6,3 @@ pub use crate::feature::Feature; pub use crate::features::*; pub use crate::float_trait::Float; pub use crate::nl_fit::evaluator::*; -pub use crate::time_series::TimeSeries; diff --git a/src/straight_line_fit.rs b/src/straight_line_fit.rs index 7b68cba3..e582dea7 100644 --- a/src/straight_line_fit.rs +++ b/src/straight_line_fit.rs @@ -1,5 +1,5 @@ +use crate::data::TimeSeries; use crate::float_trait::Float; -use crate::time_series::TimeSeries; use ndarray::Zip; diff --git a/src/tests.rs b/src/tests.rs index 544a767a..b271391d 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,8 +1,8 @@ +pub use crate::data::TimeSeries; pub use crate::evaluator::*; pub use crate::extractor::FeatureExtractor; pub use crate::feature::Feature; pub use crate::float_trait::Float; -pub use crate::time_series::TimeSeries; pub use light_curve_common::{all_close, linspace}; pub use ndarray::{Array1, ArrayView1}; @@ -119,6 +119,8 @@ pub fn eval_info_tests( .as_ref() .map(check_size); } + + eval_info_variability_required_test(&eval, &t_sorted, &w, &mut rng); } fn eval_info_ts_length_test( @@ -265,6 +267,43 @@ fn eval_info_sorting_required_test( Some(v) } +fn eval_info_variability_required_test( + eval: &Feature, + t: &[f64], + w: &[f64], + rng: &mut StdRng, +) { + assert!( + !eval.is_variability_required() || eval.is_m_required(), + "variability_required is treu, but m_required is false" + ); + + let m = vec![rng.sample::(StandardNormal).abs(); t.len()]; + let mut ts = TimeSeries::new(t, &m, w); + assert_eq!(eval.is_variability_required(), eval.eval(&mut ts).is_err()); + + match ( + std::panic::catch_unwind(|| eval.eval_no_ts_check(&mut TimeSeries::new(t, &m, w))), + eval.is_variability_required(), + ) { + (Ok(_result), true) => {} + // |-- This doesn't work sometimes because of float rounding issues + // v + // (Ok(result), true) => assert!(result + // .map(|v| assert!( + // !v.iter().copied().all(f64::is_finite), + // "{:?} are all finite", + // v + // )) + // .is_err()), + (Ok(result), false) => assert!(result + .map(|v| assert!(v.into_iter().all(f64::is_finite))) + .is_ok()), + (Err(_err), true) => {} + (Err(err), false) => panic!("{:?}", err), + } +} + #[macro_export] macro_rules! serialization_name_test { ($feature_type: ty, $feature_expr: expr) => { diff --git a/src/time_series.rs b/src/time_series.rs deleted file mode 100644 index b2a99e4d..00000000 --- a/src/time_series.rs +++ /dev/null @@ -1,520 +0,0 @@ -use crate::float_trait::Float; -use crate::sorted_array::SortedArray; -use crate::types::CowArray1; - -use conv::prelude::*; -use itertools::Itertools; -use ndarray::{s, Array1, ArrayView1, Zip}; -use ndarray_stats::SummaryStatisticsExt; - -/// A [`TimeSeries`] component -#[derive(Clone, Debug)] -pub struct DataSample<'a, T> -where - T: Float, -{ - pub sample: CowArray1<'a, T>, - sorted: Option>, - min: Option, - max: Option, - mean: Option, - median: Option, - std: Option, - std2: Option, -} - -macro_rules! data_sample_getter { - ($attr: ident, $getter: ident, $func: expr, $method_sorted: ident) => { - // This lint is false-positive in macros - // https://github.com/rust-lang/rust-clippy/issues/1553 - #[allow(clippy::redundant_closure_call)] - pub fn $getter(&mut self) -> T { - match self.$attr { - Some(x) => x, - None => { - self.$attr = Some(match self.sorted.as_ref() { - Some(sorted) => sorted.$method_sorted(), - None => $func(self), - }); - self.$attr.unwrap() - } - } - } - }; - ($attr: ident, $getter: ident, $func: expr) => { - // This lint is false-positive in macros - // https://github.com/rust-lang/rust-clippy/issues/1553 - #[allow(clippy::redundant_closure_call)] - pub fn $getter(&mut self) -> T { - match self.$attr { - Some(x) => x, - None => { - self.$attr = Some($func(self)); - self.$attr.unwrap() - } - } - } - }; -} - -impl<'a, T> DataSample<'a, T> -where - T: Float, -{ - pub fn new(sample: CowArray1<'a, T>) -> Self { - Self { - sample, - sorted: None, - min: None, - max: None, - mean: None, - median: None, - std: None, - std2: None, - } - } - - pub fn as_slice(&mut self) -> &[T] { - if !self.sample.is_standard_layout() { - let owned: Array1<_> = self.sample.iter().copied().collect::>().into(); - self.sample = owned.into(); - } - self.sample.as_slice().unwrap() - } - - pub fn get_sorted(&mut self) -> &SortedArray { - if self.sorted.is_none() { - self.sorted = Some(self.sample.to_vec().into()); - } - self.sorted.as_ref().unwrap() - } - - fn set_min_max(&mut self) { - let (min, max) = - self.sample - .slice(s![1..]) - .fold((self.sample[0], self.sample[0]), |(min, max), &x| { - if x > max { - (min, x) - } else if x < min { - (x, max) - } else { - (min, max) - } - }); - self.min = Some(min); - self.max = Some(max); - } - - data_sample_getter!( - min, - get_min, - |ds: &mut DataSample<'a, T>| { - ds.set_min_max(); - ds.min.unwrap() - }, - minimum - ); - data_sample_getter!( - max, - get_max, - |ds: &mut DataSample<'a, T>| { - ds.set_min_max(); - ds.max.unwrap() - }, - maximum - ); - data_sample_getter!(mean, get_mean, |ds: &mut DataSample<'a, T>| { - ds.sample.mean().expect("time series must be non-empty") - }); - data_sample_getter!(median, get_median, |ds: &mut DataSample<'a, T>| { - ds.get_sorted().median() - }); - data_sample_getter!(std, get_std, |ds: &mut DataSample<'a, T>| { - ds.get_std2().sqrt() - }); - data_sample_getter!(std2, get_std2, |ds: &mut DataSample<'a, T>| { - // Benchmarks show that it is faster than `ndarray::ArrayBase::var(T::one)` - let mean = ds.get_mean(); - ds.sample - .fold(T::zero(), |sum, &x| sum + (x - mean).powi(2)) - / (ds.sample.len() - 1).approx().unwrap() - }); - - pub fn signal_to_noise(&mut self, value: T) -> T { - if self.get_std().is_zero() { - T::zero() - } else { - (value - self.get_mean()) / self.get_std() - } - } -} - -impl<'a, T, Slice: ?Sized> From<&'a Slice> for DataSample<'a, T> -where - T: Float, - Slice: AsRef<[T]>, -{ - fn from(s: &'a Slice) -> Self { - ArrayView1::from(s).into() - } -} - -impl<'a, T> From> for DataSample<'a, T> -where - T: Float, -{ - fn from(v: Vec) -> Self { - Array1::from(v).into() - } -} - -impl<'a, T> From> for DataSample<'a, T> -where - T: Float, -{ - fn from(a: ArrayView1<'a, T>) -> Self { - Self::new(a.into()) - } -} - -impl<'a, T> From> for DataSample<'a, T> -where - T: Float, -{ - fn from(a: Array1) -> Self { - Self::new(a.into()) - } -} - -impl<'a, T> From> for DataSample<'a, T> -where - T: Float, -{ - fn from(a: CowArray1<'a, T>) -> Self { - Self::new(a) - } -} - -/// Time series object to be put into [Feature](crate::Feature) -/// -/// This struct caches it's properties, like mean magnitude value, etc., that's why mutable -/// reference is required fot feature evaluation -#[derive(Clone, Debug)] -pub struct TimeSeries<'a, T> -where - T: Float, -{ - pub t: DataSample<'a, T>, - pub m: DataSample<'a, T>, - pub w: DataSample<'a, T>, - m_weighted_mean: Option, - m_reduced_chi2: Option, - t_max_m: Option, - t_min_m: Option, - plateau: Option, -} - -macro_rules! time_series_getter { - ($t: ty, $attr: ident, $getter: ident, $func: expr) => { - // This lint is false-positive in macros - // https://github.com/rust-lang/rust-clippy/issues/1553 - #[allow(clippy::redundant_closure_call)] - pub fn $getter(&mut self) -> $t { - match self.$attr { - Some(x) => x, - None => { - self.$attr = Some($func(self)); - self.$attr.unwrap() - } - } - } - }; - - ($attr: ident, $getter: ident, $func: expr) => { - time_series_getter!(T, $attr, $getter, $func); - }; -} - -impl<'a, T> TimeSeries<'a, T> -where - T: Float, -{ - /// Construct `TimeSeries` from array-like objects - /// - /// `t` is time, `m` is magnitude (or flux), `w` is weights. - /// - /// All arrays must have the same length, `t` must increase monotonically. Input arrays could be - /// [`ndarray::Array1`], [`ndarray::ArrayView1`], 1-D [`ndarray::CowArray`], or `&[T]`. Several - /// features assumes that `w` array corresponds to inverse square errors of `m`. - pub fn new( - t: impl Into>, - m: impl Into>, - w: impl Into>, - ) -> Self { - let t = t.into(); - let m = m.into(); - let w = w.into(); - - assert_eq!( - t.sample.len(), - m.sample.len(), - "t and m should have the same size" - ); - assert_eq!( - m.sample.len(), - w.sample.len(), - "m and err should have the same size" - ); - - Self { - t, - m, - w, - m_weighted_mean: None, - m_reduced_chi2: None, - t_max_m: None, - t_min_m: None, - plateau: None, - } - } - - /// Construct [`TimeSeries`] from time and magnitude (flux) - /// - /// It is the same as [`TimeSeries::new`], but sets unity weights. It doesn't recommended to use - /// it for features dependent on weights / observation errors like [`crate::StetsonK`] or - /// [`crate::LinearFit`]. - pub fn new_without_weight( - t: impl Into>, - m: impl Into>, - ) -> Self { - let t = t.into(); - let m = m.into(); - - assert_eq!( - t.sample.len(), - m.sample.len(), - "t and m should have the same size" - ); - - let w = T::array0_unity().broadcast(t.sample.len()).unwrap().into(); - - Self { - t, - m, - w, - m_weighted_mean: None, - m_reduced_chi2: None, - t_max_m: None, - t_min_m: None, - plateau: None, - } - } - - /// Time series length - #[inline] - pub fn lenu(&self) -> usize { - self.t.sample.len() - } - - /// Float approximating time series length - pub fn lenf(&self) -> T { - self.lenu().approx().unwrap() - } - - time_series_getter!( - m_weighted_mean, - get_m_weighted_mean, - |ts: &mut TimeSeries| { ts.m.sample.weighted_mean(&ts.w.sample).unwrap() } - ); - - time_series_getter!(m_reduced_chi2, get_m_reduced_chi2, |ts: &mut TimeSeries< - T, - >| { - let m_weighed_mean = ts.get_m_weighted_mean(); - let m_reduced_chi2 = Zip::from(&ts.m.sample) - .and(&ts.w.sample) - .fold(T::zero(), |chi2, &m, &w| { - chi2 + (m - m_weighed_mean).powi(2) * w - }) - / (ts.lenf() - T::one()); - if m_reduced_chi2.is_zero() { - ts.plateau = Some(true); - } - m_reduced_chi2 - }); - - time_series_getter!(bool, plateau, is_plateau, |ts: &mut TimeSeries| { - if ts.m.max.is_some() && ts.m.max == ts.m.min { - return true; - } - if ts.m.std2 == Some(T::zero()) { - return true; - } - let m0 = ts.m.sample[0]; - // all() returns true for the empty slice, i.e. one-point time series - Zip::from(ts.m.sample.slice(s![1..])).all(|&m| m == m0) - }); - - fn set_t_min_max_m(&mut self) { - let (i_min, i_max) = self - .m - .as_slice() - .iter() - .position_minmax() - .into_option() - .expect("time series must be non-empty"); - self.t_min_m = Some(self.t.sample[i_min]); - self.t_max_m = Some(self.t.sample[i_max]); - } - - pub fn get_t_min_m(&mut self) -> T { - if self.t_min_m.is_none() { - self.set_t_min_max_m(); - } - self.t_min_m.unwrap() - } - - pub fn get_t_max_m(&mut self) -> T { - if self.t_max_m.is_none() { - self.set_t_min_max_m(); - } - self.t_max_m.unwrap() - } -} - -// We really don't want it to be public, it is a private helper for test-util functions -#[cfg(test)] -impl<'a, T, D> From<(D, D, D)> for TimeSeries<'a, T> -where - T: Float, - D: Into>, -{ - fn from(v: (D, D, D)) -> Self { - Self::new(v.0, v.1, v.2) - } -} - -#[cfg(test)] -impl<'a, T> From<&'a (Array1, Array1, Array1)> for TimeSeries<'a, T> -where - T: Float, -{ - fn from(v: &'a (Array1, Array1, Array1)) -> Self { - Self::new(v.0.view(), v.1.view(), v.2.view()) - } -} - -#[cfg(test)] -#[allow(clippy::unreadable_literal)] -#[allow(clippy::excessive_precision)] -mod tests { - use super::*; - - use light_curve_common::all_close; - - macro_rules! data_sample_test { - ($name: ident, $method: ident, $desired: tt, $x: tt $(,)?) => { - #[test] - fn $name() { - let x = $x; - let desired = $desired; - - let mut ds: DataSample<_> = DataSample::from(&x); - all_close(&[ds.$method()], &desired[..], 1e-6); - all_close(&[ds.$method()], &desired[..], 1e-6); - - let mut ds: DataSample<_> = DataSample::from(&x); - ds.get_sorted(); - all_close(&[ds.$method()], &desired[..], 1e-6); - all_close(&[ds.$method()], &desired[..], 1e-6); - } - }; - } - - data_sample_test!( - data_sample_min, - get_min, - [-7.79420906], - [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], - ); - - data_sample_test!( - data_sample_max, - get_max, - [6.73375373], - [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], - ); - - data_sample_test!( - data_sample_mean, - get_mean, - [-0.21613426], - [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], - ); - - data_sample_test!( - data_sample_median_odd, - get_median, - [3.28436964], - [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], - ); - - data_sample_test!( - data_sample_median_even, - get_median, - [5.655794743124782], - [9.47981408, 3.86815751, 9.90299294, -2.986894, 7.44343197, 1.52751816], - ); - - data_sample_test!( - data_sample_std, - get_std, - [6.7900544035968435], - [3.92948846, 3.28436964, 6.73375373, -7.79420906, -7.23407407], - ); - - #[test] - fn time_series_m_weighted_mean() { - let t: Vec<_> = (0..5).map(|i| i as f64).collect(); - let m = [ - 12.77883145, - 18.89988406, - 17.55633632, - 18.36073996, - 11.83854198, - ]; - let w = [0.1282489, 0.10576467, 0.32102692, 0.12962352, 0.10746144]; - let mut ts = TimeSeries::new(&t, &m, &w); - // np.average(m, weights=w) - let desired = [16.31817047752941]; - all_close(&[ts.get_m_weighted_mean()], &desired[..], 1e-6); - } - - #[test] - fn time_series_m_reduced_chi2() { - let t: Vec<_> = (0..5).map(|i| i as f64).collect(); - let m = [ - 12.77883145, - 18.89988406, - 17.55633632, - 18.36073996, - 11.83854198, - ]; - let w = [0.1282489, 0.10576467, 0.32102692, 0.12962352, 0.10746144]; - let mut ts = TimeSeries::new(&t, &m, &w); - let desired = [1.3752251301435465]; - all_close(&[ts.get_m_reduced_chi2()], &desired[..], 1e-6); - } - - /// https://github.com/light-curve/light-curve-feature/issues/95 - #[test] - fn time_series_std2_overflow() { - const N: usize = (1 << 24) + 2; - // Such a large integer cannot be represented as a float32 - let x = Array1::linspace(0.0_f32, 1.0, N); - let mut ds = DataSample::new(x.into()); - // This should not panic - let _std2 = ds.get_std2(); - } -}