Skip to content

Commit a190d64

Browse files
authored
Merge pull request #532 from light-curve/periodogram-power-pickle
Fix Periodogram pickling
2 parents 2a34273 + 12546c7 commit a190d64

File tree

3 files changed

+164
-29
lines changed

3 files changed

+164
-29
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2525

2626
### Fixed
2727

28-
--
28+
- A problem with pickling of `Periodogram` which caused wrong results from `.power` and `.freq_power` for a deserialized
29+
object https://github.com/light-curve/light-curve-python/pull/532
2930

3031
### Security
3132

light-curve/src/features.rs

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use rayon::prelude::*;
2222
use serde::{Deserialize, Serialize};
2323
use std::collections::HashMap;
2424
use std::convert::TryInto;
25+
use std::ops::Deref;
2526
// Details of pickle support implementation
2627
// ----------------------------------------
2728
// [PyFeatureEvaluator] implements __getstate__ and __setstate__ required for pickle serialisation,
@@ -588,28 +589,6 @@ impl PyFeatureEvaluator {
588589
self.feature_evaluator_f64.get_descriptions()
589590
}
590591

591-
/// Used by pickle.load / pickle.loads
592-
fn __setstate__(&mut self, state: Bound<PyBytes>) -> Res<()> {
593-
*self = serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
594-
.map_err(|err| {
595-
Exception::UnpicklingError(format!(
596-
r#"Error happened on the Rust side when deserializing _FeatureEvaluator: "{err}""#
597-
))
598-
})?;
599-
Ok(())
600-
}
601-
602-
/// Used by pickle.dump / pickle.dumps
603-
fn __getstate__<'py>(&self, py: Python<'py>) -> Res<Bound<'py, PyBytes>> {
604-
let vec_bytes =
605-
serde_pickle::to_vec(&self, serde_pickle::SerOptions::new()).map_err(|err| {
606-
Exception::PicklingError(format!(
607-
r#"Error happened on the Rust side when serializing _FeatureEvaluator: "{err}""#
608-
))
609-
})?;
610-
Ok(PyBytes::new(py, &vec_bytes))
611-
}
612-
613592
/// Used by copy.copy
614593
fn __copy__(&self) -> Self {
615594
self.clone()
@@ -621,9 +600,43 @@ impl PyFeatureEvaluator {
621600
}
622601
}
623602

603+
macro_rules! impl_pickle_serialisation {
604+
($name: ident) => {
605+
#[pymethods]
606+
impl $name {
607+
/// Used by pickle.load / pickle.loads
608+
fn __setstate__(mut slf: PyRefMut<'_, Self>, state: Bound<PyBytes>) -> Res<()> {
609+
let (super_rust, self_rust): (PyFeatureEvaluator, Self) = serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
610+
.map_err(|err| {
611+
Exception::UnpicklingError(format!(
612+
r#"Error happened on the Rust side when deserializing _FeatureEvaluator: "{err}""#
613+
))
614+
})?;
615+
*slf.as_mut() = super_rust;
616+
*slf = self_rust;
617+
Ok(())
618+
}
619+
620+
/// Used by pickle.dump / pickle.dumps
621+
fn __getstate__<'py>(slf: PyRef<'py, Self>) -> Res<Bound<'py, PyBytes>> {
622+
let supr = slf.as_super();
623+
let vec_bytes = serde_pickle::to_vec(&(supr.deref(), slf.deref()), serde_pickle::SerOptions::new()).map_err(|err| {
624+
Exception::PicklingError(format!(
625+
r#"Error happened on the Rust side when serializing _FeatureEvaluator: "{err}""#
626+
))
627+
})?;
628+
Ok(PyBytes::new(slf.py(), &vec_bytes))
629+
}
630+
}
631+
}
632+
}
633+
634+
#[derive(Serialize, Deserialize)]
624635
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
625636
pub struct Extractor {}
626637

