Skip to content

Commit 50027d2

Browse files
committed
Refactor mcts checking
1 parent 1f08fff commit 50027d2

File tree

8 files changed

+274
-88
lines changed

8 files changed

+274
-88
lines changed

src/data/multi_color_time_series.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,69 @@
11
use crate::data::TimeSeries;
22
use crate::float_trait::Float;
33
use crate::multicolor::PassbandTrait;
4+
use crate::PassbandSet;
45

6+
use itertools::Either;
7+
use itertools::EitherOrBoth;
8+
use itertools::Itertools;
59
use std::collections::BTreeMap;
610
use std::ops::{Deref, DerefMut};
711

812
pub struct MultiColorTimeSeries<'a, P: PassbandTrait, T: Float>(BTreeMap<P, TimeSeries<'a, T>>);
913

14+
impl<'a, P, T> MultiColorTimeSeries<'a, P, T>
15+
where
16+
P: PassbandTrait,
17+
T: Float,
18+
{
19+
pub fn new(map: impl Into<BTreeMap<P, TimeSeries<'a, T>>>) -> Self {
20+
Self(map.into())
21+
}
22+
23+
pub fn iter_passband_set<'slf, 'ps>(
24+
&'slf self,
25+
passband_set: &'ps PassbandSet<P>,
26+
) -> impl Iterator<Item = (&P, Option<&TimeSeries<'a, T>>)> + 'ps
27+
where
28+
'a: 'ps,
29+
'slf: 'ps,
30+
'ps: 'slf,
31+
{
32+
match passband_set {
33+
PassbandSet::AllAvailable => Either::Left(self.0.iter().map(|(p, ts)| (p, Some(ts)))),
34+
PassbandSet::FixedSet(set) => Either::Right(set.iter().map(|p| (p, self.0.get(p)))),
35+
}
36+
}
37+
38+
pub fn iter_passband_set_mut<'slf, 'ps>(
39+
&'slf mut self,
40+
passband_set: &'ps PassbandSet<P>,
41+
) -> impl Iterator<Item = (&P, Option<&mut TimeSeries<'a, T>>)> + 'ps
42+
where
43+
'a: 'ps,
44+
'slf: 'ps,
45+
'ps: 'slf,
46+
{
47+
match passband_set {
48+
PassbandSet::AllAvailable => {
49+
Either::Left(self.0.iter_mut().map(|(p, ts)| (p, Some(ts))))
50+
}
51+
PassbandSet::FixedSet(set) => Either::Right(
52+
set.iter()
53+
.merge_join_by(self.0.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2))
54+
.filter_map(|either_or_both| match either_or_both {
55+
// mcts misses required passband
56+
EitherOrBoth::Left(p) => Some((p, None)),
57+
// mcts has some passban passband_set doesn't require
58+
EitherOrBoth::Right(_) => None,
59+
// passbands match
60+
EitherOrBoth::Both(p, (_, ts)) => Some((p, Some(ts))),
61+
}),
62+
),
63+
}
64+
}
65+
}
66+
1067
impl<'a, P: PassbandTrait, T: Float> Deref for MultiColorTimeSeries<'a, P, T> {
1168
type Target = BTreeMap<P, TimeSeries<'a, T>>;
1269

src/evaluator.rs

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,42 @@ pub trait EvaluatorInfoTrait {
7575
fn is_variability_required(&self) -> bool {
7676
self.get_info().variability_required
7777
}
78+
79+
fn check_ts<F>(&self, ts: &mut TimeSeries<F>) -> Result<(), EvaluatorError>
80+
where
81+
F: Float,
82+
{
83+
self.check_ts_length(ts)?;
84+
self.check_ts_variability(ts)
85+
}
86+
87+
/// Checks if [TimeSeries] has enough points to evaluate the feature
88+
fn check_ts_length<F>(&self, ts: &TimeSeries<F>) -> Result<(), EvaluatorError>
89+
where
90+
F: Float,
91+
{
92+
let length = ts.lenu();
93+
if length < self.min_ts_length() {
94+
Err(EvaluatorError::ShortTimeSeries {
95+
actual: length,
96+
minimum: self.min_ts_length(),
97+
})
98+
} else {
99+
Ok(())
100+
}
101+
}
102+
103+
/// Checks if [TimeSeries] meets variability requirement
104+
fn check_ts_variability<F>(&self, ts: &mut TimeSeries<F>) -> Result<(), EvaluatorError>
105+
where
106+
F: Float,
107+
{
108+
if self.is_variability_required() && ts.is_plateau() {
109+
Err(EvaluatorError::FlatTimeSeries)
110+
} else {
111+
Ok(())
112+
}
113+
}
78114
}
79115

80116
// impl<P> EvaluatorInfoTrait for P
@@ -146,33 +182,6 @@ pub trait FeatureEvaluator<T: Float>:
146182
Err(_) => vec![fill_value; self.size_hint()],
147183
}
148184
}
149-
150-
fn check_ts(&self, ts: &mut TimeSeries<T>) -> Result<(), EvaluatorError> {
151-
self.check_ts_length(ts)?;
152-
self.check_ts_variability(ts)
153-
}
154-
155-
/// Checks if [TimeSeries] has enough points to evaluate the feature
156-
fn check_ts_length(&self, ts: &TimeSeries<T>) -> Result<(), EvaluatorError> {
157-
let length = ts.lenu();
158-
if length < self.min_ts_length() {
159-
Err(EvaluatorError::ShortTimeSeries {
160-
actual: length,
161-
minimum: self.min_ts_length(),
162-
})
163-
} else {
164-
Ok(())
165-
}
166-
}
167-
168-
/// Checks if [TimeSeries] meets variability requirement
169-
fn check_ts_variability(&self, ts: &mut TimeSeries<T>) -> Result<(), EvaluatorError> {
170-
if self.is_variability_required() && ts.is_plateau() {
171-
Err(EvaluatorError::FlatTimeSeries)
172-
} else {
173-
Ok(())
174-
}
175-
}
176185
}
177186

