Skip to content

Commit 4f8d081

Browse files
committed
More variants converters for MCTS
1 parent b749f41 commit 4f8d081

File tree

4 files changed

+229
-13
lines changed

4 files changed

+229
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- `m_chi2` attribute and `get_m_chi2` method for `TimeSeries`
13+
- `take_mut` dependency
1314

1415
### Changed
1516

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ num-traits = "^0.2"
5454
paste = "1"
5555
schemars = "^0.8"
5656
serde = { version = "1", features = ["derive"] }
57+
take_mut = "0.2.2"
5758
thiserror = "1"
5859
thread_local = "1.1"
5960
unzip3 = "1"

src/data/multi_color_time_series.rs

Lines changed: 224 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use itertools::Itertools;
1010
use std::collections::{BTreeMap, BTreeSet};
1111
use std::ops::{Deref, DerefMut};
1212

13+
#[derive(Clone, Debug)]
1314
pub enum MultiColorTimeSeries<'a, P: PassbandTrait, T: Float> {
1415
Mapping(MappedMultiColorTimeSeries<'a, P, T>),
1516
Flat(FlatMultiColorTimeSeries<'a, P, T>),
@@ -40,6 +41,15 @@ where
4041
}
4142
}
4243

44+
pub fn passband_count(&self) -> usize {
45+
match self {
46+
Self::Mapping(mapping) => mapping.passband_count(),
47+
Self::Flat(flat) => flat.passband_count(),
48+
// Both flat and mapping have the same number of passbands and should be equally fast
49+
Self::MappingFlat { flat, .. } => flat.passband_count(),
50+
}
51+
}
52+
4353
pub fn from_map(map: impl Into<BTreeMap<P, TimeSeries<'a, T>>>) -> Self {
4454
Self::Mapping(MappedMultiColorTimeSeries::new(map))
4555
}
@@ -53,21 +63,43 @@ where
5363
Self::Flat(FlatMultiColorTimeSeries::new(t, m, w, passbands))
5464
}
5565

56-
pub fn mapping_mut(&mut self) -> &mut MappedMultiColorTimeSeries<'a, P, T> {
66+
fn ensure_mapping(&mut self) -> &mut Self {
5767
if matches!(self, MultiColorTimeSeries::Flat(_)) {
58-
let dummy_self = Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new()));
59-
*self = match std::mem::replace(self, dummy_self) {
68+
take_mut::take(self, |slf| match slf {
6069
Self::Flat(mut flat) => {
6170
let mapping = MappedMultiColorTimeSeries::from_flat(&mut flat);
6271
Self::MappingFlat { mapping, flat }
6372
}
64-
_ => unreachable!(),
65-
}
73+
_ => unreachable!("We just checked that we are in ::Flat variant"),
74+
});
6675
}
76+
self
77+
}
78+
79+
fn enforce_mapping(&mut self) -> &mut Self {
6780
match self {
81+
Self::Mapping(_) => {}
82+
Self::Flat(_flat) => take_mut::take(self, |slf| match slf {
83+
Self::Flat(flat) => Self::Mapping(flat.into()),
84+
_ => unreachable!("We just checked that we are in ::Flat variant"),
85+
}),
86+
Self::MappingFlat { .. } => {
87+
take_mut::take(self, |slf| match slf {
88+
Self::MappingFlat { mapping, .. } => Self::Mapping(mapping),
89+
_ => unreachable!("We just checked that we are in ::MappingFlat variant"),
90+
});
91+
}
92+
}
93+
self
94+
}
95+
96+
pub fn mapping_mut(&mut self) -> &mut MappedMultiColorTimeSeries<'a, P, T> {
97+
match self.ensure_mapping() {
6898
Self::Mapping(mapping) => mapping,
6999
Self::Flat(_flat) => {
70-
unreachable!("::Flat variant is already transofrmed to ::MappingFlat")
100+
unreachable!(
101+
"::Flat variant is already transformed to ::MappingFlat in ensure_mapping"
102+
)
71103
}
72104
Self::MappingFlat { mapping, .. } => mapping,
73105
}
@@ -81,20 +113,25 @@ where
81113
}
82114
}
83115

