Skip to content

Commit c472fa2

Browse files
committed
Add cast=False for some dmdt methods
1 parent 745a034 commit c472fa2

File tree

5 files changed

+58
-16
lines changed

5 files changed

+58
-16
lines changed

light-curve/src/dmdt.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,21 +1142,27 @@ impl DmDt {
11421142
/// Time moments, must be sorted
11431143
/// sorted : bool or None, optional
11441144
/// `True` guarantees that `t` is sorted
1145+
/// cast : bool
1146+
/// If `False` allow np.ndarray input only, `True` allows casting.
1147+
/// Casting provides more flexibility with input types at the cost of
1148+
// performance.
11451149
///
11461150
/// Returns
11471151
/// 1d-array of float
11481152
///
1149-
#[pyo3(signature=(t, *, sorted=None))]
1153+
#[pyo3(signature=(t, *, sorted=None, cast=false))]
11501154
fn count_dt<'py>(
11511155
&self,
11521156
py: Python<'py>,
11531157
t: Bound<'py, PyAny>,
11541158
sorted: Option<bool>,
1159+
cast: bool,
11551160
) -> Res<Bound<'py, PyUntypedArray>> {
11561161
dtype_dispatch!(
11571162
|t| self.dmdt_f32.py_count_dt(py, t, sorted),
11581163
|t| self.dmdt_f64.py_count_dt(py, t, sorted),
1159-
t
1164+
t;
1165+
cast=cast
11601166
)
11611167
}
11621168

@@ -1203,24 +1209,30 @@ impl DmDt {
12031209
/// Magnitudes
12041210
/// sorted : bool or None, optional
12051211
/// `True` guarantees that the light curve is sorted
1212+
/// cast : bool
1213+
/// If `False` allow np.ndarray input only, `True` allows casting.
1214+
/// Casting provides more flexibility with input types at the cost of
1215+
// performance.
12061216
///
12071217
/// Returns
12081218
/// -------
12091219
/// 2d-ndarray of float
12101220
///
1211-
#[pyo3(signature = (t, m, *, sorted=None))]
1221+
#[pyo3(signature = (t, m, *, sorted=None, cast=false))]
12121222
fn points<'py>(
12131223
&self,
12141224
py: Python<'py>,
12151225
t: Bound<'py, PyAny>,
12161226
m: Bound<'py, PyAny>,
12171227
sorted: Option<bool>,
1228+
cast: bool,
12181229
) -> Res<Bound<'py, PyUntypedArray>> {
12191230
dtype_dispatch!(
12201231
|t, m| self.dmdt_f32.py_points(py, t, m, sorted),
12211232
|t, m| self.dmdt_f64.py_points(py, t, m, sorted),
12221233
t,
1223-
=m
1234+
=m;
1235+
cast=cast
12241236
)
12251237
}
12261238

@@ -1372,26 +1384,31 @@ impl DmDt {
13721384
/// Uncertainties
13731385
/// sorted : bool or None, optional
13741386
/// `True` guarantees that the light curve is sorted
1375-
///
1387+
/// cast : bool
1388+
/// If `False` allow np.ndarray input only, `True` allows casting.
1389+
/// Casting provides more flexibility with input types at the cost of
1390+
// performance.
13761391
/// Returns
13771392
/// -------
13781393
/// 2d-array of float
13791394
///
1380-
#[pyo3(signature = (t, m, sigma, *, sorted=None))]
1395+
#[pyo3(signature = (t, m, sigma, *, sorted=None, cast=false))]
13811396
fn gausses<'py>(
13821397
&self,
13831398
py: Python<'py>,
13841399
t: Bound<'py, PyAny>,
13851400
m: Bound<'py, PyAny>,
13861401
sigma: Bound<'py, PyAny>,
13871402
sorted: Option<bool>,
1403+
cast: bool,
13881404
) -> Res<Bound<'py, PyUntypedArray>> {
13891405
dtype_dispatch!(
13901406
|t, m, sigma| self.dmdt_f32.py_gausses(py, t, m, sigma, sorted),
13911407
|t, m, sigma| self.dmdt_f64.py_gausses(py, t, m, sigma, sorted),
13921408
t,
13931409
=m,
1394-
=sigma
1410+
=sigma;
1411+
cast=cast
13951412
)
13961413
}
13971414

light-curve/src/features.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ const METHOD_CALL_DOC: &str = r#"__call__(self, t, m, sigma=None, *, fill_value=
7373
cast : bool, optional
7474
Allows non-numpy input and casting of arrays to a common dtype.
7575
If `False`, inputs must be `np.ndarray` instances with matched dtypes.
76+
Casting provides more flexibility with input types at the cost of
77+
performance.
7678
Returns
7779
-------
7880
ndarray of np.float32 or np.float64

