Skip to content

Commit ec3d525

Browse files
committed
Make MultiColorTimeSeries an enum
1 parent 44d9875 commit ec3d525

File tree

6 files changed

+201
-33
lines changed

6 files changed

+201
-33
lines changed

src/data/multi_color_time_series.rs

Lines changed: 191 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,146 @@ use crate::{DataSample, PassbandSet};
66
use itertools::Either;
77
use itertools::EitherOrBoth;
88
use itertools::Itertools;
9-
use std::collections::BTreeMap;
9+
use std::collections::{BTreeMap, BTreeSet};
10+
use std::ops::{Deref, DerefMut};
1011

11-
pub struct MultiColorTimeSeries<'a, P: PassbandTrait, T: Float> {
12-
mapping: BTreeMap<P, TimeSeries<'a, T>>,
13-
flat: Option<FlatMultiColorTimeSeries<'static, P, T>>,
12+
pub enum MultiColorTimeSeries<'a, P: PassbandTrait, T: Float> {
13+
Mapping(MappedMultiColorTimeSeries<'a, P, T>),
14+
Flat(FlatMultiColorTimeSeries<'a, P, T>),
15+
MappingFlat {
16+
mapping: MappedMultiColorTimeSeries<'a, P, T>,
17+
flat: FlatMultiColorTimeSeries<'a, P, T>,
18+
},
1419
}
1520

1621
impl<'a, 'p, P, T> MultiColorTimeSeries<'a, P, T>
1722
where
1823
P: PassbandTrait + 'p,
1924
T: Float,
2025
{
21-
pub fn new(map: impl Into<BTreeMap<P, TimeSeries<'a, T>>>) -> Self {
22-
Self {
23-
mapping: map.into(),
24-
flat: None,
26+
pub fn from_map(map: impl Into<BTreeMap<P, TimeSeries<'a, T>>>) -> Self {
27+
Self::Mapping(MappedMultiColorTimeSeries::new(map))
28+
}
29+
30+
pub fn from_flat(
31+
t: impl Into<DataSample<'a, T>>,
32+
m: impl Into<DataSample<'a, T>>,
33+
w: impl Into<DataSample<'a, T>>,
34+
passbands: impl Into<Vec<P>>,
35+
) -> Self {
36+
Self::Flat(FlatMultiColorTimeSeries::new(t, m, w, passbands))
37+
}
38+
39+
pub fn mapping_mut(&mut self) -> &mut MappedMultiColorTimeSeries<'a, P, T> {
40+
if matches!(self, MultiColorTimeSeries::Flat(_)) {
41+
let dummy_self = Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new()));
42+
*self = match std::mem::replace(self, dummy_self) {
43+
Self::Flat(mut flat) => {
44+
let mapping = MappedMultiColorTimeSeries::from_flat(&mut flat);
45+
Self::MappingFlat { mapping, flat }
46+
}
47+
_ => unreachable!(),
48+
}
49+
}
50+
match self {
51+
Self::Mapping(mapping) => mapping,
52+
Self::Flat(_flat) => {
53+
unreachable!("::Flat variant is already transofrmed to ::MappingFlat")
54+
}
55+
Self::MappingFlat { mapping, .. } => mapping,
56+
}
57+
}
58+
59+
pub fn mapping(&self) -> Option<&MappedMultiColorTimeSeries<'a, P, T>> {
60+
match self {
61+
Self::Mapping(mapping) => Some(mapping),
62+
Self::Flat(_flat) => None,
63+
Self::MappingFlat { mapping, .. } => Some(mapping),
64+
}
65+
}
66+
67+
pub fn flat_mut(&mut self) -> &mut FlatMultiColorTimeSeries<'a, P, T> {
68+
if matches!(self, MultiColorTimeSeries::Mapping(_)) {
69+
let dummy_self = Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new()));
70+
*self = match std::mem::replace(self, dummy_self) {
71+
Self::Mapping(mut mapping) => {
72+
let flat = FlatMultiColorTimeSeries::from_mapping(&mut mapping);
73+
Self::MappingFlat { mapping, flat }
74+
}
75+
_ => unreachable!(),
76+
}
77+
}
78+
match self {
79+
Self::Mapping(_mapping) => {
80+
unreachable!("::Mapping veriant is already transformed to ::MappingFlat")
81+
}
82+
Self::Flat(flat) => flat,
83+
Self::MappingFlat { flat, .. } => flat,
84+
}
85+
}
86+
87+
pub fn flat(&self) -> Option<&FlatMultiColorTimeSeries<'a, P, T>> {
88+
match self {
89+
Self::Mapping(_mapping) => None,
90+
Self::Flat(flat) => Some(flat),
91+
Self::MappingFlat { flat, .. } => Some(flat),
92+
}
93+
}
94+
95+
pub fn passbands<'slf>(
96+
&'slf self,
97+
) -> Either<
98+
std::collections::btree_map::Keys<'slf, P, TimeSeries<'a, T>>,
99+
std::collections::btree_set::Iter<P>,
100+
>
101+
where
102+
'a: 'slf,
103+
{
104+
match self {
105+
Self::Mapping(mapping) => Either::Left(mapping.passbands()),
106+
Self::Flat(flat) => Either::Right(flat.passband_set.iter()),
107+
Self::MappingFlat { mapping, .. } => Either::Left(mapping.passbands()),
25108
}
26109
}
110+
}
111+
112+
pub struct MappedMultiColorTimeSeries<'a, P: PassbandTrait, T: Float>(
113+
BTreeMap<P, TimeSeries<'a, T>>,
114+
);
27115

