Skip to content

Commit a627526

Browse files
committed
Support cast with freq_power and model
1 parent c472fa2 commit a627526

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

light-curve/src/features.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ const SUPPORTED_ALGORITHMS_CURVE_FIT: [&str; N_ALGO_CURVE_FIT] = [
775775
];
776776

777777
macro_const! {
778-
const FIT_METHOD_MODEL_DOC: &str = r#"model(t, params)
778+
const FIT_METHOD_MODEL_DOC: &str = r#"model(t, params, *, cast=False)
779779
Underlying parametric model function
780780
781781
Parameters
@@ -786,6 +786,8 @@ macro_const! {
786786
Parameters of the model, this array can be longer than actual parameter
787787
list, the beginning part of the array will be used in this case, see
788788
Examples section in the class documentation.
789+
cast : bool, optional
790+
Cast inputs to np.ndarray of the same dtype
789791
790792
Returns
791793
-------
@@ -1024,14 +1026,16 @@ macro_rules! fit_evaluator {
10241026

10251027
#[doc = FIT_METHOD_MODEL_DOC!()]
10261028
#[staticmethod]
1029+
#[pyo3(signature = (t, params, *, cast=false))]
10271030
fn model<'py>(
10281031
py: Python<'py>,
10291032
t: Bound<'py, PyAny>,
10301033
params: Bound<'py, PyAny>,
1034+
cast: bool
10311035
) -> Res<Bound<'py, PyUntypedArray>> {
10321036
dtype_dispatch!({
10331037
|t, params| Ok(Self::model_impl(t, params).into_pyarray(py).as_untyped().clone())
1034-
}(t, !=params))
1038+
}(t, !=params; cast=cast))
10351039
}
10361040

10371041
#[classattr]
@@ -1715,17 +1719,20 @@ impl Periodogram {
17151719
}
17161720

17171721
/// Angular frequencies and periodogram values
1722+
#[pyo3(signature = (t, m, *, cast=false))]
17181723
fn freq_power<'py>(
17191724
&self,
17201725
py: Python<'py>,
17211726
t: Bound<PyAny>,
17221727
m: Bound<PyAny>,
1728+
cast: bool,
17231729
) -> Res<(Bound<'py, PyUntypedArray>, Bound<'py, PyUntypedArray>)> {
17241730
dtype_dispatch!(
17251731
|t, m| Ok(Self::freq_power_impl(&self.eval_f32, py, t, m)),
17261732
|t, m| Ok(Self::freq_power_impl(&self.eval_f64, py, t, m)),
17271733
t,
1728-
=m
1734+
=m;
1735+
cast=cast
17291736
)
17301737
}
17311738

@@ -1762,7 +1769,7 @@ transform : None, optional
17621769
constructors
17631770
17641771
{common}
1765-
freq_power(t, m)
1772+
freq_power(t, m, *, cast=False)
17661773
Get periodogram
17671774
17681775
Parameters
@@ -1771,6 +1778,8 @@ freq_power(t, m)
17711778
Time array
17721779
m : np.ndarray of np.float32 or np.float64
17731780
Magnitude (flux) array
1781+
cast : bool, optional
1782+
Cast inputs to np.ndarray objects of the same dtype
17741783
17751784
Returns
17761785
-------

light-curve/src/np_array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ macro_rules! dtype_dispatch {
288288
($func:tt ($first_arg:expr $(,$eq:tt $arg:expr)* $(,)?)) => {
289289
dtype_dispatch!($func, $func, $first_arg $(,$eq $arg)*)
290290
};
291-
($func:tt ($first_arg:expr $(,$eq:tt $arg:expr)*, cast=$cast:expr $(,)?)) => {
291+
($func:tt ($first_arg:expr $(,$eq:tt $arg:expr)*; cast=$cast:expr $(,)?)) => {
292292
dtype_dispatch!($func, $func, $first_arg $(,$eq $arg)*; cast=$cast)
293293
};
294294
($f32:expr, $f64:expr, $first_arg:expr $(,$eq:tt $arg:expr)* $(,)?) => {

0 commit comments

Comments
 (0)