light-curve/src/np_array.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ fn cast_fail_reason<const N: usize>(
7777
let fail_name = names.get(idx).expect("idx is out of bounds of names slice");
7878
let fail_obj = objects
7979
.get(idx)
80-
.expect("idx is out of bounds of names slice");
80+
.expect("idx is out of bounds of objects slice");
8181

8282
let error_message = if let Ok(fail_arr) = fail_obj.downcast::<PyUntypedArray>() {
8383
if fail_arr.ndim() != 1 {
@@ -212,12 +212,13 @@ fn downcast_objects_no_cast<'py, const N: usize>(
212212
})
213213
.ok_or_else(|| {
214214
let valid_f32_count = f32_arrays.iter().filter(|arr| arr.is_some()).count();
215-
cast_fail_reason(
216-
usize::max(valid_f32_count, valid_f64_count),
217-
names,
218-
objects,
219-
false,
220-
)
215+
let max_count = usize::max(valid_f32_count, valid_f64_count);
216+
if max_count == 0 {
217+
unknown_type_exception(names[0], objects[0])
218+
} else {
219+
let idx = max_count - 1;
220+
cast_fail_reason(idx, names, objects, false)
221+
}
221222
})?;
222223
Ok(GenericPyReadonlyArrays::F64(f64_arrays))
223224
}

light-curve/tests/light_curve_ext/test_dmdt.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,14 @@ def test_dmdt_points_dtype(t_dtype, m_dtype):
282282
t = np.linspace(0, 1, 11, dtype=t_dtype)
283283
m = np.asarray(t, dtype=m_dtype)
284284
dmdt = DmDt.from_borders(min_lgdt=0, max_lgdt=1, max_abs_dm=1, lgdt_size=2, dm_size=2, norm=[])
285-
values = dmdt.points(t, m)
285+
286+
if t_dtype is m_dtype:
287+
context = nullcontext()
288+
else:
289+
context = pytest.raises(TypeError)
290+
with context:
291+
dmdt.points(t, m, cast=False)
292+
values = dmdt.points(t, m, cast=True)
286293
assert values.dtype == np.result_type(t, m)
287294

288295

@@ -292,7 +299,14 @@ def test_dmdt_gausses_dtype(t_dtype, m_dtype, sigma_dtype):
292299
m = np.asarray(t, dtype=m_dtype)
293300
sigma = np.asarray(t, dtype=sigma_dtype)
294301
dmdt = DmDt.from_borders(min_lgdt=0, max_lgdt=1, max_abs_dm=1, lgdt_size=2, dm_size=2, norm=[])
295-
values = dmdt.gausses(t, m, sigma)
302+
303+
if t_dtype is m_dtype is sigma_dtype:
304+
context = nullcontext()
305+
else:
306+
context = pytest.raises(TypeError)
307+
with context:
308+
dmdt.gausses(t, m, sigma, cast=False)
309+
values = dmdt.gausses(t, m, sigma, cast=True)
296310
assert values.dtype == np.result_type(t, m, sigma)
297311

298312

light-curve/tests/light_curve_ext/test_feature.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,10 @@ def test_raises_for_wrong_inputs():
364364
fe(np.array([[1.0, 2.0, 3.0]]), np.array([1.0, 2.0, 3.0]))
365365
with pytest.raises(TypeError, match="'t' has dtype <U1"):
366366
fe(np.array(["a", "b", "c"]), np.array([1.0, 2.0, 3.0]))
367+
with pytest.raises(TypeError, match="'t' has type 'list'"):
368+
fe([1.0, 2.0, 3.0], [1.0, 2.0, 3.0], cast=False)
369+
# No failure of the last test with cast=True
370+
_ = fe([1.0, 2.0, 3.0], [1.0, 2.0, 3.0], cast=True)
367371

368372
# Second and third arguments
369373
t = np.arange(10, dtype=np.float64)
@@ -375,3 +379,7 @@ def test_raises_for_wrong_inputs():
375379
fe(t, t, np.array(1.0))
376380
with pytest.raises(TypeError, match="Mismatched dtypes:"):
377381
fe(t, t.astype(str) + "x")
382+
with pytest.raises(TypeError, match="Mismatched dtypes:"):
383+
fe(t, t.astype(np.float32), cast=False)
384+
# No failure of the last test with cast=True
385+
_ = fe(t, t.astype(np.float32), cast=True)

0 commit comments

Comments
 (0)