638+
impl_pickle_serialisation!(Extractor);
639+
627640
#[pymethods]
628641
impl Extractor {
629642
#[new]
@@ -702,11 +715,14 @@ macro_rules! impl_stock_transform {
702715

703716
macro_rules! evaluator {
704717
($name: ident, $eval: ty, $default_transform: expr $(,)?) => {
718+
#[derive(Serialize, Deserialize)]
705719
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
706720
pub struct $name {}
707721

708722
impl_stock_transform!($name, $default_transform);
709723

724+
impl_pickle_serialisation!($name);
725+
710726
#[pymethods]
711727
impl $name {
712728
#[new]
@@ -806,9 +822,12 @@ pub(crate) enum FitLnPrior {
806822

807823
macro_rules! fit_evaluator {
808824
($name: ident, $eval: ty, $ib: ty, $transform: expr, $nparam: literal, $ln_prior_by_str: tt, $ln_prior_doc: literal $(,)?) => {
825+
#[derive(Serialize, Deserialize)]
809826
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
810827
pub struct $name {}
811828

829+
impl_pickle_serialisation!($name);
830+
812831
impl $name {
813832
fn supported_algorithms_str() -> String {
814833
return SUPPORTED_ALGORITHMS_CURVE_FIT.join(", ");
@@ -1051,7 +1070,7 @@ macro_rules! fit_evaluator {
10511070
Number of Ceres iterations, default is {niter}
10521071
ceres_loss_reg : float, optional
10531072
Ceres loss regularization, default is to use square norm as is, if set to
1054-
a number, the loss function is reqgualized to descriminate outlier
1073+
a number, the loss function is regularized to descriminate outlier
10551074
residuals larger than this value.
10561075
Default is None which means no regularization.
10571076
"#,
@@ -1158,10 +1177,12 @@ evaluator!(
11581177
StockTransformer::Lg
11591178
);
11601179

1180+
#[derive(Serialize, Deserialize)]
11611181
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
11621182
pub struct BeyondNStd {}
11631183

11641184
impl_stock_transform!(BeyondNStd, StockTransformer::Identity);
1185+
impl_pickle_serialisation!(BeyondNStd);
11651186

11661187
#[pymethods]
11671188
impl BeyondNStd {
@@ -1219,9 +1240,12 @@ fit_evaluator!(
12191240
"'no': no prior",
12201241
);
12211242

1243+
#[derive(Serialize, Deserialize)]
12221244
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
12231245
pub struct Bins {}
12241246

1247+
impl_pickle_serialisation!(Bins);
1248+
12251249
#[pymethods]
12261250
impl Bins {
12271251
#[new]
@@ -1318,10 +1342,12 @@ evaluator!(
13181342
StockTransformer::Identity
13191343
);
13201344

1345+
#[derive(Serialize, Deserialize)]
13211346
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
13221347
pub struct InterPercentileRange {}
13231348

13241349
impl_stock_transform!(InterPercentileRange, StockTransformer::Identity);
1350+
impl_pickle_serialisation!(InterPercentileRange);
13251351

13261352
#[pymethods]
13271353
impl InterPercentileRange {
@@ -1385,10 +1411,12 @@ fit_evaluator!(
13851411
"'no': no prior",
13861412
);
13871413

1414+
#[derive(Serialize, Deserialize)]
13881415
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
13891416
pub struct MagnitudePercentageRatio {}
13901417

13911418
impl_stock_transform!(MagnitudePercentageRatio, StockTransformer::Identity);
1419+
impl_pickle_serialisation!(MagnitudePercentageRatio);
13921420

13931421
#[pymethods]
13941422
impl MagnitudePercentageRatio {
@@ -1474,10 +1502,12 @@ evaluator!(
14741502
StockTransformer::Identity
14751503
);
14761504

1505+
#[derive(Serialize, Deserialize)]
14771506
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
14781507
pub struct MedianBufferRangePercentage {}
14791508

14801509
impl_stock_transform!(MedianBufferRangePercentage, StockTransformer::Identity);
1510+
impl_pickle_serialisation!(MedianBufferRangePercentage);
14811511

14821512
#[pymethods]
14831513
impl MedianBufferRangePercentage {
@@ -1526,13 +1556,15 @@ evaluator!(
15261556
StockTransformer::Identity
15271557
);
15281558

1559+
#[derive(Serialize, Deserialize)]
15291560
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
15301561
pub struct PercentDifferenceMagnitudePercentile {}
15311562

15321563
impl_stock_transform!(
15331564
PercentDifferenceMagnitudePercentile,
15341565
StockTransformer::ClippedLg
15351566
);
1567+
impl_pickle_serialisation!(PercentDifferenceMagnitudePercentile);
15361568

15371569
#[pymethods]
15381570
impl PercentDifferenceMagnitudePercentile {
@@ -1588,12 +1620,15 @@ enum NyquistArgumentOfPeriodogram {
15881620
Float(f32),
15891621
}
15901622

1623+
#[derive(Serialize, Deserialize)]
15911624
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
15921625
pub struct Periodogram {
15931626
eval_f32: LcfPeriodogram<f32>,
15941627
eval_f64: LcfPeriodogram<f64>,
15951628
}
15961629

1630+
impl_pickle_serialisation!(Periodogram);
1631+
15971632
impl Periodogram {
15981633
fn create_evals(
15991634
peaks: Option<usize>,
@@ -2005,9 +2040,12 @@ evaluator!(
20052040
StockTransformer::Identity
20062041
);
20072042

2043+
#[derive(Serialize, Deserialize)]
20082044
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
20092045
pub struct OtsuSplit {}
20102046

2047+
impl_pickle_serialisation!(OtsuSplit);
2048+
20112049
#[pymethods]
20122050
impl OtsuSplit {
20132051
#[new]
@@ -2066,9 +2104,12 @@ evaluator!(
20662104
);
20672105

20682106
/// Feature evaluator deserialized from JSON string
2107+
#[derive(Serialize, Deserialize)]
20692108
#[pyclass(name = "JSONDeserializedFeature", extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
20702109
pub struct JsonDeserializedFeature {}
20712110

2111+
impl_pickle_serialisation!(JsonDeserializedFeature);
2112+
20722113
#[pymethods]
20732114
impl JsonDeserializedFeature {
20742115
#[new]

0 commit comments

Comments
 (0)