84-
pub fn flat_mut(&mut self) -> &mut FlatMultiColorTimeSeries<'a, P, T> {
116+
fn ensure_flat(&mut self) -> &mut Self {
85117
if matches!(self, MultiColorTimeSeries::Mapping(_)) {
86-
let dummy_self = Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new()));
87-
*self = match std::mem::replace(self, dummy_self) {
118+
take_mut::take(self, |slf| match slf {
88119
Self::Mapping(mut mapping) => {
89120
let flat = FlatMultiColorTimeSeries::from_mapping(&mut mapping);
90121
Self::MappingFlat { mapping, flat }
91122
}
92-
_ => unreachable!(),
93-
}
123+
_ => unreachable!("We just checked that we are in ::Mapping variant"),
124+
});
94125
}
95-
match self {
126+
self
127+
}
128+
129+
pub fn flat_mut(&mut self) -> &mut FlatMultiColorTimeSeries<'a, P, T> {
130+
match self.ensure_flat() {
96131
Self::Mapping(_mapping) => {
97-
unreachable!("::Mapping veriant is already transformed to ::MappingFlat")
132+
unreachable!(
133+
"::Mapping variant is already transformed to ::MappingFlat in ensure_flat"
134+
)
98135
}
99136
Self::Flat(flat) => flat,
100137
Self::MappingFlat { flat, .. } => flat,
@@ -124,12 +161,45 @@ where
124161
Self::MappingFlat { mapping, .. } => Either::Left(mapping.passbands()),
125162
}
126163
}
164+
165+
/// Inserts new pair of passband and time series into the multicolor time series.
166+
///
167+
/// It always converts [MultiColorTimeSeries] to [MultiColorTimeSeries::Mapping] variant.
168+
/// Also it replaces existing time series if passband is already present, and returns old time
169+
/// series.
170+
pub fn insert(&mut self, passband: P, ts: TimeSeries<'a, T>) -> Option<TimeSeries<'a, T>> {
171+
match self.enforce_mapping() {
172+
Self::Mapping(mapping) => mapping.0.insert(passband, ts),
173+
_ => unreachable!("We just converted self to ::Mapping variant"),
174+
}
175+
}
127176
}
128177

178+
impl<'a, P, T> Default for MultiColorTimeSeries<'a, P, T>
179+
where
180+
P: PassbandTrait,
181+
T: Float,
182+
{
183+
fn default() -> Self {
184+
Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new()))
185+
}
186+
}
187+
188+
#[derive(Debug, Clone)]
129189
pub struct MappedMultiColorTimeSeries<'a, P: PassbandTrait, T: Float>(
130190
BTreeMap<P, TimeSeries<'a, T>>,
131191
);
132192

193+
impl<'a, P, T> PartialEq for MappedMultiColorTimeSeries<'a, P, T>
194+
where
195+
P: PassbandTrait,
196+
T: Float,
197+
{
198+
fn eq(&self, other: &Self) -> bool {
199+
self.0.eq(&other.0)
200+
}
201+
}
202+
133203
impl<'a, 'p, P, T> MappedMultiColorTimeSeries<'a, P, T>
134204
where
135205
P: PassbandTrait + 'p,
@@ -173,6 +243,10 @@ where
173243
self.total_lenu().value_as::<T>().unwrap()
174244
}
175245

246+
pub fn passband_count(&self) -> usize {
247+
self.0.len()
248+
}
249+
176250
pub fn passbands<'slf>(
177251
&'slf self,
178252
) -> std::collections::btree_map::Keys<'slf, P, TimeSeries<'a, T>>
@@ -267,6 +341,7 @@ impl<'a, P: PassbandTrait, T: Float> DerefMut for MappedMultiColorTimeSeries<'a,
267341
}
268342
}
269343

344+
#[derive(Debug, Clone)]
270345
pub struct FlatMultiColorTimeSeries<'a, P: PassbandTrait, T: Float> {
271346
pub t: DataSample<'a, T>,
272347
pub m: DataSample<'a, T>,
@@ -275,6 +350,19 @@ pub struct FlatMultiColorTimeSeries<'a, P: PassbandTrait, T: Float> {
275350
passband_set: BTreeSet<P>,
276351
}
277352

