Skip to content

Commit c440a5d

Browse files
committed
Flatter MultiColorTimeSeries
1 parent e854863 commit c440a5d

File tree

7 files changed

+180
-90
lines changed

7 files changed

+180
-90
lines changed

src/data/multi_color_time_series.rs

Lines changed: 72 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,71 @@
11
use crate::data::TimeSeries;
22
use crate::float_trait::Float;
33
use crate::multicolor::PassbandTrait;
4-
use crate::PassbandSet;
4+
use crate::{DataSample, PassbandSet};
55

66
use itertools::Either;
77
use itertools::EitherOrBoth;
88
use itertools::Itertools;
99
use std::collections::BTreeMap;
10-
use std::ops::{Deref, DerefMut};
1110

12-
pub struct MultiColorTimeSeries<'a, P: PassbandTrait, T: Float>(BTreeMap<P, TimeSeries<'a, T>>);
11+
pub struct MultiColorTimeSeries<'a, P: PassbandTrait, T: Float> {
12+
mapping: BTreeMap<P, TimeSeries<'a, T>>,
13+
flat: Option<FlatMultiColorTimeSeries<'static, P, T>>,
14+
}
1315

1416
impl<'a, 'p, P, T> MultiColorTimeSeries<'a, P, T>
1517
where
1618
P: PassbandTrait + 'p,
1719
T: Float,
1820
{
1921
pub fn new(map: impl Into<BTreeMap<P, TimeSeries<'a, T>>>) -> Self {
20-
Self(map.into())
22+
Self {
23+
mapping: map.into(),
24+
flat: None,
25+
}
26+
}
27+
28+
pub fn flatten(&mut self) -> &FlatMultiColorTimeSeries<'static, P, T> {
29+
self.flat
30+
.get_or_insert_with(|| FlatMultiColorTimeSeries::from_mapping(&mut self.mapping))
31+
}
32+
33+
pub fn passbands<'slf>(
34+
&'slf self,
35+
) -> std::collections::btree_map::Keys<'slf, P, TimeSeries<'a, T>>
36+
where
37+
'a: 'slf,
38+
{
39+
self.mapping.keys()
2140
}
2241