28-
pub fn flatten(&mut self) -> &FlatMultiColorTimeSeries<'static, P, T> {
29-
self.flat
30-
.get_or_insert_with(|| FlatMultiColorTimeSeries::from_mapping(&mut self.mapping))
116+
impl<'a, 'p, P, T> MappedMultiColorTimeSeries<'a, P, T>
117+
where
118+
P: PassbandTrait + 'p,
119+
T: Float,
120+
{
121+
pub fn new(map: impl Into<BTreeMap<P, TimeSeries<'a, T>>>) -> Self {
122+
Self(map.into())
123+
}
124+
125+
pub fn from_flat(flat: &mut FlatMultiColorTimeSeries<P, T>) -> Self {
126+
let mut map = BTreeMap::new();
127+
let groups = itertools::multizip((
128+
flat.t.as_slice().iter(),
129+
flat.m.as_slice().iter(),
130+
flat.w.as_slice().iter(),
131+
flat.passbands.iter(),
132+
))
133+
.group_by(|(_t, _m, _w, p)| (*p).clone());
134+
for (p, group) in &groups {
135+
let (t_vec, m_vec, w_vec) = map
136+
.entry(p.clone())
137+
.or_insert_with(|| (vec![], vec![], vec![]));
138+
for (&t, &m, &w, _p) in group {
139+
t_vec.push(t);
140+
m_vec.push(m);
141+
w_vec.push(w);
142+
}
143+
}
144+
Self(
145+
map.into_iter()
146+
.map(|(p, (t, m, w))| (p, TimeSeries::new(t, m, w)))
147+
.collect(),
148+
)
31149
}
32150

33151
pub fn passbands<'slf>(
@@ -36,7 +154,7 @@ where
36154
where
37155
'a: 'slf,
38156
{
39-
self.mapping.keys()
157+
self.keys()
40158
}
41159

42160
pub fn iter_passband_set<'slf, 'ps>(
@@ -48,9 +166,7 @@ where
48166
'ps: 'a,
49167
{
50168
match passband_set {
51-
PassbandSet::AllAvailable => {
52-
Either::Left(self.mapping.iter().map(|(p, ts)| (p, Some(ts))))
53-
}
169+
PassbandSet::AllAvailable => Either::Left(self.iter().map(|(p, ts)| (p, Some(ts)))),
54170
PassbandSet::FixedSet(set) => Either::Right(self.iter_matched_passbands(set.iter())),
55171
}
56172
}
@@ -64,9 +180,7 @@ where
64180
'ps: 'a,
65181
{
66182
match passband_set {
67-
PassbandSet::AllAvailable => {
68-
Either::Left(self.mapping.iter_mut().map(|(p, ts)| (p, Some(ts))))
69-
}
183+
PassbandSet::AllAvailable => Either::Left(self.iter_mut().map(|(p, ts)| (p, Some(ts)))),
70184
PassbandSet::FixedSet(set) => {
71185
Either::Right(self.iter_matched_passbands_mut(set.iter()))
72186
}
@@ -77,15 +191,15 @@ where
77191
&self,
78192
passband_it: impl Iterator<Item = &'p P>,
79193
) -> impl Iterator<Item = (&'p P, Option<&TimeSeries<'a, T>>)> {
80-
passband_it.map(|p| (p, self.mapping.get(p)))
194+
passband_it.map(|p| (p, self.get(p)))
81195
}
82196

83197
pub fn iter_matched_passbands_mut(
84198
&mut self,
85199
passband_it: impl Iterator<Item = &'p P>,
86200
) -> impl Iterator<Item = (&'p P, Option<&mut TimeSeries<'a, T>>)> {
87201
passband_it
88-
.merge_join_by(self.mapping.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2))
202+
.merge_join_by(self.iter_mut(), |p1, (p2, _ts)| p1.cmp(p2))
89203
.filter_map(|either_or_both| match either_or_both {
90204
// mcts misses required passband
91205
EitherOrBoth::Left(p) => Some((p, None)),
@@ -98,13 +212,24 @@ where
98212
}
99213

100214
impl<'a, P: PassbandTrait, T: Float> FromIterator<(P, TimeSeries<'a, T>)>
101-
for MultiColorTimeSeries<'a, P, T>
215+
for MappedMultiColorTimeSeries<'a, P, T>
102216
{
103217
fn from_iter<I: IntoIterator<Item = (P, TimeSeries<'a, T>)>>(iter: I) -> Self {
104-
Self {
105-
mapping: iter.into_iter().collect(),
106-
flat: None,
107-
}
218+
Self(iter.into_iter().collect())
219+
}
220+
}
221+
222+
impl<'a, P: PassbandTrait, T: Float> Deref for MappedMultiColorTimeSeries<'a, P, T> {
223+
type Target = BTreeMap<P, TimeSeries<'a, T>>;
224+
225+
fn deref(&self) -> &Self::Target {
226+
&self.0
227+
}
228+
}
229+
230+
impl<'a, P: PassbandTrait, T: Float> DerefMut for MappedMultiColorTimeSeries<'a, P, T> {
231+
fn deref_mut(&mut self) -> &mut Self::Target {
232+
&mut self.0
108233
}
109234
}
110235

@@ -113,13 +238,51 @@ pub struct FlatMultiColorTimeSeries<'a, P: PassbandTrait, T: Float> {
113238
pub m: DataSample<'a, T>,
114239
pub w: DataSample<'a, T>,
115240
pub passbands: Vec<P>,
241+
passband_set: BTreeSet<P>,
116242
}
117243

118-
impl<P, T> FlatMultiColorTimeSeries<'static, P, T>
244+
impl<'a, P, T> FlatMultiColorTimeSeries<'a, P, T>
119245
where
120246
P: PassbandTrait,
121247
T: Float,
122248
{
249+
pub fn new(
250+
t: impl Into<DataSample<'a, T>>,
251+
m: impl Into<DataSample<'a, T>>,
252+
w: impl Into<DataSample<'a, T>>,
253+
passbands: impl Into<Vec<P>>,
254+
) -> Self {
255+
let t = t.into();
256+
let m = m.into();
257+
let w = w.into();
258+
let passbands = passbands.into();
259+
let passband_set = passbands.iter().cloned().collect();
260+
261+
assert_eq!(
262+
t.sample.len(),
263+
m.sample.len(),
264+
"t and m should have the same size"
265+
);
266+
assert_eq!(
267+
m.sample.len(),
268+
w.sample.len(),
269+
"m and err should have the same size"
270+
);
271+
assert_eq!(
272+
t.sample.len(),
273+
passbands.len(),
274+
"t and passbands should have the same size"
275+
);
276+
277+
Self {
278+
t,
279+
m,
280+
w,
281+
passbands,
282+
passband_set,
283+
}
284+
}
285+
123286
pub fn from_mapping(mapping: &mut BTreeMap<P, TimeSeries<T>>) -> Self {
124287
let (t, m, w, passbands): (Vec<_>, Vec<_>, Vec<_>, _) = mapping
125288
.iter_mut()
@@ -131,14 +294,15 @@ where
131294
std::iter::repeat(p.clone()),
132295
))
133296
})
134-
.kmerge_by(|(t1, _, _, _), (t2, _, _, _)| t1 <= t2)
297+
.kmerge_by(|(t1, _m1, _w1, _p1), (t2, _m2, _w2, _p2)| t1 <= t2)
135298
.multiunzip();
136299

