Skip to content

Commit 8139dee

Browse files
authored
Merge pull request #512 from light-curve/asarray
Cast inputs to a common dtype
2 parents e08e203 + a627526 commit 8139dee

File tree

8 files changed

+343
-113
lines changed

8 files changed

+343
-113
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- Mark the module as no-GIL, which enables free-threaded Python (can be built from source, not provided so far via
1313
PyPI/conda) https://github.com/light-curve/light-curve-python/pull/499
14+
- Allow non-numpy inputs and casting mismatched f32 arrays to f64 for the feature extractions with newly added
15+
`cast: bool = False` argument. We plan to change the default value to `True` in a future 0.x version.
16+
https://github.com/light-curve/light-curve-python/issues/509 https://github.com/light-curve/light-curve-python/pull/512
1417

1518
### Changed
1619

light-curve/Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

light-curve/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ serde = { version = "1", features = ["derive"] }
5959
serde-pickle = "1"
6060
serde_json = "1"
6161
thiserror = "2"
62+
unarray = "0.1.4"
6263
unzip3 = "1.0.0"
6364

6465
[dependencies.light-curve-dmdt]

light-curve/src/dmdt.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,21 +1142,27 @@ impl DmDt {
11421142
/// Time moments, must be sorted
11431143
/// sorted : bool or None, optional
11441144
/// `True` guarantees that `t` is sorted
1145+
/// cast : bool
1146+
/// If `False` allow np.ndarray input only, `True` allows casting.
1147+
/// Casting provides more flexibility with input types at the cost of
1148+
// performance.
11451149
///
11461150
/// Returns
11471151
/// 1d-array of float
11481152
///
1149-
#[pyo3(signature=(t, *, sorted=None))]
1153+
#[pyo3(signature=(t, *, sorted=None, cast=false))]
11501154
fn count_dt<'py>(
11511155
&self,
11521156
py: Python<'py>,
11531157
t: Bound<'py, PyAny>,
11541158
sorted: Option<bool>,
1159+
cast: bool,
11551160
) -> Res<Bound<'py, PyUntypedArray>> {
11561161
dtype_dispatch!(
11571162
|t| self.dmdt_f32.py_count_dt(py, t, sorted),
11581163
|t| self.dmdt_f64.py_count_dt(py, t, sorted),
1159-
t
1164+
t;
1165+
cast=cast
11601166
)
11611167
}
11621168

@@ -1203,24 +1209,30 @@ impl DmDt {
12031209
/// Magnitudes
12041210
/// sorted : bool or None, optional
12051211
/// `True` guarantees that the light curve is sorted
1212+
/// cast : bool
1213+
/// If `False` allow np.ndarray input only, `True` allows casting.
1214+
/// Casting provides more flexibility with input types at the cost of
1215+
// performance.
12061216
///
12071217
/// Returns
12081218
/// -------
12091219
/// 2d-ndarray of float
12101220
///
1211-
#[pyo3(signature = (t, m, *, sorted=None))]
1221+
#[pyo3(signature = (t, m, *, sorted=None, cast=false))]
12121222
fn points<'py>(
12131223
&self,
12141224
py: Python<'py>,
12151225
t: Bound<'py, PyAny>,
12161226
m: Bound<'py, PyAny>,
12171227
sorted: Option<bool>,
1228+
cast: bool,
12181229
) -> Res<Bound<'py, PyUntypedArray>> {
12191230
dtype_dispatch!(
12201231
|t, m| self.dmdt_f32.py_points(py, t, m, sorted),
12211232
|t, m| self.dmdt_f64.py_points(py, t, m, sorted),
12221233
t,
1223-
=m
1234+
=m;
1235+
cast=cast
12241236
)
12251237
}
12261238

@@ -1372,26 +1384,31 @@ impl DmDt {
13721384
/// Uncertainties
13731385
/// sorted : bool or None, optional
13741386
/// `True` guarantees that the light curve is sorted
1375-
///
1387+
/// cast : bool
1388+
/// If `False` allow np.ndarray input only, `True` allows casting.
1389+
/// Casting provides more flexibility with input types at the cost of
1390+
// performance.
13761391
/// Returns
13771392
/// -------
13781393
/// 2d-array of float
13791394
///
1380-
#[pyo3(signature = (t, m, sigma, *, sorted=None))]
1395+
#[pyo3(signature = (t, m, sigma, *, sorted=None, cast=false))]
13811396
fn gausses<'py>(
13821397
&self,
13831398
py: Python<'py>,
13841399
t: Bound<'py, PyAny>,
13851400
m: Bound<'py, PyAny>,
13861401
sigma: Bound<'py, PyAny>,
13871402
sorted: Option<bool>,
1403+
cast: bool,
13881404
) -> Res<Bound<'py, PyUntypedArray>> {
13891405
dtype_dispatch!(
13901406
|t, m, sigma| self.dmdt_f32.py_gausses(py, t, m, sigma, sorted),
13911407
|t, m, sigma| self.dmdt_f64.py_gausses(py, t, m, sigma, sorted),
13921408
t,
13931409
=m,
1394-
=sigma
1410+
=sigma;
1411+
cast=cast
13951412
)
13961413
}
13971414

