@@ -22,6 +22,7 @@ use rayon::prelude::*;
2222use serde:: { Deserialize , Serialize } ;
2323use std:: collections:: HashMap ;
2424use std:: convert:: TryInto ;
25+ use std:: ops:: Deref ;
2526// Details of pickle support implementation
2627// ----------------------------------------
2728// [PyFeatureEvaluator] implements __getstate__ and __setstate__ required for pickle serialisation,
@@ -588,28 +589,6 @@ impl PyFeatureEvaluator {
588589 self . feature_evaluator_f64 . get_descriptions ( )
589590 }
590591
591- /// Used by pickle.load / pickle.loads
592- fn __setstate__ ( & mut self , state : Bound < PyBytes > ) -> Res < ( ) > {
593- * self = serde_pickle:: from_slice ( state. as_bytes ( ) , serde_pickle:: DeOptions :: new ( ) )
594- . map_err ( |err| {
595- Exception :: UnpicklingError ( format ! (
596- r#"Error happened on the Rust side when deserializing _FeatureEvaluator: "{err}""#
597- ) )
598- } ) ?;
599- Ok ( ( ) )
600- }
601-
602- /// Used by pickle.dump / pickle.dumps
603- fn __getstate__ < ' py > ( & self , py : Python < ' py > ) -> Res < Bound < ' py , PyBytes > > {
604- let vec_bytes =
605- serde_pickle:: to_vec ( & self , serde_pickle:: SerOptions :: new ( ) ) . map_err ( |err| {
606- Exception :: PicklingError ( format ! (
607- r#"Error happened on the Rust side when serializing _FeatureEvaluator: "{err}""#
608- ) )
609- } ) ?;
610- Ok ( PyBytes :: new ( py, & vec_bytes) )
611- }
612-
613592 /// Used by copy.copy
614593 fn __copy__ ( & self ) -> Self {
615594 self . clone ( )
@@ -621,9 +600,43 @@ impl PyFeatureEvaluator {
621600 }
622601}
623602
603+ macro_rules! impl_pickle_serialisation {
604+ ( $name: ident) => {
605+ #[ pymethods]
606+ impl $name {
607+ /// Used by pickle.load / pickle.loads
608+ fn __setstate__( mut slf: PyRefMut <' _, Self >, state: Bound <PyBytes >) -> Res <( ) > {
609+ let ( super_rust, self_rust) : ( PyFeatureEvaluator , Self ) = serde_pickle:: from_slice( state. as_bytes( ) , serde_pickle:: DeOptions :: new( ) )
610+ . map_err( |err| {
611+ Exception :: UnpicklingError ( format!(
612+ r#"Error happened on the Rust side when deserializing _FeatureEvaluator: "{err}""#
613+ ) )
614+ } ) ?;
615+ * slf. as_mut( ) = super_rust;
616+ * slf = self_rust;
617+ Ok ( ( ) )
618+ }
619+
620+ /// Used by pickle.dump / pickle.dumps
621+ fn __getstate__<' py>( slf: PyRef <' py, Self >) -> Res <Bound <' py, PyBytes >> {
622+ let supr = slf. as_super( ) ;
623+ let vec_bytes = serde_pickle:: to_vec( & ( supr. deref( ) , slf. deref( ) ) , serde_pickle:: SerOptions :: new( ) ) . map_err( |err| {
624+ Exception :: PicklingError ( format!(
625+ r#"Error happened on the Rust side when serializing _FeatureEvaluator: "{err}""#
626+ ) )
627+ } ) ?;
628+ Ok ( PyBytes :: new( slf. py( ) , & vec_bytes) )
629+ }
630+ }
631+ }
632+ }
633+
634+ #[ derive( Serialize , Deserialize ) ]
624635#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
625636pub struct Extractor { }
626637
638+ impl_pickle_serialisation ! ( Extractor ) ;
639+
627640#[ pymethods]
628641impl Extractor {
629642 #[ new]
@@ -702,11 +715,14 @@ macro_rules! impl_stock_transform {
702715
703716macro_rules! evaluator {
704717 ( $name: ident, $eval: ty, $default_transform: expr $( , ) ?) => {
718+ #[ derive( Serialize , Deserialize ) ]
705719 #[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
706720 pub struct $name { }
707721
708722 impl_stock_transform!( $name, $default_transform) ;
709723
724+ impl_pickle_serialisation!( $name) ;
725+
710726 #[ pymethods]
711727 impl $name {
712728 #[ new]
@@ -806,9 +822,12 @@ pub(crate) enum FitLnPrior {
806822
807823macro_rules! fit_evaluator {
808824 ( $name: ident, $eval: ty, $ib: ty, $transform: expr, $nparam: literal, $ln_prior_by_str: tt, $ln_prior_doc: literal $( , ) ?) => {
825+ #[ derive( Serialize , Deserialize ) ]
809826 #[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
810827 pub struct $name { }
811828
829+ impl_pickle_serialisation!( $name) ;
830+
812831 impl $name {
813832 fn supported_algorithms_str( ) -> String {
814833 return SUPPORTED_ALGORITHMS_CURVE_FIT . join( ", " ) ;
@@ -1158,10 +1177,12 @@ evaluator!(
11581177 StockTransformer :: Lg
11591178) ;
11601179
1180+ #[ derive( Serialize , Deserialize ) ]
11611181#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
11621182pub struct BeyondNStd { }
11631183
11641184impl_stock_transform ! ( BeyondNStd , StockTransformer :: Identity ) ;
1185+ impl_pickle_serialisation ! ( BeyondNStd ) ;
11651186
11661187#[ pymethods]
11671188impl BeyondNStd {
@@ -1219,9 +1240,12 @@ fit_evaluator!(
12191240 "'no': no prior" ,
12201241) ;
12211242
1243+ #[ derive( Serialize , Deserialize ) ]
12221244#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
12231245pub struct Bins { }
12241246
1247+ impl_pickle_serialisation ! ( Bins ) ;
1248+
12251249#[ pymethods]
12261250impl Bins {
12271251 #[ new]
@@ -1318,10 +1342,12 @@ evaluator!(
13181342 StockTransformer :: Identity
13191343) ;
13201344
1345+ #[ derive( Serialize , Deserialize ) ]
13211346#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
13221347pub struct InterPercentileRange { }
13231348
13241349impl_stock_transform ! ( InterPercentileRange , StockTransformer :: Identity ) ;
1350+ impl_pickle_serialisation ! ( InterPercentileRange ) ;
13251351
13261352#[ pymethods]
13271353impl InterPercentileRange {
@@ -1385,10 +1411,12 @@ fit_evaluator!(
13851411 "'no': no prior" ,
13861412) ;
13871413
1414+ #[ derive( Serialize , Deserialize ) ]
13881415#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
13891416pub struct MagnitudePercentageRatio { }
13901417
13911418impl_stock_transform ! ( MagnitudePercentageRatio , StockTransformer :: Identity ) ;
1419+ impl_pickle_serialisation ! ( MagnitudePercentageRatio ) ;
13921420
13931421#[ pymethods]
13941422impl MagnitudePercentageRatio {
@@ -1474,10 +1502,12 @@ evaluator!(
14741502 StockTransformer :: Identity
14751503) ;
14761504
1505+ #[ derive( Serialize , Deserialize ) ]
14771506#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
14781507pub struct MedianBufferRangePercentage { }
14791508
14801509impl_stock_transform ! ( MedianBufferRangePercentage , StockTransformer :: Identity ) ;
1510+ impl_pickle_serialisation ! ( MedianBufferRangePercentage ) ;
14811511
14821512#[ pymethods]
14831513impl MedianBufferRangePercentage {
@@ -1526,13 +1556,15 @@ evaluator!(
15261556 StockTransformer :: Identity
15271557) ;
15281558
1559+ #[ derive( Serialize , Deserialize ) ]
15291560#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
15301561pub struct PercentDifferenceMagnitudePercentile { }
15311562
15321563impl_stock_transform ! (
15331564 PercentDifferenceMagnitudePercentile ,
15341565 StockTransformer :: ClippedLg
15351566) ;
1567+ impl_pickle_serialisation ! ( PercentDifferenceMagnitudePercentile ) ;
15361568
15371569#[ pymethods]
15381570impl PercentDifferenceMagnitudePercentile {
@@ -1588,12 +1620,15 @@ enum NyquistArgumentOfPeriodogram {
15881620 Float ( f32 ) ,
15891621}
15901622
1623+ #[ derive( Serialize , Deserialize ) ]
15911624#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
15921625pub struct Periodogram {
15931626 eval_f32 : LcfPeriodogram < f32 > ,
15941627 eval_f64 : LcfPeriodogram < f64 > ,
15951628}
15961629
1630+ impl_pickle_serialisation ! ( Periodogram ) ;
1631+
15971632impl Periodogram {
15981633 fn create_evals (
15991634 peaks : Option < usize > ,
@@ -2005,9 +2040,12 @@ evaluator!(
20052040 StockTransformer :: Identity
20062041) ;
20072042
2043+ #[ derive( Serialize , Deserialize ) ]
20082044#[ pyclass( extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
20092045pub struct OtsuSplit { }
20102046
2047+ impl_pickle_serialisation ! ( OtsuSplit ) ;
2048+
20112049#[ pymethods]
20122050impl OtsuSplit {
20132051 #[ new]
@@ -2066,9 +2104,12 @@ evaluator!(
20662104) ;
20672105
20682106/// Feature evaluator deserialized from JSON string
2107+ #[ derive( Serialize , Deserialize ) ]
20692108#[ pyclass( name = "JSONDeserializedFeature" , extends = PyFeatureEvaluator , module="light_curve.light_curve_ext" ) ]
20702109pub struct JsonDeserializedFeature { }
20712110
2111+ impl_pickle_serialisation ! ( JsonDeserializedFeature ) ;
2112+
20722113#[ pymethods]
20732114impl JsonDeserializedFeature {
20742115 #[ new]
0 commit comments