178187
pub trait OwnedArrays<T>

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#![doc = include_str!("../README.md")]
22

3+
extern crate core;
4+
35
#[cfg(test)]
46
#[macro_use]
57
mod tests;

src/multicolor/features/color_of_median.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ where
5252
}
5353

5454
lazy_info!(
55-
COLOR_MEDIAN_INFO,
55+
COLOR_OF_MEDIAN_INFO,
5656
size: 1,
5757
min_ts_length: 1,
5858
t_required: false,
@@ -67,7 +67,7 @@ where
6767
P: Ord,
6868
{
6969
fn get_info(&self) -> &EvaluatorInfo {
70-
&COLOR_MEDIAN_INFO
70+
&COLOR_OF_MEDIAN_INFO
7171
}
7272
}
7373

@@ -98,11 +98,10 @@ where
9898
P: PassbandTrait,
9999
T: Float,
100100
{
101-
fn eval_multicolor(
101+
fn eval_multicolor_no_mcts_check(
102102
&self,
103103
mcts: &mut MultiColorTimeSeries<P, T>,
104104
) -> Result<Vec<T>, MultiColorEvaluatorError> {
105-
self.check_mcts_passabands(mcts)?;
106105
let mut medians = [T::zero(); 2];
107106
for (median, passband) in medians.iter_mut().zip(self.passbands.iter()) {
108107
*median = self

src/multicolor/monochrome_feature.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,10 @@ where
106106
T: Float,
107107
F: FeatureEvaluator<T>,
108108
{
109-
fn eval_multicolor(
109+
fn eval_multicolor_no_mcts_check(
110110
&self,
111111
mcts: &mut MultiColorTimeSeries<P, T>,
112112
) -> Result<Vec<T>, MultiColorEvaluatorError> {
113-
self.check_mcts_passabands(mcts)?;
114113
match &self.passband_set {
115114
PassbandSet::FixedSet(set) => set
116115
.iter()

0 commit comments

Comments
 (0)