137300
Self {
138301
t: t.into(),
139302
m: m.into(),
140303
w: w.into(),
141304
passbands,
305+
passband_set: mapping.keys().cloned().collect(),
142306
}
143307
}
144308
}

src/multicolor/features/color_of_maximum.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ where
9999
{
100100
let mut maxima = [T::zero(); 2];
101101
for ((_passband, mcts), maximum) in mcts
102+
.mapping_mut()
102103
.iter_matched_passbands_mut(self.passbands.iter())
103104
.zip(maxima.iter_mut())
104105
{

src/multicolor/features/color_of_median.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ where
108108
{
109109
let mut medians = [T::zero(); 2];
110110
for ((passband, mcts), median) in mcts
111+
.mapping_mut()
111112
.iter_matched_passbands_mut(self.passbands.iter())
112113
.zip(medians.iter_mut())
113114
{

src/multicolor/features/color_of_minimum.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ where
9999
{
100100
let mut minima = [T::zero(); 2];
101101
for ((_passband, mcts), minimum) in mcts
102+
.mapping_mut()
102103
.iter_matched_passbands_mut(self.passbands.iter())
103104
.zip(minima.iter_mut())
104105
{

src/multicolor/monochrome_feature.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ where
116116
{
117117
match &self.passband_set {
118118
PassbandSet::FixedSet(set) => {
119-
mcts.iter_matched_passbands_mut(set.iter())
119+
mcts.mapping_mut().iter_matched_passbands_mut(set.iter())
120120
.map(|(passband, ts)| {
121121
self.feature.eval_no_ts_check(
122122
ts.expect("we checked all needed passbands are in mcts, but we still cannot find one")

src/multicolor/multicolor_evaluator.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ where
139139
'a: 'mcts,
140140
P: 'a,
141141
{
142-
mcts.iter_passband_set_mut(self.get_passband_set())
142+
mcts.mapping_mut()
143+
.iter_passband_set_mut(self.get_passband_set())
143144
.map(|(p, maybe_ts)| {
144145
maybe_ts
145146
.ok_or(InternalMctsError::InternalWrongPassbandSet)
@@ -227,16 +228,16 @@ mod tests {
227228
let passband_v_capital = MonochromePassband::new(5500e-8, "V");
228229
let passband_r_capital = MonochromePassband::new(6400e-8, "R");
229230
let mut mcts = {
230-
let mut passbands = BTreeMap::new();
231-
passbands.insert(
231+
let mut mapping = BTreeMap::new();
232+
mapping.insert(
232233
passband_b_capital.clone(),
233234
TimeSeries::new_without_weight(&t, &m),
234235
);
235-
passbands.insert(
236+
mapping.insert(
236237
passband_v_capital.clone(),
237238
TimeSeries::new_without_weight(&t, &m),
238239
);
239-
MultiColorTimeSeries::new(passbands)
240+
MultiColorTimeSeries::from_map(mapping)
240241
};
241242

242243
let feature = TestTimeMultiColorFeature {

0 commit comments

Comments
 (0)