2342
pub fn iter_passband_set<'slf, 'ps>(
2443
&'slf self,
2544
passband_set: &'ps PassbandSet<P>,
26-
) -> impl Iterator<Item = (&P, Option<&TimeSeries<'a, T>>)> + 'ps
45+
) -> impl Iterator<Item = (&P, Option<&TimeSeries<'a, T>>)> + 'slf
2746
where
28-
'a: 'ps,
29-
'slf: 'ps,
30-
'ps: 'slf,
47+
'a: 'slf,
48+
'ps: 'a,
3149
{
3250
match passband_set {
33-
PassbandSet::AllAvailable => Either::Left(self.0.iter().map(|(p, ts)| (p, Some(ts)))),
51+
PassbandSet::AllAvailable => {
52+
Either::Left(self.mapping.iter().map(|(p, ts)| (p, Some(ts))))
53+
}
3454
PassbandSet::FixedSet(set) => Either::Right(self.iter_matched_passbands(set.iter())),
3555
}
3656
}
3757

3858
pub fn iter_passband_set_mut<'slf, 'ps>(
3959
&'slf mut self,
4060
passband_set: &'ps PassbandSet<P>,
41-
) -> impl Iterator<Item = (&P, Option<&mut TimeSeries<'a, T>>)> + 'ps
61+
) -> impl Iterator<Item = (&P, Option<&mut TimeSeries<'a, T>>)> + 'slf
4262
where
43-
'a: 'ps,
44-
'slf: 'ps,
45-
'ps: 'slf,
63+
'a: 'slf,
64+
'ps: 'a,
4665
{
4766
match passband_set {
4867
PassbandSet::AllAvailable => {
49-
Either::Left(self.0.iter_mut().map(|(p, ts)| (p, Some(ts))))
68+
Either::Left(self.mapping.iter_mut().map(|(p, ts)| (p, Some(ts))))
5069
}
5170
PassbandSet::FixedSet(set) => {
5271
Either::Right(self.iter_matched_passbands_mut(set.iter()))
@@ -58,15 +77,15 @@ where
5877
&self,
5978
passband_it: impl Iterator<Item = &'p P>,
6079
) -> impl Iterator<Item = (&'p P, Option<&TimeSeries<'a, T>>)> {
61-
passband_it.map(|p| (p, self.0.get(p)))
80+
passband_it.map(|p| (p, self.mapping.get(p)))
6281
}
6382

6483
pub fn iter_matched_passbands_mut(
6584
&mut self,
6685
passband_it: impl Iterator<Item = &'p P>,
6786
) -> impl Iterator<Item = (&'p P, Option<&mut TimeSeries<'a, T>>)> {
6887
passband_it
69-
.merge_join_by(self.0.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2))
88+
.merge_join_by(self.mapping.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2))
7089
.filter_map(|either_or_both| match either_or_both {
7190
// mcts misses required passband
7291
EitherOrBoth::Left(p) => Some((p, None)),
@@ -78,24 +97,48 @@ where
7897
}
7998
}
8099

81-
impl<'a, P: PassbandTrait, T: Float> Deref for MultiColorTimeSeries<'a, P, T> {
82-
type Target = BTreeMap<P, TimeSeries<'a, T>>;
83-
84-
fn deref(&self) -> &Self::Target {
85-
&self.0
100+
impl<'a, P: PassbandTrait, T: Float> FromIterator<(P, TimeSeries<'a, T>)>
101+
for MultiColorTimeSeries<'a, P, T>
102+
{
103+
fn from_iter<I: IntoIterator<Item = (P, TimeSeries<'a, T>)>>(iter: I) -> Self {
104+
Self {
105+
mapping: iter.into_iter().collect(),
106+
flat: None,
107+
}
86108
}
87109
}
88110

89-
impl<'a, P: PassbandTrait, T: Float> DerefMut for MultiColorTimeSeries<'a, P, T> {
90-
fn deref_mut(&mut self) -> &mut Self::Target {
91-
&mut self.0
92-
}
111+
pub struct FlatMultiColorTimeSeries<'a, P: PassbandTrait, T: Float> {
112+
pub t: DataSample<'a, T>,
113+
pub m: DataSample<'a, T>,
114+
pub w: DataSample<'a, T>,
115+
pub passbands: Vec<P>,
93116
}
94117

95-
impl<'a, P: PassbandTrait, T: Float> FromIterator<(P, TimeSeries<'a, T>)>
96-
for MultiColorTimeSeries<'a, P, T>
118+
impl<P, T> FlatMultiColorTimeSeries<'static, P, T>
119+
where
120+
P: PassbandTrait,
121+
T: Float,
97122
{
98-
fn from_iter<I: IntoIterator<Item = (P, TimeSeries<'a, T>)>>(iter: I) -> Self {
99-
Self(iter.into_iter().collect())
123+
pub fn from_mapping(mapping: &mut BTreeMap<P, TimeSeries<T>>) -> Self {
124+
let (t, m, w, passbands): (Vec<_>, Vec<_>, Vec<_>, _) = mapping
125+
.iter_mut()
126+
.map(|(p, ts)| {
127+
itertools::multizip((
128+
ts.t.as_slice().iter().copied(),
129+
ts.m.as_slice().iter().copied(),
130+
ts.w.as_slice().iter().copied(),
131+
std::iter::repeat(p.clone()),
132+
))
133+
})
134+
.kmerge_by(|(t1, _, _, _), (t2, _, _, _)| t1 <= t2)
135+
.multiunzip();
136+
137+
Self {
138+
t: t.into(),
139+
m: m.into(),
140+
w: w.into(),
141+
passbands,
142+
}
100143
}
101144
}

src/multicolor/features/color_of_maximum.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,14 @@ where
8989
P: PassbandTrait,
9090
T: Float,
9191
{
92-
fn eval_multicolor_no_mcts_check(
93-
&self,
94-
mcts: &mut MultiColorTimeSeries<P, T>,
95-
) -> Result<Vec<T>, MultiColorEvaluatorError> {
92+
fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>(
93+
&'slf self,
94+
mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>,
95+
) -> Result<Vec<T>, MultiColorEvaluatorError>
96+
where
97+
'slf: 'a,
98+
'a: 'mcts,
99+
{
96100
let mut maxima = [T::zero(); 2];
97101
for ((_passband, mcts), maximum) in mcts
98102
.iter_matched_passbands_mut(self.passbands.iter())

src/multicolor/features/color_of_median.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,14 @@ where
9898
P: PassbandTrait,
9999
T: Float,
100100
{
101-
fn eval_multicolor_no_mcts_check(
102-
&self,
103-
mcts: &mut MultiColorTimeSeries<P, T>,
104-
) -> Result<Vec<T>, MultiColorEvaluatorError> {
101+
fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>(
102+
&'slf self,
103+
mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>,
104+
) -> Result<Vec<T>, MultiColorEvaluatorError>
105+
where
106+
'slf: 'a,
107+
'a: 'mcts,
108+
{
105109
let mut medians = [T::zero(); 2];
106110
for ((passband, mcts), median) in mcts
107111
.iter_matched_passbands_mut(self.passbands.iter())

src/multicolor/features/color_of_minimum.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,14 @@ where
8989
P: PassbandTrait,
9090
T: Float,
9191
{
92-
fn eval_multicolor_no_mcts_check(
93-
&self,
94-
mcts: &mut MultiColorTimeSeries<P, T>,
95-
) -> Result<Vec<T>, MultiColorEvaluatorError> {
92+
fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>(
93+
&'slf self,
94+
mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>,
95+
) -> Result<Vec<T>, MultiColorEvaluatorError>
96+
where
97+
'slf: 'a,
98+
'a: 'mcts,
99+
{
96100
let mut minima = [T::zero(); 2];
97101
for ((_passband, mcts), minimum) in mcts
98102
.iter_matched_passbands_mut(self.passbands.iter())

src/multicolor/monochrome_feature.rs

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -106,23 +106,26 @@ where
106106
T: Float,
107107
F: FeatureEvaluator<T>,
108108
{
109-
fn eval_multicolor_no_mcts_check(
110-
&self,
111-
mcts: &mut MultiColorTimeSeries<P, T>,
112-
) -> Result<Vec<T>, MultiColorEvaluatorError> {
109+
fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>(
110+
&'slf self,
111+
mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>,
112+
) -> Result<Vec<T>, MultiColorEvaluatorError>
113+
where
114+
'slf: 'a,
115+
'a: 'mcts,
116+
{
113117
match &self.passband_set {
114-
PassbandSet::FixedSet(set) => set
115-
.iter()
116-
.map(|passband| {
117-
self.feature.eval(mcts.get_mut(passband).expect(
118-
"we checked all needed passbands are in mcts, but we still cannot find one",
119-
)).map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError {
120-
passband: passband.name().into(),
121-
error,
122-
})
123-
})
124-
.flatten_ok()
125-
.collect(),
118+
PassbandSet::FixedSet(set) => {
119+
mcts.iter_matched_passbands_mut(set.iter())
120+
.map(|(passband, ts)| {
121+
self.feature.eval_no_ts_check(
122+
ts.expect("we checked all needed passbands are in mcts, but we still cannot find one")
123+
).map_err(|error| MultiColorEvaluatorError::MonochromeEvaluatorError {
124+
passband: passband.name().into(),
125+
error,
126+
})
127+
}).flatten_ok().collect()
128+
}
126129
PassbandSet::AllAvailable => panic!("passband_set must be FixedSet variant here"),
127130
}
128131
}

src/multicolor/multicolor_evaluator.rs

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,22 @@ enum InternalMctsError {
5050
}
5151

5252
impl InternalMctsError {
53-
fn into_multi_color_evaluator_error<P, T>(
53+
fn into_multi_color_evaluator_error<'mcts, 'a, 'ps, P, T>(
5454
self,
55-
mcts: &MultiColorTimeSeries<P, T>,
56-
ps: &PassbandSet<P>,
55+
mcts: &'mcts MultiColorTimeSeries<'a, P, T>,
56+
ps: &'ps PassbandSet<P>,
5757
) -> MultiColorEvaluatorError
5858
where
59+
'ps: 'a,
60+
'a: 'mcts,
5961
P: PassbandTrait,
6062
T: Float,
6163
{
6264
match self {
6365
InternalMctsError::MultiColorEvaluatorError(e) => e,
6466
InternalMctsError::InternalWrongPassbandSet => {
6567
MultiColorEvaluatorError::wrong_passbands_error(
66-
mcts.keys(),
68+
mcts.passbands(),
6769
match ps {
6870
PassbandSet::FixedSet(ps) => ps.iter(),
6971
PassbandSet::AllAvailable => {
@@ -88,37 +90,55 @@ where
8890
T: Float,
8991
{
9092
/// Version of [MultiColorEvaluator::eval_multicolor] without basic [MultiColorTimeSeries] checks
91-
fn eval_multicolor_no_mcts_check(
92-
&self,
93-
mcts: &mut MultiColorTimeSeries<P, T>,
94-
) -> Result<Vec<T>, MultiColorEvaluatorError>;
93+
fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>(
94+
&'slf self,
95+
mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>,
96+
) -> Result<Vec<T>, MultiColorEvaluatorError>
97+
where
98+
'slf: 'a,
99+
'a: 'mcts;
95100

96101
/// Vector of feature values or `EvaluatorError`
97-
fn eval_multicolor(
98-
&self,
99-
mcts: &mut MultiColorTimeSeries<P, T>,
100-
) -> Result<Vec<T>, MultiColorEvaluatorError> {
102+
fn eval_multicolor<'slf, 'a, 'mcts>(
103+
&'slf self,
104+
mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>,
105+
) -> Result<Vec<T>, MultiColorEvaluatorError>
106+
where
107+
'slf: 'a,
108+
'a: 'mcts,
109+
P: 'a,
110+
{
101111
self.check_mcts(mcts)?;
102112
self.eval_multicolor_no_mcts_check(mcts)
103113
}
104114

105115
/// Returns vector of feature values and fill invalid components with given value
106-
fn eval_or_fill_multicolor(
107-
&self,
108-
mcts: &mut MultiColorTimeSeries<P, T>,
116+
fn eval_or_fill_multicolor<'slf, 'a, 'mcts>(
117+
&'slf self,
118+
mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>,
109119
fill_value: T,
110-
) -> Result<Vec<T>, MultiColorEvaluatorError> {
120+
) -> Result<Vec<T>, MultiColorEvaluatorError>
121+
where
122+
'slf: 'a,
123+
'a: 'mcts,
124+
P: 'a,
125+
{
111126
Ok(match self.eval_multicolor(mcts) {
112127
Ok(v) => v,
113128
Err(_) => vec![fill_value; self.size_hint()],
114129
})
115130
}
116131

117132
/// Check [MultiColorTimeSeries] to have required passbands and individual [TimeSeries] are valid
118-
fn check_mcts(
119-
&self,
120-
mcts: &mut MultiColorTimeSeries<P, T>,
121-
) -> Result<(), MultiColorEvaluatorError> {
133+
fn check_mcts<'slf, 'a, 'mcts>(
134+
&'slf self,
135+
mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>,
136+
) -> Result<(), MultiColorEvaluatorError>
137+
where
138+
'slf: 'a,
139+
'a: 'mcts,
140+
P: 'a,
141+
{
122142
mcts.iter_passband_set_mut(self.get_passband_set())
123143
.map(|(p, maybe_ts)| {
124144
maybe_ts
@@ -187,10 +207,14 @@ mod tests {
187207
where
188208
T: Float,
189209
{
190-
fn eval_multicolor_no_mcts_check(
191-
&self,
192-
_mcts: &mut MultiColorTimeSeries<MonochromePassband<'static, f64>, T>,
193-
) -> Result<Vec<T>, MultiColorEvaluatorError> {
210+
fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>(
211+
&'slf self,
212+
_mcts: &'mcts mut MultiColorTimeSeries<'a, MonochromePassband<'static, f64>, T>,
213+
) -> Result<Vec<T>, MultiColorEvaluatorError>
214+
where
215+
'slf: 'a,
216+
'a: 'mcts,
217+
{
194218
Ok(vec![T::zero()])
195219
}
196220
}

0 commit comments

Comments
 (0)