light-curve/src/features.rs

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,49 +47,54 @@ names : list of str
4747
descriptions : list of str
4848
Feature descriptions"#;
4949

50-
const METHOD_CALL_DOC: &str = r#"__call__(self, t, m, sigma=None, *, sorted=None, check=True, fill_value=None)
50+
const METHOD_CALL_DOC: &str = r#"__call__(self, t, m, sigma=None, *, fill_value=None, sorted=None, check=True, cast=False)
5151
Extract features and return them as a numpy array
5252
5353
Parameters
5454
----------
5555
t : numpy.ndarray of np.float32 or np.float64 dtype
5656
Time moments
57-
m : numpy.ndarray of the same dtype and size as t
57+
m : numpy.ndarray
5858
Signal in magnitude or fluxes. Refer to the feature description to
5959
decide which would work better in your case
60-
sigma : numpy.ndarray of the same dtype and size as t, optional
60+
sigma : numpy.ndarray, optional
6161
Observation error, if None it is assumed to be unity
62+
fill_value : float or None, optional
63+
Value to fill invalid feature values, for example if count of
64+
observations is not enough to find a proper value.
65+
None causes exception for invalid features
6266
sorted : bool or None, optional
6367
Specifies if input array are sorted by time moments.
6468
True is for certainly sorted, False is for unsorted.
6569
If None is specified than sorting is checked and an exception is
6670
raised for unsorted `t`
6771
check : bool, optional
6872
Check all input arrays for NaNs, `t` and `m` for infinite values
69-
fill_value : float or None, optional
70-
Value to fill invalid feature values, for example if count of
71-
observations is not enough to find a proper value.
72-
None causes exception for invalid features
73-
73+
cast : bool, optional
74+
Allows non-numpy input and casting of arrays to a common dtype.
75+
If `False`, inputs must be `np.ndarray` instances with matched dtypes.
76+
Casting provides more flexibility with input types at the cost of
77+
performance.
7478
Returns
7579
-------
7680
ndarray of np.float32 or np.float64
7781
Extracted feature array"#;
7882

7983
macro_const! {
8084
const METHOD_MANY_DOC: &str = r#"
81-
many(self, lcs, *, sorted=None, check=True, fill_value=None, n_jobs=-1)
85+
many(self, lcs, *, fill_value=None, sorted=None, check=True, cast=False, n_jobs=-1)
8286
Parallel light curve feature extraction
8387
8488
It is a parallel executed equivalent of
85-
>>> def many(self, lcs, *, sorted=None, check=True, fill_value=None):
89+
>>> def many(self, lcs, *, fill_value=None, sorted=None, check=True):
8690
... return np.stack(
8791
... [
8892
... self(
8993
... *lc,
94+
... fill_value=fill_value,
9095
... sorted=sorted,
9196
... check=check,
92-
... fill_value=fill_value
97+
... cast=False,
9398
... )
9499
... for lc in lcs
95100
... ]
@@ -101,13 +106,13 @@ many(self, lcs, *, sorted=None, check=True, fill_value=None, n_jobs=-1)
101106
A collection of light curves packed into three-tuples, all light curves
102107
must be represented by numpy.ndarray of the same dtype. See __call__
103108
documentation for details
109+
fill_value : float or None, optional
110+
Fill invalid values by this or raise an exception if None
104111
sorted : bool or None, optional
105112
Specifies if input array are sorted by time moments, see __call__
106113
documentation for details
107114
check : bool, optional
108115
Check all input arrays for NaNs, `t` and `m` for infinite values
109-
fill_value : float or None, optional
110-
Fill invalid values by this or raise an exception if None
111116
n_jobs : int
112117
Number of tasks to run in paralell. Default is -1 which means run as
113118
many jobs as CPU count. See rayon rust crate documentation for
@@ -440,19 +445,21 @@ impl PyFeatureEvaluator {
440445
m,
441446
sigma = None,
442447
*,
448+
fill_value = None,
443449
sorted = None,
444450
check = true,
445-
fill_value = None
451+
cast = false,
446452
))]
447453
fn __call__<'py>(
448454
&self,
449455
py: Python<'py>,
450456
t: Bound<'py, PyAny>,
451457
m: Bound<'py, PyAny>,
452458
sigma: Option<Bound<'py, PyAny>>,
459+
fill_value: Option<f64>,
453460
sorted: Option<bool>,
454461
check: bool,
455-
fill_value: Option<f64>,
462+
cast: bool,
456463
) -> Res<Bound<'py, PyUntypedArray>> {
457464
if let Some(sigma) = sigma {
458465
dtype_dispatch!(
@@ -484,7 +491,8 @@ impl PyFeatureEvaluator {
484491
},
485492
t,
486493
=m,
487-
=sigma
494+
=sigma;
495+
cast=cast
488496
)
489497
} else {
490498
dtype_dispatch!(
@@ -515,20 +523,21 @@ impl PyFeatureEvaluator {
515523
)
516524
},
517525
t,
518-
=m,
526+
=m;
527+
cast=cast
519528
)
520529
}
521530
}
522531