353+
impl<'a, P, T> PartialEq for FlatMultiColorTimeSeries<'a, P, T>
354+
where
355+
P: PassbandTrait,
356+
T: Float,
357+
{
358+
fn eq(&self, other: &Self) -> bool {
359+
self.t == other.t
360+
&& self.m == other.m
361+
&& self.w == other.w
362+
&& self.passbands == other.passbands
363+
}
364+
}
365+
278366
impl<'a, P, T> FlatMultiColorTimeSeries<'a, P, T>
279367
where
280368
P: PassbandTrait,
@@ -347,4 +435,127 @@ where
347435
pub fn total_lenf(&self) -> T {
348436
self.t.sample.len().value_as::<T>().unwrap()
349437
}
438+
439+
pub fn passband_count(&self) -> usize {
440+
self.passband_set.len()
441+
}
442+
}
443+
444+
impl<'a, P, T> From<FlatMultiColorTimeSeries<'a, P, T>> for MappedMultiColorTimeSeries<'a, P, T>
445+
where
446+
P: PassbandTrait,
447+
T: Float,
448+
{
449+
fn from(mut flat: FlatMultiColorTimeSeries<'a, P, T>) -> Self {
450+
Self::from_flat(&mut flat)
451+
}
452+
}
453+
454+
impl<'a, P, T> From<MappedMultiColorTimeSeries<'a, P, T>> for FlatMultiColorTimeSeries<'a, P, T>
455+
where
456+
P: PassbandTrait,
457+
T: Float,
458+
{
459+
fn from(mut mapped: MappedMultiColorTimeSeries<'a, P, T>) -> Self {
460+
Self::from_mapping(&mut mapped.0)
461+
}
462+
}
463+
464+
impl<'a, P, T> From<FlatMultiColorTimeSeries<'a, P, T>> for MultiColorTimeSeries<'a, P, T>
465+
where
466+
P: PassbandTrait,
467+
T: Float,
468+
{
469+
fn from(flat: FlatMultiColorTimeSeries<'a, P, T>) -> Self {
470+
Self::Flat(flat)
471+
}
472+
}
473+
474+
impl<'a, P, T> From<MappedMultiColorTimeSeries<'a, P, T>> for MultiColorTimeSeries<'a, P, T>
475+
where
476+
P: PassbandTrait,
477+
T: Float,
478+
{
479+
fn from(mapped: MappedMultiColorTimeSeries<'a, P, T>) -> Self {
480+
Self::Mapping(mapped)
481+
}
482+
}
483+
484+
impl<'a, P, T> From<MultiColorTimeSeries<'a, P, T>> for FlatMultiColorTimeSeries<'a, P, T>
485+
where
486+
P: PassbandTrait,
487+
T: Float,
488+
{
489+
fn from(mcts: MultiColorTimeSeries<'a, P, T>) -> Self {
490+
match mcts {
491+
MultiColorTimeSeries::Flat(flat) => flat,
492+
MultiColorTimeSeries::Mapping(mapped) => mapped.into(),
493+
MultiColorTimeSeries::MappingFlat { flat, .. } => flat,
494+
}
495+
}
496+
}
497+
498+
impl<'a, P, T> From<MultiColorTimeSeries<'a, P, T>> for MappedMultiColorTimeSeries<'a, P, T>
499+
where
500+
P: PassbandTrait,
501+
T: Float,
502+
{
503+
fn from(mcts: MultiColorTimeSeries<'a, P, T>) -> Self {
504+
match mcts {
505+
MultiColorTimeSeries::Flat(flat) => flat.into(),
506+
MultiColorTimeSeries::Mapping(mapping) => mapping,
507+
MultiColorTimeSeries::MappingFlat { mapping, .. } => mapping,
508+
}
509+
}
510+
}
511+
512+
#[cfg(test)]
513+
mod tests {
514+
use super::*;
515+
516+
use crate::MonochromePassband;
517+
518+
use ndarray::Array1;
519+
520+
#[test]
521+
fn multi_color_ts_insert() {
522+
let mut mcts = MultiColorTimeSeries::default();
523+
mcts.insert(
524+
MonochromePassband::new(4700.0, "g"),
525+
TimeSeries::new_without_weight(Array1::linspace(0.0, 1.0, 11), Array1::zeros(11)),
526+
);
527+
assert_eq!(mcts.passband_count(), 1);
528+
assert_eq!(mcts.total_lenu(), 11);
529+
mcts.insert(
530+
MonochromePassband::new(6200.0, "r"),
531+
TimeSeries::new_without_weight(Array1::linspace(0.0, 1.0, 6), Array1::zeros(6)),
532+
);
533+
assert_eq!(mcts.passband_count(), 2);
534+
assert_eq!(mcts.total_lenu(), 17);
535+
}
536+
537+
fn compare_variants<P: PassbandTrait, T: Float>(mcts: MultiColorTimeSeries<P, T>) {
538+
let flat: FlatMultiColorTimeSeries<_, _> = mcts.clone().into();
539+
let mapped: MappedMultiColorTimeSeries<_, _> = mcts.clone().into();
540+
let mapped_from_flat: MappedMultiColorTimeSeries<_, _> = flat.clone().into();
541+
let flat_from_mapped: FlatMultiColorTimeSeries<_, _> = mapped.clone().into();
542+
assert_eq!(mapped, mapped_from_flat);
543+
assert_eq!(flat, flat_from_mapped);
544+
}
545+
546+
#[test]
547+
fn convert_between_variants() {
548+
let mut mcts = MultiColorTimeSeries::default();
549+
compare_variants(mcts.clone());
550+
mcts.insert(
551+
MonochromePassband::new(4700.0, "g"),
552+
TimeSeries::new_without_weight(Array1::linspace(0.0, 1.0, 11), Array1::zeros(11)),
553+
);
554+
compare_variants(mcts.clone());
555+
mcts.insert(
556+
MonochromePassband::new(6200.0, "r"),
557+
TimeSeries::new_without_weight(Array1::linspace(0.0, 1.0, 6), Array1::zeros(6)),
558+
);
559+
compare_variants(mcts.clone());
560+
}
350561
}

src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ pub enum MultiColorEvaluatorError {
3939

4040
#[error(r#"Underlying feature caused an error: "{0:?}""#)]
4141
UnderlyingEvaluatorError(#[from] EvaluatorError),
42+
43+
#[error("All time-series are flat")]
44+
AllTimeSeriesAreFlat,
4245
}
4346

4447
impl MultiColorEvaluatorError {

0 commit comments

Comments
 (0)