Skip to content

Commit c9b1728

Browse files
authored
Merge pull request #510 from light-curve/fix-casting-exceptions
Fix error messages for invalid inputs
2 parents 6d804da + 286d2c3 commit c9b1728

File tree

2 files changed

+73
-32
lines changed

2 files changed

+73
-32
lines changed

light-curve/src/np_array.rs

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::errors::{Exception, Res};
22

33
use numpy::prelude::*;
4-
use numpy::{Element, PyArray1, PyReadonlyArray1};
4+
use numpy::{Element, PyArray1, PyReadonlyArray1, PyUntypedArray, PyUntypedArrayMethods};
55
use pyo3::prelude::*;
66

77
pub(crate) type Arr<'a, T> = PyReadonlyArray1<'a, T>;
@@ -22,6 +22,30 @@ impl DType for f64 {
2222
}
2323
}
2424

25+
pub(crate) fn unknown_type_exception(name: &str, obj: Bound<PyAny>) -> Exception {
26+
let message = if let Ok(arr) = obj.downcast::<PyUntypedArray>() {
27+
let ndim = arr.ndim();
28+
if ndim != 1 {
29+
format!("'{name}' is a {ndim}-d array, only 1-d arrays are supported.")
30+
} else {
31+
let dtype = match arr.dtype().str() {
32+
Ok(s) => s,
33+
Err(err) => return err.into(),
34+
};
35+
format!("'{name}' has dtype {dtype}, but only float32 and float64 are supported.")
36+
}
37+
} else {
38+
let tp = match obj.get_type().name() {
39+
Ok(s) => s,
40+
Err(err) => return err.into(),
41+
};
42+
format!(
43+
"'{name}' has type '{tp}', float32 or float64 1-d numpy array was supported. Try to cast with np.asarray."
44+
)
45+
};
46+
Exception::TypeError(message)
47+
}
48+
2549
pub(crate) fn extract_matched_array<'py, T>(
2650
y_name: &'static str,
2751
y: Bound<'py, PyAny>,
@@ -39,43 +63,37 @@ where
3963
Ok(y)
4064
} else {
4165
Err(Exception::ValueError(format!(
42-
"Mismatched length ({}: {}, {}: {})",
43-
y_name,
44-
y.len(),
66+
"Mismatched lengths: '{}': {}, '{}': {}",
4567
x_name,
4668
x.len(),
69+
y_name,
70+
y.len(),
4771
)))
4872
}
4973
} else {
5074
Ok(y)
5175
}
5276
} else {
53-
let y_type = y
54-
.get_type()
55-
.name()
56-
.map(|name| {
57-
if name == "ndarray" {
58-
format!(
59-
"ndarray[{}]",
60-
y.getattr("dtype")
61-
.map(|dtype| dtype
62-
.getattr("name")
63-
.map(|p| p.to_string())
64-
.unwrap_or("unknown".into()))
65-
.unwrap_or("unknown".into())
66-
)
67-
} else {
68-
name.to_string()
69-
}
70-
})
71-
.unwrap_or("unknown".into());
72-
Err(Exception::TypeError(format!(
73-
"Mismatched types ({}: np.ndarray[{}], {}: {})",
74-
x_name,
75-
T::dtype_name(),
76-
y_name,
77-
y_type
78-
)))
77+
let error_message = if let Ok(y_arr) = y.downcast::<PyUntypedArray>() {
78+
if y_arr.ndim() != 1 {
79+
format!(
80+
"'{}' is a {}-d array, only 1-d arrays are supported.",
81+
y_name,
82+
y_arr.ndim()
83+
)
84+
} else {
85+
format!(
86+
"Mismatched dtypes: '{}': {}, '{}': {}",
87+
x_name,
88+
x.dtype().str()?,
89+
y_name,
90+
y_arr.dtype().str()?
91+
)
92+
}
93+
} else {
94+
format!("'{y_name}' must be a numpy array of the same shape and dtype as '{x_name}', '{x_name}' has type 'np.ndarray[{x_dtype}]', '{y_name}' has type '{y_type}')", y_type=y.get_type().name()?, x_dtype=T::dtype_name())
95+
};
96+
Err(Exception::TypeError(error_message))
7997
}
8098
}
8199

@@ -102,7 +120,7 @@ macro_rules! dtype_dispatch {
102120
let f64 = $f64;
103121
f64(x64)
104122
} else {
105-
Err(crate::errors::Exception::TypeError("Unsupported dtype".into()).into())
123+
Err(crate::np_array::unknown_type_exception(stringify!($first_arg), $first_arg.clone()))
106124
}
107125
}};
108126
($f32:expr, $f64:expr, $first_arg:expr $(,$eq:tt $arg:expr)+ $(,)?) => {{
@@ -116,7 +134,7 @@ macro_rules! dtype_dispatch {
116134
let f64 = $f64;
117135
f64(x64.clone(), $(crate::np_array::extract_matched_array(stringify!($arg), $arg, x_name, &x64, _distinguish_eq_symbol!($eq))?,)*)
118136
} else {
119-
Err(crate::errors::Exception::TypeError("Unsupported dtype".into()).into())
137+
Err(crate::np_array::unknown_type_exception(stringify!($first_arg), $first_arg.clone()))
120138
}
121139
}};
122140
}

light-curve/tests/light_curve_ext/test_feature.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,26 @@ def test_json_deserialization():
352352
from_json = lc.feature_from_json(json)
353353
assert isinstance(from_json, lc._FeatureEvaluator)
354354
from_json(*gen_lc(128))
355+
356+
357+
def test_raises_for_wrong_inputs():
358+
fe = lc.Amplitude()
359+
360+
# First argument
361+
with pytest.raises(TypeError, match="'t' has type 'list'"):
362+
fe([1.0, 2.0, 3.0], [1.0, 2.0, 3.0])
363+
with pytest.raises(TypeError, match="'t' is a 2-d array"):
364+
fe(np.array([[1.0, 2.0, 3.0]]), np.array([1.0, 2.0, 3.0]))
365+
with pytest.raises(TypeError, match="'t' has dtype int64"):
366+
fe(np.array([1, 2, 3], dtype=np.int64), np.array([[1.0, 2.0, 3.0]]))
367+
368+
# Second and third arguments
369+
t = np.arange(10, dtype=np.float64)
370+
with pytest.raises(ValueError, match="Mismatched lengths:"):
371+
fe(t, np.arange(11, dtype=np.float64))
372+
with pytest.raises(TypeError, match="'m' must be a numpy array"):
373+
fe(t, list(t))
374+
with pytest.raises(TypeError, match="'sigma' is a 0-d array"):
375+
fe(t, t, np.array(1.0))
376+
with pytest.raises(TypeError, match="Mismatched dtypes:"):
377+
fe(t, t.astype(np.float32))

0 commit comments

Comments
 (0)