523532
#[doc = METHOD_MANY_DOC!()]
524-
#[pyo3(signature = (lcs, *, sorted=None, check=true, fill_value=None, n_jobs=-1))]
533+
#[pyo3(signature = (lcs, *, fill_value=None, sorted=None, check=true, n_jobs=-1))]
525534
fn many<'py>(
526535
&self,
527536
py: Python<'py>,
528537
lcs: PyLcs<'py>,
538+
fill_value: Option<f64>,
529539
sorted: Option<bool>,
530540
check: bool,
531-
fill_value: Option<f64>,
532541
n_jobs: i64,
533542
) -> Res<Bound<'py, PyUntypedArray>> {
534543
if lcs.is_empty() {
@@ -766,7 +775,7 @@ const SUPPORTED_ALGORITHMS_CURVE_FIT: [&str; N_ALGO_CURVE_FIT] = [
766775
];
767776

768777
macro_const! {
769-
const FIT_METHOD_MODEL_DOC: &str = r#"model(t, params)
778+
const FIT_METHOD_MODEL_DOC: &str = r#"model(t, params, *, cast=False)
770779
Underlying parametric model function
771780
772781
Parameters
@@ -777,6 +786,8 @@ macro_const! {
777786
Parameters of the model, this array can be longer than actual parameter
778787
list, the beginning part of the array will be used in this case, see
779788
Examples section in the class documentation.
789+
cast : bool, optional
790+
Cast inputs to np.ndarray of the same dtype
780791
781792
Returns
782793
-------
@@ -1015,14 +1026,16 @@ macro_rules! fit_evaluator {
10151026

10161027
#[doc = FIT_METHOD_MODEL_DOC!()]
10171028
#[staticmethod]
1029+
#[pyo3(signature = (t, params, *, cast=false))]
10181030
fn model<'py>(
10191031
py: Python<'py>,
10201032
t: Bound<'py, PyAny>,
10211033
params: Bound<'py, PyAny>,
1034+
cast: bool
10221035
) -> Res<Bound<'py, PyUntypedArray>> {
10231036
dtype_dispatch!({
10241037
|t, params| Ok(Self::model_impl(t, params).into_pyarray(py).as_untyped().clone())
1025-
}(t, !=params))
1038+
}(t, !=params; cast=cast))
10261039
}
10271040

10281041
#[classattr]
@@ -1706,17 +1719,20 @@ impl Periodogram {
17061719
}
17071720

17081721
/// Angular frequencies and periodogram values
1722+
#[pyo3(signature = (t, m, *, cast=false))]
17091723
fn freq_power<'py>(
17101724
&self,
17111725
py: Python<'py>,
17121726
t: Bound<PyAny>,
17131727
m: Bound<PyAny>,
1728+
cast: bool,
17141729
) -> Res<(Bound<'py, PyUntypedArray>, Bound<'py, PyUntypedArray>)> {
17151730
dtype_dispatch!(
17161731
|t, m| Ok(Self::freq_power_impl(&self.eval_f32, py, t, m)),
17171732
|t, m| Ok(Self::freq_power_impl(&self.eval_f64, py, t, m)),
17181733
t,
1719-
=m
1734+
=m;
1735+
cast=cast
17201736
)
17211737
}
17221738

@@ -1753,7 +1769,7 @@ transform : None, optional
17531769
constructors
17541770
17551771
{common}
1756-
freq_power(t, m)
1772+
freq_power(t, m, *, cast=False)
17571773
Get periodogram
17581774
17591775
Parameters
@@ -1762,6 +1778,8 @@ freq_power(t, m)
17621778
Time array
17631779
m : np.ndarray of np.float32 or np.float64
17641780
Magnitude (flux) array
1781+
cast : bool, optional
1782+
Cast inputs to np.ndarray objects of the same dtype
17651783
17661784
Returns
17671785
-------

0 commit comments

Comments
 (0)