@@ -22,6 +22,7 @@ use rayon::prelude::*;
22
22
use serde:: { Deserialize , Serialize } ;
23
23
use std:: collections:: HashMap ;
24
24
use std:: convert:: TryInto ;
25
+ use std:: ops:: Deref ;
25
26
// Details of pickle support implementation
26
27
// ----------------------------------------
27
28
// [PyFeatureEvaluator] implements __getstate__ and __setstate__ required for pickle serialisation,
@@ -588,28 +589,6 @@ impl PyFeatureEvaluator {
588
589
self . feature_evaluator_f64 . get_descriptions ( )
589
590
}
590
591
591
- /// Used by pickle.load / pickle.loads
592
- fn __setstate__ ( & mut self , state : Bound < PyBytes > ) -> Res < ( ) > {
593
- * self = serde_pickle:: from_slice ( state. as_bytes ( ) , serde_pickle:: DeOptions :: new ( ) )
594
- . map_err ( |err| {
595
- Exception :: UnpicklingError ( format ! (
596
- r#"Error happened on the Rust side when deserializing _FeatureEvaluator: "{err}""#
597
- ) )
598
- } ) ?;
599
- Ok ( ( ) )
600
- }
601
-
602
- /// Used by pickle.dump / pickle.dumps
603
- fn __getstate__ < ' py > ( & self , py : Python < ' py > ) -> Res < Bound < ' py , PyBytes > > {
604
- let vec_bytes =
605
- serde_pickle:: to_vec ( & self , serde_pickle:: SerOptions :: new ( ) ) . map_err ( |err| {
606
- Exception :: PicklingError ( format ! (
607
- r#"Error happened on the Rust side when serializing _FeatureEvaluator: "{err}""#
608
- ) )
609
- } ) ?;
610
- Ok ( PyBytes :: new ( py, & vec_bytes) )
611
- }
612
-
613
592
/// Used by copy.copy
614
593
fn __copy__ ( & self ) -> Self {
615
594
self . clone ( )
@@ -621,9 +600,43 @@ impl PyFeatureEvaluator {
621
600
}
622
601
}
623
602
603
+ macro_rules! impl_pickle_serialisation {
604
+ ( $name: ident) => {
605
+ #[ pymethods]
606
+ impl $name {
607
+ /// Used by pickle.load / pickle.loads
608
+ fn __setstate__( mut slf: PyRefMut <' _, Self >, state: Bound <PyBytes >) -> Res <( ) > {
609
+ let ( super_rust, self_rust) : ( PyFeatureEvaluator , Self ) = serde_pickle:: from_slice( state. as_bytes( ) , serde_pickle:: DeOptions :: new( ) )
610
+ . map_err( |err| {
611
+ Exception :: UnpicklingError ( format!(
612
+ r#"Error happened on the Rust side when deserializing _FeatureEvaluator: "{err}""#
613
+ ) )
614
+ } ) ?;
615
+ * slf. as_mut( ) = super_rust;
616
+ * slf = self_rust;
617
+ Ok ( ( ) )
618
+ }
619
+
620
+ /// Used by pickle.dump / pickle.dumps
621
+ fn __getstate__<' py>( slf: PyRef <' py, Self >) -> Res <Bound <' py, PyBytes >> {
622
+ let supr = slf. as_super( ) ;
623
+ let vec_bytes = serde_pickle:: to_vec( & ( supr. deref( ) , slf. deref( ) ) , serde_pickle:: SerOptions :: new( ) ) . map_err( |err| {
624
+ Exception :: PicklingError ( format!(
625
+ r#"Error happened on the Rust side when serializing _FeatureEvaluator: "{err}""#
626
+ ) )
627
+ } ) ?;
628
+ Ok ( PyBytes :: new( slf. py( ) , & vec_bytes) )
629
+ }
630
+ }
631
+ }
632
+ }
633
+
634
+ #[ derive( Serialize , Deserialize ) ]
624
635
#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
625
636
pub struct Extractor { }
626
637
638
+ impl_pickle_serialisation ! ( Extractor ) ;
639
+
627
640
#[ pymethods]
628
641
impl Extractor {
629
642
#[ new]
@@ -702,11 +715,14 @@ macro_rules! impl_stock_transform {
702
715
703
716
macro_rules! evaluator {
704
717
( $name: ident, $eval: ty, $default_transform: expr $( , ) ?) => {
718
+ #[ derive( Serialize , Deserialize ) ]
705
719
#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
706
720
pub struct $name { }
707
721
708
722
impl_stock_transform!( $name, $default_transform) ;
709
723
724
+ impl_pickle_serialisation!( $name) ;
725
+
710
726
#[ pymethods]
711
727
impl $name {
712
728
#[ new]
@@ -806,9 +822,12 @@ pub(crate) enum FitLnPrior {
806
822
807
823
macro_rules! fit_evaluator {
808
824
( $name: ident, $eval: ty, $ib: ty, $transform: expr, $nparam: literal, $ln_prior_by_str: tt, $ln_prior_doc: literal $( , ) ?) => {
825
+ #[ derive( Serialize , Deserialize ) ]
809
826
#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
810
827
pub struct $name { }
811
828
829
+ impl_pickle_serialisation!( $name) ;
830
+
812
831
impl $name {
813
832
fn supported_algorithms_str( ) -> String {
814
833
return SUPPORTED_ALGORITHMS_CURVE_FIT . join( ", " ) ;
@@ -1051,7 +1070,7 @@ macro_rules! fit_evaluator {
1051
1070
Number of Ceres iterations, default is {niter}
1052
1071
ceres_loss_reg : float, optional
1053
1072
Ceres loss regularization, default is to use square norm as is, if set to
1054
- a number, the loss function is reqgualized to descriminate outlier
1073
+ a number, the loss function is regularized to descriminate outlier
1055
1074
residuals larger than this value.
1056
1075
Default is None which means no regularization.
1057
1076
"# ,
@@ -1158,10 +1177,12 @@ evaluator!(
1158
1177
StockTransformer :: Lg
1159
1178
) ;
1160
1179
1180
+ #[ derive( Serialize , Deserialize ) ]
1161
1181
#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
1162
1182
pub struct BeyondNStd { }
1163
1183
1164
1184
impl_stock_transform ! ( BeyondNStd , StockTransformer :: Identity ) ;
1185
+ impl_pickle_serialisation ! ( BeyondNStd ) ;
1165
1186
1166
1187
#[ pymethods]
1167
1188
impl BeyondNStd {
@@ -1219,9 +1240,12 @@ fit_evaluator!(
1219
1240
"'no': no prior" ,
1220
1241
) ;
1221
1242
1243
+ #[ derive( Serialize , Deserialize ) ]
1222
1244
#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
1223
1245
pub struct Bins { }
1224
1246
1247
+ impl_pickle_serialisation ! ( Bins ) ;
1248
+
1225
1249
#[ pymethods]
1226
1250
impl Bins {
1227
1251
#[ new]
@@ -1318,10 +1342,12 @@ evaluator!(
1318
1342
StockTransformer :: Identity
1319
1343
) ;
1320
1344
1345
+ #[ derive( Serialize , Deserialize ) ]
1321
1346
#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
1322
1347
pub struct InterPercentileRange { }
1323
1348
1324
1349
impl_stock_transform ! ( InterPercentileRange , StockTransformer :: Identity ) ;
1350
+ impl_pickle_serialisation ! ( InterPercentileRange ) ;
1325
1351
1326
1352
#[ pymethods]
1327
1353
impl InterPercentileRange {
@@ -1385,10 +1411,12 @@ fit_evaluator!(
1385
1411
"'no': no prior" ,
1386
1412
) ;
1387
1413
1414
+ #[ derive( Serialize , Deserialize ) ]
1388
1415
#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
1389
1416
pub struct MagnitudePercentageRatio { }
1390
1417
1391
1418
impl_stock_transform ! ( MagnitudePercentageRatio , StockTransformer :: Identity ) ;
1419
+ impl_pickle_serialisation ! ( MagnitudePercentageRatio ) ;
1392
1420
1393
1421
#[ pymethods]
1394
1422
impl MagnitudePercentageRatio {
@@ -1474,10 +1502,12 @@ evaluator!(
1474
1502
StockTransformer :: Identity
1475
1503
) ;
1476
1504
1505
+ #[ derive( Serialize , Deserialize ) ]
1477
1506
#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
1478
1507
pub struct MedianBufferRangePercentage { }
1479
1508
1480
1509
impl_stock_transform ! ( MedianBufferRangePercentage , StockTransformer :: Identity ) ;
1510
+ impl_pickle_serialisation ! ( MedianBufferRangePercentage ) ;
1481
1511
1482
1512
#[ pymethods]
1483
1513
impl MedianBufferRangePercentage {
@@ -1526,13 +1556,15 @@ evaluator!(
1526
1556
StockTransformer :: Identity
1527
1557
) ;
1528
1558
1559
+ #[ derive( Serialize , Deserialize ) ]
1529
1560
#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
1530
1561
pub struct PercentDifferenceMagnitudePercentile { }
1531
1562
1532
1563
impl_stock_transform ! (
1533
1564
PercentDifferenceMagnitudePercentile ,
1534
1565
StockTransformer :: ClippedLg
1535
1566
) ;
1567
+ impl_pickle_serialisation ! ( PercentDifferenceMagnitudePercentile ) ;
1536
1568
1537
1569
#[ pymethods]
1538
1570
impl PercentDifferenceMagnitudePercentile {
@@ -1588,12 +1620,15 @@ enum NyquistArgumentOfPeriodogram {
1588
1620
Float ( f32 ) ,
1589
1621
}
1590
1622
1623
+ #[ derive( Serialize , Deserialize ) ]
1591
1624
#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
1592
1625
pub struct Periodogram {
1593
1626
eval_f32 : LcfPeriodogram < f32 > ,
1594
1627
eval_f64 : LcfPeriodogram < f64 > ,
1595
1628
}
1596
1629
1630
+ impl_pickle_serialisation ! ( Periodogram ) ;
1631
+
1597
1632
impl Periodogram {
1598
1633
fn create_evals (
1599
1634
peaks : Option < usize > ,
@@ -2005,9 +2040,12 @@ evaluator!(
2005
2040
StockTransformer :: Identity
2006
2041
) ;
2007
2042
2043
+ #[ derive( Serialize , Deserialize ) ]
2008
2044
#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
2009
2045
pub struct OtsuSplit { }
2010
2046
2047
+ impl_pickle_serialisation ! ( OtsuSplit ) ;
2048
+
2011
2049
#[ pymethods]
2012
2050
impl OtsuSplit {
2013
2051
#[ new]
@@ -2066,9 +2104,12 @@ evaluator!(
2066
2104
) ;
2067
2105
2068
2106
/// Feature evaluator deserialized from JSON string
2107
+ #[ derive( Serialize , Deserialize ) ]
2069
2108
#[ pyclass( name = "JSONDeserializedFeature" , extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
2070
2109
pub struct JsonDeserializedFeature { }
2071
2110
2111
+ impl_pickle_serialisation ! ( JsonDeserializedFeature ) ;
2112
+
2072
2113
#[ pymethods]
2073
2114
impl JsonDeserializedFeature {
2074
2115
#[ new]
0 commit comments